Yan Bai commited on
Commit
d2006b6
·
1 Parent(s): 0ffae8a
Dockerfile CHANGED
@@ -8,7 +8,7 @@ RUN pip install --no-cache-dir \
8
  termcolor \
9
  ipdb
10
  # 添加 Megatron-LM core_v0.12.2
11
- RUN git clone -b core_v0.13.0rc4 --depth 1 https://github.com/NVIDIA/Megatron-LM.git /opt/Megatron-LM
12
 
13
  RUN git clone -b estimator_mcore013 --depth 1 https://github.com/ISEEKYAN/mbridge.git /opt/mbridge
14
 
@@ -18,10 +18,7 @@ RUN groupadd -g 1000 user && \
18
 
19
  # 复制代码至工作目录
20
  WORKDIR $HOME/app
21
- RUN mkdir -p $HOME/app && mv /opt/mbridge/memory_estimator/* $HOME/app && chown -R user:user $HOME/app
22
-
23
- RUN echo " " > $HOME/app/__init__.py
24
- RUN echo "from webui.main import app" > $HOME/app/app.py
25
 
26
  # HF Spaces 默认通过 $PORT 注入端口
27
  ENV PYTHONPATH=/opt/Megatron-LM:$PYTHONPATH
@@ -29,4 +26,4 @@ ENV PORT=7860
29
  EXPOSE 7860
30
 
31
  # 启动 FastAPI 服务
32
- CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port $PORT"]
 
8
  termcolor \
9
  ipdb
10
  # 添加 Megatron-LM core_v0.12.2
11
+ RUN git clone -b core_v0.12.2 --depth 1 https://github.com/NVIDIA/Megatron-LM.git /opt/Megatron-LM
12
 
13
  RUN git clone -b estimator_mcore013 --depth 1 https://github.com/ISEEKYAN/mbridge.git /opt/mbridge
14
 
 
18
 
19
  # 复制代码至工作目录
20
  WORKDIR $HOME/app
21
+ COPY --chown=user . $HOME/app
 
 
 
22
 
23
  # HF Spaces 默认通过 $PORT 注入端口
24
  ENV PYTHONPATH=/opt/Megatron-LM:$PYTHONPATH
 
26
  EXPOSE 7860
27
 
28
  # 启动 FastAPI 服务
29
+ CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port $PORT"]
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from webui.main import app
estimate_013.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ """Pretrain GPT."""
3
+ import warnings
4
+
5
+ warnings.filterwarnings("ignore", category=DeprecationWarning)
6
+ warnings.filterwarnings("ignore", category=FutureWarning)
7
+ warnings.filterwarnings("ignore")
8
+ import inspect
9
+ import os
10
+ from contextlib import nullcontext
11
+ from functools import partial
12
+ from typing import List, Optional, Tuple, Union
13
+
14
+ import torch
15
+ from megatron.core import mpu
16
+ from megatron.core.datasets.blended_megatron_dataset_builder import (
17
+ BlendedMegatronDatasetBuilder,
18
+ )
19
+ from megatron.core.datasets.gpt_dataset import (
20
+ GPTDataset,
21
+ GPTDatasetConfig,
22
+ MockGPTDataset,
23
+ )
24
+ from megatron.core.datasets.utils import get_blend_from_list
25
+ from megatron.core.enums import ModelType
26
+ from megatron.core.models.gpt.gpt_layer_specs import (
27
+ get_gpt_decoder_block_spec,
28
+ get_gpt_layer_local_spec,
29
+ get_gpt_layer_with_transformer_engine_spec,
30
+ get_gpt_mtp_block_spec,
31
+ )
32
+ from megatron.core.transformer.spec_utils import import_module
33
+ from megatron.core.utils import StragglerDetector
34
+ from megatron.training import (
35
+ get_args,
36
+ get_timers,
37
+ get_tokenizer,
38
+ pretrain,
39
+ print_rank_0,
40
+ )
41
+ from megatron.training.arguments import core_transformer_config_from_args
42
+ from megatron.training.initialize import initialize_megatron
43
+ from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank
44
+ from megatron.training.yaml_arguments import core_transformer_config_from_yaml
45
+ from moe_mem_estimator.base import (
46
+ get_pipeline_model_parallel_rank,
47
+ get_pipeline_model_parallel_world_size,
48
+ get_virtual_pipeline_model_parallel_world_size,
49
+ is_pipeline_first_stage,
50
+ is_pipeline_last_stage,
51
+ set_global_config,
52
+ set_pipeline_model_parallel_rank,
53
+ )
54
+ from moe_mem_estimator.gpt_model import GPTModel
55
+ from moe_mem_estimator.layers import MLASelfAttention, MoELayer
56
+
57
+ torch.distributed.get_rank = lambda: 0
58
+ torch.cuda.get_device_capability = lambda: [8]
59
+
60
+ def estimate_from_config(config, args):
61
+ """
62
+ Estimate memory usage from a given config and args, instead of global state.
63
+ Now supports virtual pipeline model parallelism for more accurate results.
64
+ """
65
+
66
+ args.moe_grouped_gemm = True
67
+ patch_parallel_states()
68
+ if config is None:
69
+ if args.yaml_cfg is not None:
70
+ config = core_transformer_config_from_yaml(args, "language_model")
71
+ else:
72
+ config = core_transformer_config_from_args(args)
73
+
74
+ input_shape = [args.micro_batch_size, args.seq_length]
75
+
76
+ set_global_config(config)
77
+ print(config)
78
+ # return
79
+ cli_reports = []
80
+
81
+ if config.pipeline_model_parallel_size > 1:
82
+ for pp_rank in range(config.pipeline_model_parallel_size):
83
+ set_pipeline_model_parallel_rank(pp_rank)
84
+ print(
85
+ f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------"
86
+ )
87
+ input_shape, rpt = report_memory_usage_one_pp_rank(
88
+ input_shape, args, config, pp_rank, config.pipeline_model_parallel_size
89
+ )
90
+ cli_reports.append(rpt)
91
+ else:
92
+ set_pipeline_model_parallel_rank(0)
93
+ _, rpt = report_memory_usage_one_pp_rank(input_shape, args, config)
94
+ cli_reports.append(rpt)
95
+
96
+ aggregated_reports: list[dict] = cli_reports
97
+
98
+ # 返回 (聚合后的 pp 报告列表, 全量 raw chunk 列表)
99
+ return aggregated_reports, cli_reports
100
+
101
+
102
+ def _get_transformer_layer_spec(use_te, config):
103
+ """Get transformer layer specification based on configuration.
104
+
105
+ Args:
106
+ use_te (bool): Whether to use Transformer Engine
107
+ args: Training arguments
108
+ config: Model configuration
109
+
110
+ Returns:
111
+ transformer_layer_spec: The transformer layer specification
112
+ """
113
+ if use_te:
114
+ return get_gpt_layer_with_transformer_engine_spec(
115
+ config.num_moe_experts,
116
+ config.moe_grouped_gemm,
117
+ config.qk_layernorm,
118
+ config.multi_latent_attention,
119
+ config.fp8,
120
+ )
121
+ else:
122
+ return get_gpt_layer_local_spec(
123
+ config.num_moe_experts,
124
+ config.moe_grouped_gemm,
125
+ config.qk_layernorm,
126
+ config.multi_latent_attention,
127
+ )
128
+
129
+
130
+ def model_provider(
131
+ args, config, pre_process=True, post_process=True, vp_stage: Optional[int] = None
132
+ ) -> GPTModel:
133
+ use_te = True
134
+ if args.num_experts:
135
+ # Define the decoder block spec
136
+ transformer_layer_spec = get_gpt_decoder_block_spec(
137
+ config,
138
+ use_transformer_engine=use_te,
139
+ normalization="LayerNorm",
140
+ qk_l2_norm=False,
141
+ vp_stage=vp_stage,
142
+ )
143
+ else:
144
+ # Define the decoder layer spec
145
+ transformer_layer_spec = _get_transformer_layer_spec(use_te, config)
146
+ mtp_block_spec = None
147
+ # TODO fp8
148
+ model = GPTModel(
149
+ config=config,
150
+ transformer_layer_spec=transformer_layer_spec,
151
+ vocab_size=args.padded_vocab_size,
152
+ max_sequence_length=args.max_position_embeddings,
153
+ pre_process=pre_process,
154
+ post_process=post_process,
155
+ fp16_lm_cross_entropy=getattr(config, "fp16_lm_cross_entropy", False),
156
+ parallel_output=True,
157
+ share_embeddings_and_output_weights=False,
158
+ position_embedding_type="rope",
159
+ rotary_percent=getattr(args, "rotary_percent", 1.0),
160
+ rotary_base=getattr(args, "rotary_base", 10000),
161
+ rope_scaling=getattr(config, "use_rope_scaling", False),
162
+ mtp_block_spec=mtp_block_spec,
163
+ vp_stage=vp_stage,
164
+ )
165
+
166
+ return model
167
+
168
+
169
+ def get_model(
170
+ model_provider_func, args, config, model_type=ModelType.encoder_or_decoder
171
+ ):
172
+ """Build the model."""
173
+ # args = get_args()
174
+ # args.model_type = model_type
175
+
176
+ # Build model.
177
+ if not getattr(args, "virtual_pipeline_model_parallel_size", None):
178
+ args.virtual_pipeline_model_parallel_size = None
179
+ if config.pipeline_model_parallel_layout:
180
+ args.virtual_pipeline_model_parallel_size = (
181
+ config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size
182
+ )
183
+ config.virtual_pipeline_model_parallel_size = (
184
+ config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size
185
+ )
186
+
187
+ def build_model():
188
+ if (
189
+ get_pipeline_model_parallel_world_size() > 1
190
+ and args.virtual_pipeline_model_parallel_size is not None
191
+ ):
192
+ if model_type == ModelType.encoder_and_decoder:
193
+ assert (
194
+ config.encoder_pipeline_model_parallel_size == 0
195
+ ), "Interleaved schedule not supported for model with encoder on separate PP rank"
196
+ model = []
197
+ for i in range(args.virtual_pipeline_model_parallel_size):
198
+ # Set pre_process and post_process only after virtual rank is set.
199
+ pre_process = is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
200
+ post_process = is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
201
+
202
+ this_model = model_provider_func(
203
+ args,
204
+ config,
205
+ pre_process=pre_process,
206
+ post_process=post_process,
207
+ vp_stage=i,
208
+ )
209
+ this_model.model_type = model_type
210
+ this_model.vp_stage = i
211
+ model.append(this_model)
212
+ else:
213
+ pre_process = is_pipeline_first_stage()
214
+ post_process = is_pipeline_last_stage()
215
+ if model_type == ModelType.encoder_and_decoder:
216
+ if get_pipeline_model_parallel_world_size() > 1:
217
+ rank = get_pipeline_model_parallel_rank()
218
+ first_decoder_rank = config.encoder_pipeline_model_parallel_size
219
+ world_size = get_pipeline_model_parallel_world_size()
220
+ pre_process = rank == 0 or rank == first_decoder_rank
221
+ post_process = (rank == (first_decoder_rank - 1)) or (
222
+ rank == (world_size - 1)
223
+ )
224
+ model = model_provider_func(
225
+ args,
226
+ config,
227
+ pre_process=pre_process,
228
+ post_process=post_process,
229
+ )
230
+ else:
231
+ model = model_provider_func(
232
+ args, config, pre_process=pre_process, post_process=post_process
233
+ )
234
+ model.model_type = model_type
235
+ return model
236
+
237
+ model = build_model()
238
+
239
+ if not isinstance(model, list):
240
+ model = [model]
241
+ return model
242
+
243
+
244
+ NUM_BYTES_IN_MEGABYTE = 1024 * 1024
245
+ NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024
246
+
247
+
248
+ def patch_parallel_states():
249
+ from megatron.core import parallel_state
250
+
251
+ parallel_state.is_pipeline_first_stage = is_pipeline_first_stage
252
+ parallel_state.is_pipeline_last_stage = is_pipeline_last_stage
253
+ parallel_state.get_pipeline_model_parallel_rank = get_pipeline_model_parallel_rank
254
+ parallel_state.get_pipeline_model_parallel_world_size = (
255
+ get_pipeline_model_parallel_world_size
256
+ )
257
+ parallel_state.get_virtual_pipeline_model_parallel_world_size = (
258
+ get_virtual_pipeline_model_parallel_world_size
259
+ )
260
+ parallel_state.is_inside_encoder = lambda: False
261
+ parallel_state.get_pipeline_model_parallel_decoder_start = lambda: 0
262
+
263
+
264
+ def report_memory_usage(args, config=None):
265
+ args.moe_grouped_gemm = True
266
+ patch_parallel_states()
267
+ if config is None:
268
+ if args.yaml_cfg is not None:
269
+ config = core_transformer_config_from_yaml(args, "language_model")
270
+ else:
271
+ config = core_transformer_config_from_args(args)
272
+
273
+ input_shape = [args.micro_batch_size, args.seq_length]
274
+
275
+ set_global_config(config)
276
+
277
+ cli_reports = []
278
+
279
+ if config.pipeline_model_parallel_size > 1:
280
+ for pp_rank in range(config.pipeline_model_parallel_size):
281
+ set_pipeline_model_parallel_rank(pp_rank)
282
+ print(
283
+ f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------"
284
+ )
285
+ input_shape, rpt = report_memory_usage_one_pp_rank(
286
+ input_shape, args, config, pp_rank, config.pipeline_model_parallel_size
287
+ )
288
+ cli_reports.append(rpt)
289
+ else:
290
+ set_pipeline_model_parallel_rank(0)
291
+ _, rpt = report_memory_usage_one_pp_rank(input_shape, args, config)
292
+ cli_reports.append(rpt)
293
+
294
+ # Optionally pretty print summary
295
+ print("\n===== Summary (per PP rank) =====")
296
+ for r in cli_reports:
297
+ print(
298
+ f"PP{r['pp_rank']} total {r['total_gb']} GB (weight_grad {r['weight_grad_gb']} GB weight_grad_optim {r['weight_grad_optim_gb']} GB act {r['activation_gb']} GB)"
299
+ )
300
+
301
+
302
+ def report_memory_usage_one_pp_rank(
303
+ input_shape: list[int], args, config, pp_rank=0, pp_size=1
304
+ ) -> tuple[list[int], dict]:
305
+ print(f"{input_shape=}")
306
+ model: list[GPTModel] = get_model(model_provider, args, config)
307
+ num_parameter_this_shard_all = 0
308
+ num_parameter_this_shard_sparse_all = 0
309
+ num_activation_all = 0
310
+ output_shape = input_shape
311
+ for vpp_rank, one_chunk in enumerate(model):
312
+ num_parameter_this_shard = one_chunk.num_parameter()
313
+ num_activation = one_chunk.num_activation(output_shape)
314
+ output_shape = one_chunk.mock_forward(output_shape)
315
+ print(f"{output_shape=}")
316
+ num_parameter_this_shard_sparse = 0
317
+ for layer in one_chunk.decoder.layers.modules:
318
+ if isinstance(layer.mlp, MoELayer):
319
+ num_parameter_this_shard_sparse += layer.mlp.num_parameter()
320
+ if (
321
+ "shared_experts" in layer.mlp.__dir__()
322
+ and layer.mlp.shared_experts is not None
323
+ ):
324
+ num_parameter_this_shard_sparse -= (
325
+ layer.mlp.shared_experts.num_parameter()
326
+ )
327
+ num_activation_this_shard_mlp = sum(
328
+ [m.mlp.num_activation() for m in one_chunk.decoder.layers.modules]
329
+ )
330
+ if len(model) > 1:
331
+ if vpp_rank >= 1 and vpp_rank < len(model) - 1:
332
+ num_microbatch_this_pp_rank = pp_size
333
+ elif vpp_rank == 0:
334
+ num_microbatch_this_pp_rank = pp_size + max(
335
+ (pp_size - pp_rank) * 2 - 1 - pp_size, 0
336
+ )
337
+ elif vpp_rank == len(model) - 1:
338
+ num_microbatch_this_pp_rank = min((pp_size - pp_rank) * 2 + 1, pp_size)
339
+ else:
340
+ num_microbatch_this_pp_rank = pp_size - pp_rank
341
+
342
+ num_parameter_this_shard_sparse = 0
343
+ for layer in one_chunk.decoder.layers.modules:
344
+ if isinstance(layer.mlp, MoELayer):
345
+ num_parameter_this_shard_sparse += layer.mlp.num_parameter()
346
+ if (
347
+ "shared_experts" in layer.mlp.__dir__()
348
+ and layer.mlp.shared_experts is not None
349
+ ):
350
+ num_parameter_this_shard_sparse -= (
351
+ layer.mlp.shared_experts.num_parameter()
352
+ )
353
+
354
+ one_chunk.__repr__()
355
+ print(one_chunk)
356
+ print(
357
+ f"Number of parameters in every GPU in billions: "
358
+ f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}"
359
+ )
360
+ num_parameter_this_shard_all += num_parameter_this_shard
361
+ num_parameter_this_shard_sparse_all += num_parameter_this_shard_sparse
362
+ # recompute
363
+ if config.recompute_granularity == "full":
364
+ recompute_num_layers = config.recompute_num_layers
365
+ num_layers = one_chunk.num_layers
366
+ common_act = (
367
+ one_chunk.num_act_pre
368
+ + one_chunk.num_act_between_layers
369
+ * num_layers
370
+ * num_microbatch_this_pp_rank
371
+ ) # recompute with pipeline parallel
372
+ info = "With this recomputing setting, the number of activation achieve peak when "
373
+ if config.recompute_method == "block":
374
+ num_layers_with_loss = num_layers - recompute_num_layers
375
+ if num_layers_with_loss == 0:
376
+ peak1 = common_act + one_chunk.num_act_post
377
+ peak2 = common_act + one_chunk.num_act_per_layer
378
+ if peak1 > peak2:
379
+ info += "calculating loss"
380
+ else:
381
+ info += "back-propogating loss"
382
+ num_activation = max(peak1, peak2)
383
+ else:
384
+ info += f"calculating loss with {num_layers_with_loss} non-recompute layers"
385
+ num_activation = (
386
+ common_act
387
+ + one_chunk.num_act_post
388
+ + one_chunk.num_act_per_layer
389
+ * num_layers_with_loss
390
+ * num_microbatch_this_pp_rank
391
+ )
392
+ elif config.recompute_method == "uniform":
393
+ peak1 = common_act + one_chunk.num_act_post
394
+ peak2 = (
395
+ (common_act + one_chunk.num_act_per_layer)
396
+ if vpp_rank == 0
397
+ else (common_act)
398
+ )
399
+ if peak1 > peak2:
400
+ info += "calculating loss"
401
+ else:
402
+ info += f"back-propogating loss recomputing every {recompute_num_layers} layers"
403
+ num_activation = max(peak1, peak2)
404
+ if len(one_chunk.decoder.layers.modules) > 0 and isinstance(
405
+ one_chunk.decoder.layers.modules[0].self_attention, MLASelfAttention
406
+ ): # MLA recompute achieve peak at backward
407
+ num_activation += one_chunk.decoder.layers.modules[
408
+ 0
409
+ ].self_attention.core_attention.num_activation()
410
+ print(info)
411
+
412
+ else:
413
+ num_activation = (
414
+ num_activation - one_chunk.num_act_post
415
+ ) * num_microbatch_this_pp_rank + one_chunk.num_act_post
416
+
417
+ # CP
418
+ num_activation = num_activation / config.context_parallel_size
419
+ if pp_size == 1:
420
+ print(
421
+ f"Number of activation in every GPU in billions: "
422
+ f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
423
+ )
424
+ else:
425
+ print(
426
+ f"Number of activation per microbatch in every GPU in billions: "
427
+ f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
428
+ f", {num_microbatch_this_pp_rank=} {vpp_rank=}"
429
+ )
430
+ num_activation_all += num_activation
431
+ num_bytes_per_parameter = (
432
+ 18
433
+ if not args.use_distributed_optimizer
434
+ else 6 + (12 / args.data_parallel_size / config.context_parallel_size)
435
+ )
436
+ if config.expert_model_parallel_size * config.expert_tensor_parallel_size > 1:
437
+ num_bytes_per_parameter_dense = num_bytes_per_parameter
438
+ num_bytes_per_parameter_moe = (
439
+ 18
440
+ if not args.use_distributed_optimizer
441
+ else 6
442
+ + (
443
+ 12
444
+ / (
445
+ args.world_size
446
+ / config.pipeline_model_parallel_size
447
+ / config.expert_model_parallel_size
448
+ / config.expert_tensor_parallel_size
449
+ )
450
+ )
451
+ )
452
+ print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}")
453
+ weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE
454
+ weight_grad_optim_memory = (
455
+ (num_parameter_this_shard_all - num_parameter_this_shard_sparse_all)
456
+ * num_bytes_per_parameter_dense
457
+ + num_parameter_this_shard_sparse_all * num_bytes_per_parameter_moe
458
+ ) / NUM_BYTES_IN_GIGABYTE
459
+ else:
460
+ print(f"{num_bytes_per_parameter=}")
461
+ weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE
462
+ weight_grad_optim_memory = (
463
+ num_parameter_this_shard_all
464
+ * num_bytes_per_parameter
465
+ / NUM_BYTES_IN_GIGABYTE
466
+ )
467
+
468
+ activation_memory = (
469
+ num_activation_all * 2 / NUM_BYTES_IN_GIGABYTE
470
+ ) # only support fp16
471
+ total_memory = weight_grad_optim_memory + activation_memory
472
+
473
+ print(
474
+ f"Theoretical memory footprints: weight and optimizer={weight_grad_optim_memory:.2f} GB, "
475
+ f"activation={activation_memory:.2f} GB, total={total_memory:.2f} GB\n"
476
+ )
477
+
478
+ # 生成与 estimate_from_config 相同格式的聚合报告
479
+ model_breakdown_concat = "\n\n".join(
480
+ [f"--- vpp_chunk {i} ---\n{str(m)}" for i, m in enumerate(model)]
481
+ )
482
+
483
+ report = {
484
+ "pp_rank": pp_rank,
485
+ "parameters_b": num_parameter_this_shard_all / 1e9,
486
+ "activation_b": num_activation_all / 1e9,
487
+ "weight_grad_gb": round(weight_grad_memory, 2),
488
+ "weight_grad_optim_gb": round(weight_grad_optim_memory, 2),
489
+ "activation_gb": round(activation_memory, 2),
490
+ "total_gb": round(total_memory, 2),
491
+ "model_breakdown": model_breakdown_concat,
492
+ "details": None,
493
+ }
494
+
495
+ return output_shape, report
496
+
497
+
498
+ if __name__ == "__main__":
499
+ initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True)
500
+
501
+ import ipdb
502
+
503
+ with ipdb.launch_ipdb_on_exception():
504
+ args = get_args()
505
+ report_memory_usage(args)
moe_mem_estimator/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
moe_mem_estimator/base.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ from abc import ABC
3
+
4
+ from megatron.core.transformer.transformer_config import TransformerConfig
5
+ from termcolor import colored
6
+ from torch.nn.modules.module import _addindent
7
+
8
+
9
+ def prehook_save_input_shape(func):
10
+ def wrapper(self, *input_shapes, **kw_input_shapes):
11
+ if len(input_shapes) + len(kw_input_shapes) == 0:
12
+ if "_input_shape" in self.__dict__:
13
+ return func(self, *self._input_shape, **self._kw_input_shapes)
14
+ else:
15
+ return 0
16
+ self._input_shape = input_shapes
17
+ self._kw_input_shapes = kw_input_shapes
18
+ return func(self, *self._input_shape, **self._kw_input_shapes)
19
+
20
+ return wrapper
21
+
22
+
23
+ class MetaBase(type):
24
+ def __new__(cls, name, bases, attrs):
25
+ if "num_activation" in attrs:
26
+ attrs["num_activation"] = prehook_save_input_shape(attrs["num_activation"])
27
+
28
+ return super().__new__(cls, name, bases, attrs)
29
+
30
+
31
+ class MemEstimator(metaclass=MetaBase):
32
+ def __init__(self, *args, **kwargs):
33
+ self._modules = {}
34
+ pass
35
+
36
+ def __repr__(self):
37
+ # We treat the extra repr like the sub-module, one item per line
38
+ extra_lines = []
39
+ # extra_repr = self.extra_repr()
40
+ # # empty string will be split into list ['']
41
+ # if extra_repr:
42
+ # extra_lines = extra_repr.split("\n")
43
+ child_lines = []
44
+ for key, module in self._modules.items():
45
+ mod_str = repr(module)
46
+ mod_str = _addindent(mod_str, 2)
47
+ child_lines.append("(" + key + "): " + mod_str)
48
+ lines = extra_lines + child_lines
49
+
50
+ stat = (
51
+ "\t/* n_params="
52
+ + colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
53
+ + "\tn_act="
54
+ + colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
55
+ + " */"
56
+ )
57
+ main_str = self._get_name() + stat + " ("
58
+ if lines:
59
+ # simple one-liner info, which most builtin Modules will use
60
+ if len(extra_lines) == 1 and not child_lines:
61
+ main_str += extra_lines[0]
62
+ else:
63
+ main_str += "\n " + "\n ".join(lines) + "\n"
64
+
65
+ main_str += ")"
66
+ return main_str
67
+ return f"{self.__class__.__name__} n_param={self.num_parameter()}"
68
+
69
+ def dump(self):
70
+ ret = {}
71
+ ret["name"] = self._get_name()
72
+ ret["n_params"] = self.num_parameter()
73
+ ret["n_act"] = self.num_activation()
74
+ modules = {}
75
+ for key, module in self._modules.items():
76
+ modules[key] = module.dump()
77
+ if len(modules) > 0:
78
+ ret["modules"] = modules
79
+ return ret
80
+
81
+ def _get_name(self):
82
+ return self.__class__.__name__
83
+
84
+ def num_parameter(self):
85
+ """
86
+ Calculate number of the model parameters
87
+ """
88
+ raise NotImplemented
89
+
90
+ def num_activation(self, input_shape: list[int]):
91
+ """
92
+ Calculate number of the activation with given input_shape.
93
+ Args:
94
+ input shape
95
+ """
96
+ raise NotImplemented
97
+
98
+ def mock_forward(self, input_shape: list[int]):
99
+ """
100
+ Mock the forward.
101
+ Args:
102
+ input shape
103
+ return:
104
+ output shape
105
+ """
106
+ raise NotImplemented
107
+
108
+ def __setattr__(self, name: str, value) -> None:
109
+ if isinstance(value, MemEstimator):
110
+ modules = self.__dict__.get("_modules")
111
+ modules[name] = value
112
+ else:
113
+ pass
114
+ return super().__setattr__(name, value)
115
+
116
+ def __delattr__(self, name):
117
+ modules = self.__dict__.get("_modules")
118
+ if name in modules:
119
+ del modules[name]
120
+ return super().__delattr__(name)
121
+
122
+
123
+ _global_config: TransformerConfig = None
124
+
125
+
126
+ def set_global_config(cfg):
127
+ global _global_config
128
+ _global_config = cfg
129
+
130
+
131
+ def get_tensor_model_parallel_world_size():
132
+ global _global_config
133
+ return _global_config.tensor_model_parallel_size
134
+
135
+
136
+ def get_tensor_model_parallel_rank():
137
+ return 0
138
+
139
+
140
+ def get_expert_tensor_parallel_world_size():
141
+ global _global_config
142
+ return _global_config.expert_tensor_parallel_size
143
+
144
+
145
+ def get_expert_tensor_parallel_rank():
146
+ return 0
147
+
148
+
149
+ _pp_rank = 0
150
+
151
+
152
+ def set_pipeline_model_parallel_rank(rank):
153
+ global _pp_rank
154
+ _pp_rank = rank
155
+
156
+
157
+ def get_pipeline_model_parallel_rank():
158
+ global _pp_rank
159
+ return _pp_rank
160
+
161
+
162
+ def get_virtual_pipeline_model_parallel_rank():
163
+ return 0
164
+
165
+
166
+ def get_pipeline_model_parallel_world_size():
167
+ global _global_config
168
+ return _global_config.pipeline_model_parallel_size
169
+
170
+
171
+ def get_expert_model_parallel_rank():
172
+ return 0
173
+
174
+
175
+ def get_expert_model_parallel_world_size():
176
+ global _global_config
177
+ return _global_config.expert_model_parallel_size
178
+
179
+
180
+ def get_virtual_pipeline_model_parallel_world_size():
181
+ global _global_config
182
+ return _global_config.virtual_pipeline_model_parallel_size
183
+
184
+
185
+ def is_pipeline_first_stage(ignore_virtual=False, vp_stage=None):
186
+ """Return True if in the first pipeline model-parallel stage, False otherwise."""
187
+ if (
188
+ not ignore_virtual
189
+ and get_virtual_pipeline_model_parallel_world_size() is not None
190
+ ):
191
+ if vp_stage != 0:
192
+ return False
193
+ return get_pipeline_model_parallel_rank() == 0
194
+
195
+
196
+ def is_pipeline_last_stage(ignore_virtual=False, vp_stage=None):
197
+ """Return True if in the last pipeline-model-parallel stage, False otherwise."""
198
+ if (
199
+ not ignore_virtual
200
+ and get_virtual_pipeline_model_parallel_world_size() is not None
201
+ ):
202
+ if vp_stage != (get_virtual_pipeline_model_parallel_world_size() - 1):
203
+ return False
204
+ return get_pipeline_model_parallel_rank() == (
205
+ get_pipeline_model_parallel_world_size() - 1
206
+ )
207
+
208
+
209
+ def cum_mul(l: list):
210
+ try:
211
+ ret = 1
212
+ for one in l:
213
+ ret *= one
214
+ return ret
215
+ except:
216
+ return 0
217
+ __import__("ipdb").set_trace()
moe_mem_estimator/gpt_model.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ from typing import Dict, Literal, Optional, Union
3
+
4
+ from megatron.core.model_parallel_config import ModelParallelConfig
5
+ from megatron.core.tensor_parallel.utils import VocabUtility
6
+ from megatron.core.transformer.enums import ModelType
7
+ from megatron.core.transformer.spec_utils import ModuleSpec
8
+ from megatron.core.transformer.transformer_block import (
9
+ TransformerBlockSubmodules,
10
+ _get_block_submodules,
11
+ )
12
+ from megatron.core.transformer.transformer_config import TransformerConfig
13
+
14
+ from .base import (
15
+ MemEstimator,
16
+ cum_mul,
17
+ get_tensor_model_parallel_rank,
18
+ get_tensor_model_parallel_world_size,
19
+ set_global_config,
20
+ )
21
+ from .layers import ColumnParallelLinear, LanguageModelEmbedding, TransformerBlock
22
+
23
+
24
+ class GPTModel(MemEstimator):
25
+ def __init__(
26
+ self,
27
+ config: TransformerConfig,
28
+ transformer_layer_spec: ModuleSpec,
29
+ vocab_size: int,
30
+ max_sequence_length: int,
31
+ pre_process: bool = True,
32
+ post_process: bool = True,
33
+ fp16_lm_cross_entropy: bool = False,
34
+ parallel_output: bool = True,
35
+ share_embeddings_and_output_weights: bool = False,
36
+ position_embedding_type: Literal[
37
+ "learned_absolute", "rope", "none"
38
+ ] = "learned_absolute",
39
+ rotary_percent: float = 1.0,
40
+ rotary_base: int = 10000,
41
+ rope_scaling: bool = False,
42
+ seq_len_interpolation_factor: Optional[float] = None,
43
+ mtp_block_spec: Optional[ModuleSpec] = None,
44
+ vp_stage: Optional[int] = None,
45
+ ):
46
+ super().__init__()
47
+
48
+ self.config = config
49
+ config.use_cpu_initialization = True
50
+
51
+ self.transformer_layer_spec: ModuleSpec = transformer_layer_spec
52
+ self.vocab_size = vocab_size
53
+ self.max_sequence_length = max_sequence_length
54
+ self.pre_process = pre_process
55
+ self.post_process = post_process
56
+ self.fp16_lm_cross_entropy = fp16_lm_cross_entropy
57
+ self.parallel_output = parallel_output
58
+ self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
59
+ self.position_embedding_type = position_embedding_type
60
+
61
+ # megatron core pipelining currently depends on model type
62
+ # TODO: remove this dependency ?
63
+ self.model_type = ModelType.encoder_or_decoder
64
+
65
+ # These 4 attributes are needed for TensorRT-LLM export.
66
+ self.max_position_embeddings = max_sequence_length
67
+ self.rotary_percent = rotary_percent
68
+ self.rotary_base = rotary_base
69
+ self.rotary_scaling = rope_scaling
70
+
71
+ if self.pre_process:
72
+ self.embedding = LanguageModelEmbedding(
73
+ config=self.config,
74
+ vocab_size=self.vocab_size,
75
+ max_sequence_length=self.max_sequence_length,
76
+ position_embedding_type=position_embedding_type,
77
+ )
78
+
79
+ # remove RotaryEmbedding
80
+
81
+ # Transformer.
82
+ self.decoder = TransformerBlock(
83
+ config=self.config,
84
+ spec=transformer_layer_spec,
85
+ pre_process=self.pre_process,
86
+ post_process=self.post_process,
87
+ vp_stage=vp_stage,
88
+ )
89
+
90
+ # Output
91
+ if post_process:
92
+ if self.config.defer_embedding_wgrad_compute:
93
+ self.embedding_activation_buffer = []
94
+ self.grad_output_buffer = []
95
+ else:
96
+ self.embedding_activation_buffer = None
97
+ self.grad_output_buffer = None
98
+
99
+ self.output_layer = ColumnParallelLinear(
100
+ config.hidden_size,
101
+ self.vocab_size,
102
+ config=config,
103
+ init_method=config.init_method,
104
+ bias=False,
105
+ skip_bias_add=False,
106
+ gather_output=not self.parallel_output,
107
+ skip_weight_param_allocation=self.pre_process
108
+ and self.share_embeddings_and_output_weights,
109
+ embedding_activation_buffer=self.embedding_activation_buffer,
110
+ grad_output_buffer=self.grad_output_buffer,
111
+ )
112
+
113
+ def num_parameter(self):
114
+ ret = 0
115
+ if self.pre_process:
116
+ ret += self.embedding.num_parameter()
117
+ ret += self.decoder.num_parameter()
118
+ if self.post_process:
119
+ ret += self.output_layer.num_parameter()
120
+ return ret
121
+
122
+ def num_activation(self, input_shape: list[int]):
123
+ self._inited = True
124
+ ret = 0
125
+
126
+ self.num_act_pre = 0
127
+ self.num_act_post = 0
128
+ self.num_act_per_layer = 0
129
+ self.num_act_between_layers = 0
130
+ self.num_layers = self.decoder.layers.modules.__len__()
131
+
132
+ if self.pre_process:
133
+ self.num_act_pre = self.embedding.num_activation(input_shape)
134
+ ret += self.num_act_pre
135
+ input_shape = self.embedding.mock_forward(input_shape)
136
+ ret += self.decoder.num_activation(input_shape)
137
+ if self.decoder.layers.modules.__len__() > 0:
138
+ self.num_act_per_layer = self.decoder.layers.modules[0].num_activation()
139
+ input_shape = self.decoder.mock_forward(input_shape)
140
+ self.num_act_between_layers = cum_mul(input_shape)
141
+
142
+ if self.post_process:
143
+ self.num_act_post = self.output_layer.num_activation(input_shape)
144
+ softmax_activation = (
145
+ self.output_layer.num_activation(input_shape) * 2
146
+ ) # due to softmax is calculate in fp32
147
+ self.num_act_post += softmax_activation
148
+ ret += self.num_act_post
149
+ return ret
150
+
151
+ def mock_forward(self, input_shape: list[int]):
152
+ if self.pre_process:
153
+ input_shape = self.embedding.mock_forward(input_shape)
154
+ input_shape = self.decoder.mock_forward(input_shape)
155
+ if self.post_process:
156
+ input_shape = self.output_layer.mock_forward(input_shape)
157
+ return input_shape
moe_mem_estimator/layers.py ADDED
@@ -0,0 +1,1940 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2
+ import math
3
+ import types
4
+ import warnings
5
+ from copy import deepcopy
6
+ from typing import Dict, Literal, Optional, Union
7
+
8
+ from megatron.core.extensions.transformer_engine import (
9
+ _get_extra_te_kwargs,
10
+ condition_init_method,
11
+ get_expert_parallel_rng_tracker_name,
12
+ )
13
+ from megatron.core.model_parallel_config import ModelParallelConfig
14
+ from megatron.core.models.common.embeddings import (
15
+ _yarn_get_mscale,
16
+ apply_rotary_pos_emb,
17
+ )
18
+ from megatron.core.tensor_parallel.utils import VocabUtility
19
+ from megatron.core.transformer import transformer_layer
20
+ from megatron.core.transformer.enums import AttnMaskType
21
+ from megatron.core.transformer.mlp import MLPSubmodules
22
+ from megatron.core.transformer.spec_utils import ModuleSpec, import_module
23
+ from megatron.core.transformer.transformer_block import TransformerBlockSubmodules
24
+ from megatron.core.transformer.transformer_config import (
25
+ MLATransformerConfig,
26
+ TransformerConfig,
27
+ )
28
+ from megatron.core.utils import divide
29
+
30
+ from .base import (
31
+ MemEstimator,
32
+ _addindent,
33
+ colored,
34
+ cum_mul,
35
+ get_expert_model_parallel_rank,
36
+ get_expert_model_parallel_world_size,
37
+ get_expert_tensor_parallel_rank,
38
+ get_expert_tensor_parallel_world_size,
39
+ get_pipeline_model_parallel_rank,
40
+ get_pipeline_model_parallel_world_size,
41
+ get_tensor_model_parallel_rank,
42
+ get_tensor_model_parallel_world_size,
43
+ is_pipeline_first_stage,
44
+ is_pipeline_last_stage,
45
+ set_global_config,
46
+ )
47
+
48
+
49
+ class LanguageModelEmbedding(MemEstimator):
50
+ def __init__(
51
+ self,
52
+ config: TransformerConfig,
53
+ vocab_size: int,
54
+ max_sequence_length: int,
55
+ position_embedding_type: Literal[
56
+ "learned_absolute", "rope", "none"
57
+ ] = "learned_absolute",
58
+ num_tokentypes: int = 0,
59
+ ):
60
+ super().__init__()
61
+
62
+ self.config: TransformerConfig = config
63
+ self.vocab_size: int = vocab_size
64
+ self.max_sequence_length: int = max_sequence_length
65
+ self.add_position_embedding: bool = (
66
+ position_embedding_type == "learned_absolute"
67
+ )
68
+ self.num_tokentypes = num_tokentypes
69
+ self.reduce_scatter_embeddings = (
70
+ (not self.add_position_embedding)
71
+ and self.num_tokentypes <= 0
72
+ and self.config.sequence_parallel
73
+ )
74
+ # Word embeddings (parallel).
75
+ self.word_embeddings = VocabParallelEmbedding(
76
+ num_embeddings=self.vocab_size,
77
+ embedding_dim=self.config.hidden_size,
78
+ init_method=self.config.init_method,
79
+ reduce_scatter_embeddings=self.reduce_scatter_embeddings,
80
+ config=self.config,
81
+ )
82
+
83
+ # TODO if self.add_position_embedding:
84
+
85
+ # TODO if self.num_tokentypes > 0:
86
+
87
+ self.embedding_dropout = Dropout(self.config.hidden_dropout)
88
+
89
+ def num_parameter(self):
90
+ ret = self.word_embeddings.num_parameter()
91
+ ret += self.embedding_dropout.num_parameter()
92
+ return ret
93
+
94
+ def num_activation(self, input_shape: list[int]):
95
+ ret = self.word_embeddings.num_activation(input_shape)
96
+ input_shape = self.word_embeddings.mock_forward(input_shape)
97
+ ret += self.embedding_dropout.num_activation(input_shape)
98
+ return ret
99
+
100
+ def mock_forward(self, input_shape: list[int]):
101
+ input_shape = self.word_embeddings.mock_forward(input_shape)
102
+ return input_shape
103
+
104
+
105
+ class VocabParallelEmbedding(MemEstimator):
106
+ def __init__(
107
+ self,
108
+ num_embeddings: int,
109
+ embedding_dim: int,
110
+ *,
111
+ init_method,
112
+ reduce_scatter_embeddings: bool = False,
113
+ config: ModelParallelConfig,
114
+ ):
115
+ super().__init__()
116
+ # Keep the input dimensions.
117
+ self.num_embeddings = num_embeddings
118
+ self.embedding_dim = embedding_dim
119
+ self.reduce_scatter_embeddings = reduce_scatter_embeddings
120
+ self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()
121
+ # Divide the weight matrix along the vocaburaly dimension.
122
+ (self.vocab_start_index, self.vocab_end_index) = (
123
+ VocabUtility.vocab_range_from_global_vocab_size(
124
+ self.num_embeddings,
125
+ get_tensor_model_parallel_rank(),
126
+ self.tensor_model_parallel_size,
127
+ )
128
+ )
129
+ self.num_embeddings_per_partition = (
130
+ self.vocab_end_index - self.vocab_start_index
131
+ )
132
+ self.deterministic_mode = config.deterministic_mode
133
+ self.weight = (self.num_embeddings_per_partition, self.embedding_dim)
134
+
135
+ def num_parameter(self):
136
+ return self.weight[0] * self.weight[1]
137
+
138
+ def num_activation(self, input_shape: list[int]):
139
+ return cum_mul(input_shape) * self.weight[1]
140
+
141
+ def mock_forward(self, input_shape: list[int]):
142
+ return input_shape + [self.weight[1]]
143
+
144
+
145
+ class Dropout(MemEstimator):
146
+ def __init__(self, p=0, *args, **kwargs):
147
+ super().__init__()
148
+ self.p = p
149
+
150
+ def num_parameter(self):
151
+ return 0
152
+
153
+ def num_activation(self, input_shape: list[int]):
154
+ if self.p == 0:
155
+ return 0
156
+ return cum_mul(input_shape[:])
157
+
158
+ def mock_forward(self, input_shape: list[int]):
159
+ return input_shape
160
+
161
+
162
+ class ColumnParallelLinear(MemEstimator):
163
+ def __init__(
164
+ self,
165
+ input_size,
166
+ output_size,
167
+ *,
168
+ config: ModelParallelConfig,
169
+ init_method,
170
+ bias=True,
171
+ gather_output=False,
172
+ stride=1,
173
+ keep_master_weight_for_test=False,
174
+ skip_bias_add=False,
175
+ skip_weight_param_allocation: bool = False,
176
+ embedding_activation_buffer=None,
177
+ grad_output_buffer=None,
178
+ is_expert: bool = False,
179
+ tp_comm_buffer_name: str = None, # Not used
180
+ disable_grad_reduce: bool = False,
181
+ is_mla: bool = False,
182
+ ):
183
+ super().__init__()
184
+
185
+ if is_mla and config.sequence_parallel:
186
+ tp_size = get_tensor_model_parallel_world_size()
187
+ output_size = divide(output_size, tp_size)
188
+ parallel_mode = None
189
+ tp_size = 1
190
+ tp_group = None
191
+ # Keep input parameters
192
+ self.input_size = input_size
193
+ self.output_size = output_size
194
+ self.gather_output = gather_output
195
+ # Divide the weight matrix along the last dimension.
196
+ self.skip_bias_add = skip_bias_add
197
+ self.is_expert = is_expert
198
+ self.expert_parallel = config.expert_model_parallel_size > 1
199
+ self.embedding_activation_buffer = embedding_activation_buffer
200
+ self.grad_output_buffer = grad_output_buffer
201
+ self.config = config
202
+ self.disable_grad_reduce = disable_grad_reduce
203
+
204
+ if is_expert:
205
+ world_size = get_expert_tensor_parallel_world_size()
206
+ rank = get_expert_tensor_parallel_rank()
207
+ else:
208
+ world_size = get_tensor_model_parallel_world_size()
209
+ rank = get_tensor_model_parallel_rank()
210
+
211
+ self.output_size_per_partition = divide(output_size, world_size)
212
+
213
+ # Parameters.
214
+ # Note: torch.nn.functional.linear performs XA^T + b and as a result
215
+ # we allocate the transpose.
216
+ # Initialize weight.
217
+ if not skip_weight_param_allocation:
218
+ self.weight = (self.output_size_per_partition, self.input_size)
219
+ else:
220
+ self.weight = (self.output_size_per_partition, self.input_size)
221
+
222
+ if bias:
223
+ self.bias = [self.output_size_per_partition]
224
+ else:
225
+ self.bias = None
226
+
227
+ self.sequence_parallel = config.sequence_parallel
228
+ if self.sequence_parallel and world_size <= 1:
229
+ warnings.warn(
230
+ "`sequence_parallel` is set to `True`, but tensor model parallel size "
231
+ f"is {world_size}. Disabling sequence parallel."
232
+ )
233
+ self.sequence_parallel = False
234
+
235
+ self.allreduce_dgrad = (
236
+ world_size > 1
237
+ and not self.sequence_parallel
238
+ and not self.disable_grad_reduce
239
+ )
240
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
241
+
242
+ def num_parameter(self):
243
+ ret = cum_mul(self.weight)
244
+ if self.bias is not None:
245
+ ret += self.bias[0]
246
+ return ret
247
+
248
+ def num_activation(self, input_shape: list[int]):
249
+ return cum_mul(input_shape[:-1]) * self.weight[0]
250
+
251
+ def mock_forward(self, input_shape: list[int]):
252
+ try:
253
+ assert self.weight[-1] == input_shape[-1]
254
+ except:
255
+
256
+ print(f"{self.weight=} {input_shape=}")
257
+ __import__("ipdb").set_trace()
258
+ raise
259
+ return input_shape[:-1] + [self.weight[0]]
260
+
261
+
262
+ class RowParallelLinear(MemEstimator):
263
+ def __init__(
264
+ self,
265
+ input_size: int,
266
+ output_size: int,
267
+ *,
268
+ config: ModelParallelConfig,
269
+ init_method,
270
+ bias: bool,
271
+ input_is_parallel: bool,
272
+ skip_bias_add: bool,
273
+ stride: int = 1,
274
+ keep_master_weight_for_test: bool = False,
275
+ is_expert: bool = False,
276
+ tp_comm_buffer_name: str = None, # Not used
277
+ ):
278
+ super().__init__()
279
+
280
+ # Keep input parameters
281
+ self.input_size = input_size
282
+ self.output_size = output_size
283
+ self.input_is_parallel = input_is_parallel
284
+ self.skip_bias_add = skip_bias_add
285
+ self.config = config
286
+ self.is_expert = is_expert
287
+ self.expert_parallel = config.expert_model_parallel_size > 1
288
+ self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
289
+ self.sequence_parallel = config.sequence_parallel
290
+ if self.sequence_parallel and not self.input_is_parallel:
291
+ raise RuntimeError(
292
+ "To enable `sequence_parallel`, `input_is_parallel` must be `True`"
293
+ )
294
+
295
+ # Divide the weight matrix along the last dimension.
296
+ if self.is_expert:
297
+ world_size = get_expert_tensor_parallel_world_size()
298
+ rank = get_expert_tensor_parallel_rank()
299
+ else:
300
+ world_size = get_tensor_model_parallel_world_size()
301
+ rank = get_tensor_model_parallel_rank()
302
+
303
+ self.input_size_per_partition = divide(input_size, world_size)
304
+
305
+ self.weight = (self.output_size, self.input_size_per_partition)
306
+ if bias:
307
+ self.bias = [self.output_size]
308
+ else:
309
+ self.bias = None
310
+
311
+ def num_parameter(self):
312
+ ret = cum_mul(self.weight)
313
+ if self.bias is not None:
314
+ ret += self.bias[0]
315
+ return ret
316
+
317
+ def num_activation(self, input_shape: list[int]):
318
+ return cum_mul(input_shape[:-1]) * self.weight[1]
319
+
320
+ def mock_forward(self, input_shape: list[int]):
321
+ assert self.weight[0] == input_shape[-1]
322
+ return input_shape[:-1] + [self.weight[1]]
323
+
324
+
325
+ class RMSNorm(MemEstimator):
326
+ def __init__(self, hidden_size: int, *args, **kwargs):
327
+ super().__init__()
328
+ self.weight = hidden_size
329
+
330
+ def num_parameter(self):
331
+ return self.weight
332
+
333
+ def num_activation(self, input_shape: list[int]):
334
+ return cum_mul(input_shape[:])
335
+
336
+ def mock_forward(self, input_shape: list[int]):
337
+ return input_shape
338
+
339
+
340
+ class GetBiasDropoutAdd(MemEstimator):
341
+ def __init__(self, *args, **kwargs):
342
+ super().__init__()
343
+
344
+ def num_parameter(self):
345
+ return 0
346
+
347
+ def num_activation(self, input_shape: list[int]):
348
+ return cum_mul(input_shape[:])
349
+
350
+ def mock_forward(self, input_shape: list[int]):
351
+ return input_shape
352
+
353
+
354
+ get_bias_dropout_add = GetBiasDropoutAdd()
355
+
356
+
357
+ class MLP(MemEstimator):
358
+
359
+ def __init__(
360
+ self,
361
+ config: TransformerConfig,
362
+ submodules,
363
+ is_expert: bool = False,
364
+ input_size: int = None,
365
+ ):
366
+ super().__init__()
367
+
368
+ self.config: TransformerConfig = config
369
+
370
+ self.input_size = input_size if input_size != None else self.config.hidden_size
371
+
372
+ # If this is a gated linear unit we double the output width, see https://arxiv.org/pdf/2002.05202.pdf
373
+ ffn_hidden_size = self.config.ffn_hidden_size
374
+ if self.config.gated_linear_unit:
375
+ ffn_hidden_size *= 2
376
+
377
+ self.linear_fc1 = build_module(
378
+ submodules.linear_fc1,
379
+ self.input_size,
380
+ ffn_hidden_size,
381
+ config=self.config,
382
+ init_method=self.config.init_method,
383
+ gather_output=False,
384
+ bias=self.config.add_bias_linear,
385
+ skip_bias_add=True,
386
+ is_expert=is_expert,
387
+ tp_comm_buffer_name="fc1",
388
+ )
389
+
390
+ self.activation_func = self.config.activation_func
391
+
392
+ self.linear_fc2 = build_module(
393
+ submodules.linear_fc2,
394
+ self.config.ffn_hidden_size,
395
+ self.config.hidden_size,
396
+ config=self.config,
397
+ init_method=self.config.output_layer_init_method,
398
+ bias=self.config.add_bias_linear,
399
+ input_is_parallel=True,
400
+ skip_bias_add=True,
401
+ is_expert=is_expert,
402
+ tp_comm_buffer_name="fc2",
403
+ )
404
+
405
+ def num_parameter(self):
406
+ return self.linear_fc1.num_parameter() + self.linear_fc2.num_parameter()
407
+
408
+ def num_activation(self, input_shape: list[int]):
409
+ result = 0
410
+ result += self.linear_fc1.num_activation(input_shape)
411
+ intermediate_shape = self.linear_fc1.mock_forward(input_shape)
412
+ result += cum_mul(intermediate_shape) / 2 # activation layer
413
+ self.linear_fc2.num_activation(intermediate_shape)
414
+
415
+ return result
416
+
417
+ def mock_forward(self, input_shape: list[int]):
418
+ intermediate_shape = self.linear_fc1.mock_forward(input_shape)
419
+ output_shape = self.linear_fc2.mock_forward(intermediate_shape)
420
+ return output_shape
421
+
422
+
423
+ class ModuleList(MemEstimator):
424
+ def __init__(self, modules: list[MemEstimator] = None):
425
+ super().__init__()
426
+ if modules is None:
427
+ modules = []
428
+ self.modules = modules
429
+
430
+ def __repr__(self):
431
+ """Return a custom repr for ModuleList that compresses repeated module representations."""
432
+ list_of_reprs = [repr(item) for item in self.modules]
433
+ if len(list_of_reprs) == 0:
434
+ return self._get_name() + "()"
435
+
436
+ start_end_indices = [[0, 0]]
437
+ repeated_blocks = [list_of_reprs[0]]
438
+ for i, r in enumerate(list_of_reprs[1:], 1):
439
+ if r == repeated_blocks[-1]:
440
+ start_end_indices[-1][1] += 1
441
+ continue
442
+
443
+ start_end_indices.append([i, i])
444
+ repeated_blocks.append(r)
445
+
446
+ lines = []
447
+ stat = (
448
+ "\t/* n_params="
449
+ + colored(f"{self.num_parameter()/1024/1024:.2f}M", "red")
450
+ + "\tn_act="
451
+ + colored(f"{self.num_activation()/1024/1024:.2f}M", "green")
452
+ + " */"
453
+ )
454
+ main_str = self._get_name() + stat + " ("
455
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
456
+ local_repr = f"({start_id}): {b}" # default repr
457
+
458
+ if start_id != end_id:
459
+ n = end_id - start_id + 1
460
+ local_repr = f"({start_id}-{end_id}): {n} x {b}"
461
+
462
+ local_repr = _addindent(local_repr, 2)
463
+ lines.append(local_repr)
464
+
465
+ main_str += "\n " + "\n ".join(lines) + "\n"
466
+ main_str += ")"
467
+ return main_str
468
+
469
+ def dump(self):
470
+ list_of_reprs = [repr(item) for item in self.modules]
471
+ if len(list_of_reprs) == 0:
472
+ return self._get_name() + "()"
473
+ list_of_dumps = [item.dump() for item in self.modules]
474
+
475
+ start_end_indices = [[0, 0]]
476
+ repeated_blocks = [list_of_reprs[0]]
477
+ repeated_blocks_dump = [list_of_dumps[0]]
478
+ for i, r in enumerate(list_of_reprs[1:], 1):
479
+ if r == repeated_blocks[-1]:
480
+ start_end_indices[-1][1] += 1
481
+ continue
482
+
483
+ start_end_indices.append([i, i])
484
+ repeated_blocks.append(r)
485
+ repeated_blocks_dump(list_of_dumps[i])
486
+ modules = {}
487
+ for (start_id, end_id), b in zip(start_end_indices, repeated_blocks_dump):
488
+ key = f"({start_id})"
489
+ if start_id != end_id:
490
+ n = end_id - start_id + 1
491
+ key = f"({start_id}-{end_id}) {n} layers"
492
+ modules[key] = b
493
+
494
+ ret = {}
495
+ ret["name"] = self._get_name()
496
+ ret["n_params"] = self.num_parameter()
497
+ ret["n_act"] = self.num_activation()
498
+ if len(modules) > 0:
499
+ ret["modules"] = modules
500
+ return ret
501
+
502
+ def append(self, m: MemEstimator):
503
+ self.modules.append(m)
504
+
505
+ def __len__(
506
+ self,
507
+ ):
508
+ return self.modules.__len__()
509
+
510
+ def num_parameter(self):
511
+ return sum([x.num_parameter() for x in self.modules])
512
+
513
+ def num_activation(self, input_shape: list[int]):
514
+ result = 0
515
+ for m in self.modules:
516
+ result += m.num_activation(input_shape)
517
+ input_shape = m.mock_forward(input_shape)
518
+
519
+ return result
520
+
521
+ def mock_forward(self, input_shape: list[int]):
522
+ for m in self.modules:
523
+ result += m.num_activation(input_shape)
524
+ input_shape = m.mock_forward(input_shape)
525
+ return input_shape
526
+
527
+
528
+ class SequentialMLP(MemEstimator):
529
+ def __init__(self, num_local_experts, config: TransformerConfig, submodules):
530
+ super().__init__()
531
+ self.config = config
532
+ self.add_bias = config.add_bias_linear
533
+ self.moe_extended_tp = config.moe_extended_tp
534
+ self.num_local_experts = num_local_experts
535
+ self.local_experts = ModuleList()
536
+ for _ in range(self.num_local_experts):
537
+ expert = MLP(self.config, submodules, is_expert=True)
538
+ self.local_experts.append(expert)
539
+
540
+ def num_parameter(self):
541
+ return self.local_experts.num_parameter()
542
+
543
+ def num_activation(self, input_shape: list[int], tokens_per_expert=None):
544
+ # assume all the inputs are routed equally
545
+ all_tokens = input_shape[1]
546
+ result = 0
547
+ for m in self.local_experts.modules:
548
+ result += m.num_activation(
549
+ input_shape[:1]
550
+ + [all_tokens // self.num_local_experts]
551
+ + input_shape[2:]
552
+ )
553
+ return result
554
+
555
+ def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
556
+ # assume all the inputs are routed to the first expert
557
+ input_shape = self.local_experts.modules[0].mock_forward(input_shape)
558
+ return input_shape
559
+
560
+
561
+ class TEGroupedMLP(MemEstimator):
562
+ """An efficient implementation of the Experts layer using TE's GroupedLinear.
563
+
564
+ Executes multiple experts in parallel to maximize computational efficiency.
565
+ """
566
+
567
+ def __init__(self, num_local_experts, config: TransformerConfig, submodules):
568
+ super().__init__()
569
+ self.config = config
570
+ self.moe_extended_tp = config.moe_extended_tp
571
+ self.num_local_experts = num_local_experts
572
+ self.input_size = self.config.hidden_size
573
+
574
+ # Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf
575
+ ffn_hidden_size = self.config.moe_ffn_hidden_size
576
+ if self.config.gated_linear_unit:
577
+ ffn_hidden_size *= 2
578
+
579
+ self.linear_fc1 = build_module(
580
+ submodules.linear_fc1,
581
+ self.num_local_experts,
582
+ self.input_size,
583
+ ffn_hidden_size,
584
+ config=self.config,
585
+ init_method=self.config.init_method,
586
+ bias=self.config.add_bias_linear,
587
+ skip_bias_add=True,
588
+ is_expert=True,
589
+ tp_comm_buffer_name="fc1",
590
+ )
591
+
592
+ self.activation_func = self.config.activation_func
593
+
594
+ self.activation_recompute = (
595
+ self.config.recompute_granularity == "selective"
596
+ and "moe_act" in self.config.recompute_modules
597
+ )
598
+ self.linear_fc2 = build_module(
599
+ submodules.linear_fc2,
600
+ self.num_local_experts,
601
+ self.config.moe_ffn_hidden_size,
602
+ self.config.hidden_size,
603
+ config=self.config,
604
+ init_method=self.config.output_layer_init_method,
605
+ bias=self.config.add_bias_linear,
606
+ skip_bias_add=True,
607
+ is_expert=True,
608
+ tp_comm_buffer_name="fc2",
609
+ )
610
+ # TODO if self.config.fp8:
611
+
612
+ def num_parameter(self):
613
+ ret = self.linear_fc1.num_parameter()
614
+ ret += self.linear_fc2.num_parameter()
615
+ return ret
616
+
617
+ def num_activation(self, input_shape: list[int], tokens_per_expert=None):
618
+ ret = 0
619
+ if not self.activation_recompute:
620
+ ret += self.linear_fc1.num_activation(input_shape)
621
+ input_shape = self.linear_fc1.mock_forward(input_shape)
622
+
623
+ # activation
624
+ if not self.activation_recompute:
625
+ ret += cum_mul(input_shape) / 2 # swiglu or gelu
626
+ input_shape = deepcopy(input_shape)
627
+ input_shape[-1] //= 2
628
+
629
+ self.linear_fc2.num_activation(input_shape)
630
+ return ret
631
+
632
+ def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
633
+ # assume all the inputs are routed to the first expert
634
+ input_shape = self.local_experts.modules[0].mock_forward(input_shape)
635
+ return input_shape
636
+
637
+
638
+ class TEGroupedLinear(MemEstimator):
639
+ def __init__(
640
+ self,
641
+ num_gemms: int,
642
+ input_size: int,
643
+ output_size: int,
644
+ *,
645
+ parallel_mode: str,
646
+ config: ModelParallelConfig,
647
+ init_method,
648
+ bias: bool,
649
+ skip_bias_add: bool,
650
+ is_expert: bool = False,
651
+ tp_comm_buffer_name: str = None,
652
+ ):
653
+ super().__init__()
654
+ self.config = config
655
+
656
+ # TE returns a zero length Tensor when bias=False and
657
+ # return_bias=True, but we prefer None. So in that case we
658
+ # tell TE to not return the bias, and return None
659
+ # ourselves. This way our forward always returns two values
660
+ # and we don't have to deal with the zero length Tensor.
661
+ self.te_return_bias = skip_bias_add and bias
662
+ self.is_first_microbatch = True
663
+ self.disable_parameter_transpose_cache = (
664
+ self.config.disable_parameter_transpose_cache
665
+ )
666
+
667
+ extra_kwargs = _get_extra_te_kwargs(config)
668
+ extra_kwargs["ub_name"] = tp_comm_buffer_name
669
+
670
+ self.expert_parallel = self.config.expert_model_parallel_size > 1
671
+ if self.expert_parallel:
672
+ extra_kwargs["rng_tracker_name"] = get_expert_parallel_rng_tracker_name()
673
+
674
+ # For MoE models, the comms between TP and EP group is explicitly handled by
675
+ # MoE token dispatcher. So we disable comms by making TE agnostic of model parallel.
676
+ self.explicit_expert_comm = is_expert and (
677
+ config.tensor_model_parallel_size > 1 or self.expert_parallel
678
+ )
679
+ if is_expert:
680
+ tp_size = get_expert_tensor_parallel_world_size()
681
+ else:
682
+ tp_size = get_tensor_model_parallel_world_size()
683
+ if self.explicit_expert_comm:
684
+ if parallel_mode == "column":
685
+ output_size = divide(output_size, tp_size)
686
+ elif parallel_mode == "row":
687
+ input_size = divide(input_size, tp_size)
688
+ parallel_mode = None
689
+ tp_size = 1
690
+ assert not bias, "bias is not considered for now"
691
+
692
+ self.num_gemms = num_gemms
693
+ self.input_size = input_size
694
+ self.output_size = output_size
695
+
696
+ def num_parameter(self):
697
+ ret = self.num_gemms * self.input_size * self.output_size
698
+ return ret
699
+
700
+ def num_activation(self, input_shape: list[int], tokens_per_expert=None):
701
+ ret = cum_mul(self.mock_forward(input_shape))
702
+ return ret
703
+
704
+ def mock_forward(self, input_shape: list[int], tokens_per_expert=None):
705
+ return input_shape[:-1] + [self.output_size]
706
+
707
+
708
+ class TEColumnParallelGroupedLinear(TEGroupedLinear):
709
+ def __init__(
710
+ self,
711
+ num_gemms: int,
712
+ input_size: int,
713
+ output_size: int,
714
+ *,
715
+ config: ModelParallelConfig,
716
+ init_method,
717
+ bias: bool,
718
+ skip_bias_add: bool,
719
+ is_expert: bool,
720
+ tp_comm_buffer_name: str = None,
721
+ ):
722
+ super().__init__(
723
+ num_gemms=num_gemms,
724
+ input_size=input_size,
725
+ output_size=output_size,
726
+ parallel_mode="column",
727
+ config=config,
728
+ init_method=condition_init_method(config, init_method),
729
+ bias=bias,
730
+ skip_bias_add=skip_bias_add,
731
+ is_expert=is_expert,
732
+ tp_comm_buffer_name=tp_comm_buffer_name,
733
+ )
734
+
735
+
736
+ class TERowParallelGroupedLinear(TEGroupedLinear):
737
+ def __init__(
738
+ self,
739
+ num_gemms: int,
740
+ input_size: int,
741
+ output_size: int,
742
+ *,
743
+ config: ModelParallelConfig,
744
+ init_method,
745
+ bias: bool,
746
+ skip_bias_add: bool,
747
+ is_expert: bool,
748
+ tp_comm_buffer_name: str = None,
749
+ ):
750
+
751
+ super().__init__(
752
+ num_gemms=num_gemms,
753
+ input_size=input_size,
754
+ output_size=output_size,
755
+ parallel_mode="row",
756
+ config=config,
757
+ init_method=condition_init_method(config, init_method),
758
+ bias=bias,
759
+ skip_bias_add=skip_bias_add,
760
+ is_expert=is_expert,
761
+ tp_comm_buffer_name=tp_comm_buffer_name,
762
+ )
763
+
764
+
765
+ class SharedExpertMLP(MLP):
766
+ """
767
+ MLP layer for Shared Experts.
768
+ """
769
+
770
+ def __init__(self, config: TransformerConfig, spec: ModuleSpec):
771
+ config = deepcopy(config)
772
+ assert (
773
+ config.add_bias_linear == False
774
+ ), "bias is not supported in the shared experts, "
775
+ "please set '--disable-bias-linear' instead."
776
+
777
+ config.ffn_hidden_size = config.moe_shared_expert_intermediate_size
778
+ super().__init__(config=config, submodules=spec.submodules)
779
+
780
+ self.use_shared_expert_gate = spec.params.get("gate", False)
781
+ if self.use_shared_expert_gate:
782
+ assert False, "use_shared_expert_gate is not Implemented"
783
+ # self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size)))
784
+ # if config.perform_initialization:
785
+ # if get_cuda_rng_tracker().is_initialized():
786
+ # with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()):
787
+ # config.init_method(self.gate_weight)
788
+ # else:
789
+ # config.init_method(self.gate_weight)
790
+ # self.gate_weight.data = self.gate_weight.data.to(dtype=config.params_dtype)
791
+ # setattr(self.gate_weight, 'sequence_parallel', self.config.sequence_parallel)
792
+ else:
793
+ self.gate_weight = None
794
+
795
+
796
+ class TransformerBlock(MemEstimator):
797
+ """Transformer class."""
798
+
799
+ def __init__(
800
+ self,
801
+ config: TransformerConfig,
802
+ spec: Union[TransformerBlockSubmodules, ModuleSpec],
803
+ post_layer_norm: bool = True,
804
+ pre_process: bool = True,
805
+ post_process: bool = True,
806
+ vp_stage: Optional[int] = None,
807
+ ):
808
+ super().__init__()
809
+ self.config = config
810
+
811
+ self.submodules = _get_block_submodules(config, spec, vp_stage)
812
+ self.post_layer_norm = post_layer_norm
813
+ self.pre_process = pre_process
814
+ self.post_process = post_process
815
+ self.vp_stage = vp_stage
816
+ self.cuda_graphs = {}
817
+ self.current_microbatch = -1
818
+ self.input_tensor = None
819
+ self.checkpoint_core_attention = (
820
+ self.config.recompute_granularity == "selective"
821
+ and "core_attn" in self.config.recompute_modules
822
+ )
823
+
824
+ self._build_layers()
825
+ self.num_layers_per_pipeline_rank = len(self.layers)
826
+ self.tp_only_amax_red = config.tp_only_amax_red
827
+
828
+ def _build_layers(self):
829
+ def build_layer(layer_spec, layer_number):
830
+ return build_module(
831
+ layer_spec,
832
+ config=self.config,
833
+ layer_number=layer_number,
834
+ vp_stage=self.vp_stage,
835
+ )
836
+
837
+ # offset is implicit in TransformerLayer
838
+ self.layers = ModuleList(
839
+ [
840
+ build_layer(layer_spec, i + 1)
841
+ for i, layer_spec in enumerate(self.submodules.layer_specs)
842
+ ]
843
+ )
844
+
845
+ if self.submodules.layer_norm and self.post_process and self.post_layer_norm:
846
+ self.final_layernorm = build_module(
847
+ self.submodules.layer_norm,
848
+ config=self.config,
849
+ hidden_size=self.config.hidden_size,
850
+ eps=self.config.layernorm_epsilon,
851
+ )
852
+ else:
853
+ self.final_layernorm = None # Either this or nn.Identity
854
+
855
+ def num_parameter(self):
856
+ ret = self.layers.num_parameter()
857
+ if self.final_layernorm is not None:
858
+ ret += self.final_layernorm.num_parameter()
859
+
860
+ return ret
861
+
862
+ def num_activation(self, input_shape: list[int]):
863
+ result = self.layers.num_activation(input_shape)
864
+ if self.final_layernorm is not None:
865
+ result += self.final_layernorm.num_activation(input_shape)
866
+ return result
867
+
868
+ def mock_forward(self, input_shape: list[int]):
869
+ return input_shape
870
+
871
+
872
+ class TopKRouter(MemEstimator):
873
+
874
+ def __init__(self, config: TransformerConfig) -> None:
875
+ super().__init__()
876
+ self.config = config
877
+ self.topk = self.config.moe_router_topk
878
+ self.routing_type = self.config.moe_router_load_balancing_type
879
+ self.input_jitter = None
880
+
881
+ def num_parameter(self):
882
+ return 0
883
+
884
+ def num_activation(self, input_shape: list[int]):
885
+ result = cum_mul(input_shape) * 2 # sinkhorn and sinkhorn activation
886
+ return result
887
+
888
+ def mock_forward(self, input_shape: list[int]):
889
+ return input_shape[:-1] + [self.topk]
890
+
891
+
892
+ class MoELayer(MemEstimator):
893
+
894
+ def __init__(
895
+ self, config: TransformerConfig, submodules=None, layer_number: int = None
896
+ ):
897
+ super().__init__()
898
+ self.config = config
899
+ self.submodules = submodules
900
+ self.moe_layer_recompute = config.moe_layer_recompute
901
+
902
+ self.expert_parallel_size = get_expert_model_parallel_world_size()
903
+ assert (
904
+ self.expert_parallel_size > 0
905
+ ), "Expected non-negative expert parallel size"
906
+
907
+ assert self.config.num_moe_experts % self.expert_parallel_size == 0
908
+ self.num_local_experts = (
909
+ self.config.num_moe_experts // self.expert_parallel_size
910
+ )
911
+ local_expert_indices_offset = (
912
+ get_expert_model_parallel_rank() * self.num_local_experts
913
+ )
914
+
915
+ self.moe_layer_recompute = (
916
+ config.recompute_granularity == "selective"
917
+ and "moe" in config.recompute_modules
918
+ )
919
+
920
+ self.router = TopKRouter(config=self.config)
921
+ self.use_shared_expert = (
922
+ self.config.moe_shared_expert_intermediate_size is not None
923
+ )
924
+ self.shared_expert_overlap = self.config.moe_shared_expert_overlap
925
+
926
+ self.local_expert_indices = [
927
+ local_expert_indices_offset + i for i in range(self.num_local_experts)
928
+ ]
929
+ assert all(
930
+ map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices)
931
+ )
932
+
933
+ self.experts = None
934
+ self.shared_experts = None
935
+ self.token_dispatcher = None
936
+ self.layer_number = layer_number
937
+ # Initialize experts
938
+ self.experts = build_module(
939
+ self.submodules.experts, self.num_local_experts, self.config
940
+ )
941
+
942
+ # Initialize shared experts
943
+ if self.use_shared_expert:
944
+ self.shared_experts = SharedExpertMLP(
945
+ self.config, self.submodules.shared_experts
946
+ )
947
+ # if self.shared_expert_overlap:
948
+ # self.token_dispatcher.set_shared_experts(self.shared_experts)
949
+
950
+ def num_parameter(self):
951
+ ret = self.experts.num_parameter() + self.router.num_parameter()
952
+ if self.use_shared_expert:
953
+ ret += self.shared_experts.num_parameter()
954
+ return ret
955
+
956
+ def num_activation(self, input_shape: list[int]):
957
+ if self.moe_layer_recompute:
958
+ return 0
959
+ tp_size = get_tensor_model_parallel_world_size()
960
+ etp_size = get_expert_tensor_parallel_world_size()
961
+ new_input_shape = deepcopy(input_shape)
962
+ new_input_shape[1] = input_shape[1] // tp_size * etp_size
963
+ input_shape = new_input_shape
964
+
965
+ result = self.router.num_activation(input_shape)
966
+ result += cum_mul(input_shape) * self.router.topk # token dispatcher
967
+ moe_input_shape_average = deepcopy(input_shape)
968
+ moe_input_shape_average[1] = int(moe_input_shape_average[1] * self.router.topk)
969
+
970
+ result += self.experts.num_activation(moe_input_shape_average)
971
+ if self.use_shared_expert:
972
+ result += self.shared_experts.num_activation(input_shape)
973
+
974
+ if self.config.moe_layer_recompute:
975
+ result = cum_mul(input_shape) * 2
976
+ return result
977
+
978
+ def mock_forward(self, input_shape: list[int]):
979
+ return input_shape
980
+
981
+
982
+ class IdentityOp(MemEstimator):
983
+ def num_parameter(self):
984
+ return 0
985
+
986
+ def num_activation(self, input_shape: list[int]):
987
+ return 0
988
+
989
+ def mock_forward(self, input_shape: list[int]):
990
+ return input_shape
991
+
992
+
993
+ IdentityFuncOp = IdentityOp
994
+ TERowParallelLinear = RowParallelLinear
995
+ TEColumnParallelLinear = ColumnParallelLinear
996
+ TELayerNormColumnParallelLinear = ColumnParallelLinear
997
+
998
+
999
+ class TEDotProductAttention(MemEstimator):
1000
+ def __init__(self, config: TransformerConfig, *args, **kwargs):
1001
+ super().__init__()
1002
+ self.config = config
1003
+
1004
+ def num_parameter(self):
1005
+ return 0
1006
+
1007
+ def num_activation(
1008
+ self, q_shape: list[int], k_shape: list[int], v_shape: list[int]
1009
+ ):
1010
+ bs, seqs, heads, dim = q_shape
1011
+ if self.config.multi_latent_attention and False:
1012
+ result = bs * seqs * seqs * heads
1013
+ else:
1014
+ bs, seqs, heads, dim = k_shape
1015
+ result = (
1016
+ bs * seqs * dim * heads * 2 # * self.config.tensor_model_parallel_size
1017
+ ) # flash attention
1018
+ if self.config.context_parallel_size > 1:
1019
+ result *= 2
1020
+ return result
1021
+
1022
+ def mock_forward(
1023
+ self,
1024
+ hidden_size: int,
1025
+ q_shape: list[int],
1026
+ k_shape: list[int],
1027
+ v_shape: list[int],
1028
+ ):
1029
+ seqs, bs, heads, dim = q_shape
1030
+ return [seqs, bs, hidden_size]
1031
+
1032
+
1033
+ class TransformerLayer(MemEstimator):
1034
+ def __init__(
1035
+ self,
1036
+ config: TransformerConfig,
1037
+ submodules,
1038
+ layer_number: int = 1,
1039
+ hidden_dropout: float = None,
1040
+ vp_stage: Optional[int] = None,
1041
+ ):
1042
+ super().__init__()
1043
+ self.config = config
1044
+
1045
+ if config.enable_cuda_graph and self.training:
1046
+ assert (
1047
+ not config.cpu_offloading and config.recompute_granularity is None
1048
+ ), "Cudagraphs not supported"
1049
+ self.cudagraph_manager = CudaGraphManager()
1050
+
1051
+ self.submodules_config = submodules
1052
+ self.layer_number = layer_number + get_transformer_layer_offset(
1053
+ self.config, vp_stage
1054
+ )
1055
+ self.hidden_dropout = (
1056
+ config.hidden_dropout if hidden_dropout is None else hidden_dropout
1057
+ )
1058
+
1059
+ # [Module 1: Input Layernorm] Optional Layernorm on the input data
1060
+ # TODO: add pytorch only layernorm
1061
+ self.input_layernorm = build_module(
1062
+ submodules.input_layernorm,
1063
+ config=self.config,
1064
+ hidden_size=self.config.hidden_size,
1065
+ eps=self.config.layernorm_epsilon,
1066
+ )
1067
+
1068
+ # [Module 2: SelfAttention]
1069
+ self.self_attention = build_module(
1070
+ submodules.self_attention, config=self.config, layer_number=layer_number
1071
+ )
1072
+
1073
+ # [Module 3: BiasDropoutFusion]
1074
+ self.self_attn_bda = build_module(submodules.self_attn_bda)
1075
+
1076
+ # [Module 4: Post SelfAttention] Optional Layernorm after self-attn
1077
+ self.pre_cross_attn_layernorm = build_module(
1078
+ submodules.pre_cross_attn_layernorm,
1079
+ config=self.config,
1080
+ hidden_size=self.config.hidden_size,
1081
+ eps=self.config.layernorm_epsilon,
1082
+ )
1083
+
1084
+ # [Module 5: CrossAttention]
1085
+ self.cross_attention = build_module(
1086
+ submodules.cross_attention, config=self.config, layer_number=layer_number
1087
+ )
1088
+
1089
+ # [Module 6: BiasDropoutFusion]
1090
+ self.cross_attn_bda = build_module(
1091
+ submodules.cross_attn_bda, config=self.config
1092
+ )
1093
+
1094
+ # [Module 7: Pre MLP] Optional Layernorm before MLP
1095
+ self.pre_mlp_layernorm = build_module(
1096
+ submodules.pre_mlp_layernorm,
1097
+ config=self.config,
1098
+ hidden_size=self.config.hidden_size,
1099
+ eps=self.config.layernorm_epsilon,
1100
+ )
1101
+
1102
+ # [Module 8: MLP block]
1103
+ self.mlp = build_module(submodules.mlp, config=self.config)
1104
+ if hasattr(self.mlp, "set_layer_number"):
1105
+ self.mlp.set_layer_number(self.layer_number)
1106
+
1107
+ # [Module 9: BiasDropoutFusion]
1108
+ self.mlp_bda = build_module(submodules.mlp_bda)
1109
+
1110
+ self.recompute_input_layernorm = False
1111
+ self.recompute_pre_mlp_layernorm = False
1112
+ self.recompute_mlp = False
1113
+ if self.config.recompute_granularity == "selective":
1114
+ if "layernorm" in self.config.recompute_modules:
1115
+ if not isinstance(self.input_layernorm, IdentityOp):
1116
+ self.recompute_input_layernorm = True
1117
+ if not isinstance(self.pre_mlp_layernorm, IdentityOp):
1118
+ self.recompute_pre_mlp_layernorm = True
1119
+ if "mlp" in self.config.recompute_modules:
1120
+
1121
+ if not isinstance(self.mlp, MoELayer):
1122
+ self.recompute_mlp = True
1123
+
1124
+ def num_parameter(self):
1125
+ result = self.input_layernorm.num_parameter()
1126
+ result += self.self_attention.num_parameter()
1127
+ result += self.pre_cross_attn_layernorm.num_parameter()
1128
+ result += self.cross_attention.num_parameter()
1129
+ result += self.cross_attn_bda.num_parameter()
1130
+ result += self.pre_mlp_layernorm.num_parameter()
1131
+ result += self.mlp.num_parameter()
1132
+
1133
+ return result
1134
+
1135
+ def num_activation(self, input_shape: list[int]):
1136
+ result = 0
1137
+ result += self.self_attention.num_activation(input_shape)
1138
+ if not self.recompute_mlp:
1139
+ result += self.mlp.num_activation(input_shape)
1140
+ # __import__('ipdb').set_trace()
1141
+ # sequence parallel
1142
+ if self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1:
1143
+ input_shape = deepcopy(input_shape)
1144
+ input_shape[1] /= self.config.tensor_model_parallel_size
1145
+ if not self.recompute_input_layernorm:
1146
+ result += self.input_layernorm.num_activation(input_shape)
1147
+ if not self.recompute_pre_mlp_layernorm:
1148
+ result += self.pre_mlp_layernorm.num_activation(input_shape)
1149
+ result += self.self_attn_bda.num_activation(input_shape)
1150
+ result += self.mlp_bda.num_activation(input_shape)
1151
+ return result
1152
+
1153
+ def mock_forward(self, input_shape: list[int]):
1154
+ return input_shape
1155
+
1156
+
1157
+ class SelfAttention(MemEstimator):
1158
+
1159
+ def __init__(
1160
+ self,
1161
+ config: TransformerConfig,
1162
+ submodules,
1163
+ layer_number: int,
1164
+ attn_mask_type,
1165
+ ):
1166
+ super().__init__()
1167
+
1168
+ self.config = config
1169
+ self.layer_number = layer_number
1170
+ self.attn_mask_type = attn_mask_type
1171
+ self.attention_type = ""
1172
+
1173
+ # For normal attention without groups, num_query_groups == num_attention_heads,
1174
+ # so these two will be the same
1175
+ self.query_projection_size = (
1176
+ self.config.kv_channels * self.config.num_attention_heads
1177
+ )
1178
+ self.kv_projection_size = self.config.kv_channels * self.config.num_query_groups
1179
+
1180
+ # Per attention head and per partition values.
1181
+ world_size = get_tensor_model_parallel_world_size()
1182
+ self.hidden_size_per_attention_head = divide(
1183
+ self.query_projection_size, self.config.num_attention_heads
1184
+ )
1185
+ self.num_attention_heads_per_partition = divide(
1186
+ self.config.num_attention_heads, world_size
1187
+ )
1188
+ self.num_query_groups_per_partition = divide(
1189
+ self.config.num_query_groups, world_size
1190
+ )
1191
+ self.core_attention = build_module(
1192
+ submodules.core_attention,
1193
+ config=self.config,
1194
+ layer_number=self.layer_number,
1195
+ attn_mask_type=self.attn_mask_type,
1196
+ )
1197
+ self.linear_qkv = build_module(
1198
+ submodules.linear_qkv,
1199
+ self.config.hidden_size,
1200
+ self.query_projection_size + 2 * self.kv_projection_size,
1201
+ config=self.config,
1202
+ init_method=self.config.init_method,
1203
+ gather_output=False,
1204
+ bias=self.config.add_bias_linear or self.config.add_qkv_bias,
1205
+ skip_bias_add=False,
1206
+ is_expert=False,
1207
+ tp_comm_buffer_name="qkv",
1208
+ )
1209
+
1210
+ if submodules.q_layernorm is not None:
1211
+ self.q_layernorm = build_module(
1212
+ submodules.q_layernorm,
1213
+ hidden_size=self.hidden_size_per_attention_head,
1214
+ config=self.config,
1215
+ eps=self.config.layernorm_epsilon,
1216
+ )
1217
+ else:
1218
+ self.q_layernorm = None
1219
+
1220
+ if submodules.k_layernorm is not None:
1221
+ self.k_layernorm = build_module(
1222
+ submodules.k_layernorm,
1223
+ hidden_size=self.hidden_size_per_attention_head,
1224
+ config=self.config,
1225
+ eps=self.config.layernorm_epsilon,
1226
+ )
1227
+ else:
1228
+ self.k_layernorm = None
1229
+ self.linear_proj = build_module(
1230
+ submodules.linear_proj,
1231
+ self.query_projection_size,
1232
+ self.config.hidden_size,
1233
+ config=self.config,
1234
+ init_method=self.config.output_layer_init_method,
1235
+ bias=self.config.add_bias_linear,
1236
+ input_is_parallel=True,
1237
+ skip_bias_add=True,
1238
+ is_expert=False,
1239
+ tp_comm_buffer_name="proj",
1240
+ )
1241
+ self.checkpoint_core_attention = (
1242
+ self.config.recompute_granularity == "selective"
1243
+ )
1244
+
1245
+ def num_parameter(self):
1246
+ result = 0
1247
+ result += self.core_attention.num_parameter()
1248
+ result += self.linear_proj.num_parameter()
1249
+ result += self.linear_qkv.num_parameter()
1250
+ if self.q_layernorm is not None:
1251
+ result += self.q_layernorm.num_parameter()
1252
+ if self.k_layernorm is not None:
1253
+ result += self.k_layernorm.num_parameter()
1254
+
1255
+ return result
1256
+
1257
+ def num_activation(self, input_shape: list[int]):
1258
+ ret = 0
1259
+ ## in estimator: act(linear) = 1.5*cum_mul(input_shape)
1260
+ ## in reality: act(linear) = cum_mul(input_shape), act(rotary) = cum_mul(input_shape), act(attn_forward_func_with_cp) = cum_mul(input_shape)
1261
+ # ret += self.linear_qkv.num_activation(input_shape)
1262
+ mixed_qkv_shape = self.linear_qkv.mock_forward(input_shape)
1263
+ new_tensor_shape = mixed_qkv_shape[:-1] + [
1264
+ self.num_query_groups_per_partition,
1265
+ (
1266
+ (
1267
+ self.num_attention_heads_per_partition
1268
+ // self.num_query_groups_per_partition
1269
+ + 2
1270
+ )
1271
+ * self.hidden_size_per_attention_head
1272
+ ),
1273
+ ]
1274
+ split_arg_list = [
1275
+ (
1276
+ self.num_attention_heads_per_partition
1277
+ // self.num_query_groups_per_partition
1278
+ * self.hidden_size_per_attention_head
1279
+ ),
1280
+ self.hidden_size_per_attention_head,
1281
+ self.hidden_size_per_attention_head,
1282
+ ]
1283
+ # [sq, b, ng, (np/ng + 2) * hn]
1284
+ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
1285
+ q_shape = new_tensor_shape[:-1] + [split_arg_list[0]]
1286
+ k_shape = new_tensor_shape[:-1] + [split_arg_list[1]]
1287
+ v_shape = new_tensor_shape[:-1] + [split_arg_list[2]]
1288
+ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
1289
+ q_shape = (
1290
+ q_shape[:2]
1291
+ + [cum_mul(q_shape[-2:]) // self.hidden_size_per_attention_head]
1292
+ + [self.hidden_size_per_attention_head]
1293
+ )
1294
+
1295
+ if not self.checkpoint_core_attention:
1296
+ ret += self.core_attention.num_activation(q_shape, k_shape, v_shape)
1297
+ ret += self.linear_proj.num_activation(input_shape)
1298
+ ## in reality: act(linear) = cum_mul(input_shape), act(rotary) = cum_mul(input_shape), act(attn_forward_func_with_cp) = cum_mul(input_shape)
1299
+ ret += self.linear_proj.num_activation(input_shape) * 3
1300
+
1301
+ return ret
1302
+
1303
+ def mock_forward(self, input_shape: list[int]):
1304
+ return input_shape
1305
+
1306
+
1307
+ class Linear(MemEstimator):
1308
+ def __init__(
1309
+ self,
1310
+ in_features: int,
1311
+ out_features: int,
1312
+ bias: bool = True,
1313
+ device=None,
1314
+ dtype=None,
1315
+ ) -> None:
1316
+
1317
+ super().__init__()
1318
+ self.weight = (in_features, out_features)
1319
+
1320
+ def num_parameter(self):
1321
+ return self.weight[0] * self.weight[1]
1322
+
1323
+ def num_activation(self, input_shape: list[int]):
1324
+ return cum_mul(input_shape[:-1]) * self.weight[1]
1325
+
1326
+ def mock_forward(self, input_shape: list[int]):
1327
+ return input_shape[:-1] + [self.weight[1]]
1328
+
1329
+
1330
+ class MLASelfAttention(MemEstimator):
1331
+ """MLA Self-attention layer class
1332
+
1333
+ Self-attention layer takes input with size [s, b, h]
1334
+ and returns output of the same size.
1335
+ """
1336
+
1337
+ def __init__(
1338
+ self,
1339
+ config: MLATransformerConfig,
1340
+ submodules,
1341
+ layer_number: int,
1342
+ attn_mask_type=AttnMaskType.padding,
1343
+ ) -> None:
1344
+
1345
+ super().__init__()
1346
+ self.config = config
1347
+ self.layer_number = layer_number
1348
+ self.attn_mask_type = attn_mask_type
1349
+ self.attention_type = "self"
1350
+ self.world_size = get_tensor_model_parallel_world_size()
1351
+ # assert (
1352
+ # world_size == 1
1353
+ # ), "MLA is not supported with Tensor Parallelism yet, \
1354
+ # use Expert Parallelism and Pipeline Parallelism for better performance."
1355
+
1356
+ self.query_projection_size = (
1357
+ self.config.v_head_dim * self.config.num_attention_heads
1358
+ )
1359
+
1360
+ self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim
1361
+
1362
+ mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale)
1363
+ self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim)
1364
+
1365
+ # Per attention head and per partition values.
1366
+ world_size = get_tensor_model_parallel_world_size()
1367
+ self.hidden_size_per_attention_head = divide(
1368
+ self.query_projection_size, self.config.num_attention_heads
1369
+ )
1370
+ self.num_attention_heads_per_partition = divide(
1371
+ self.config.num_attention_heads, world_size
1372
+ )
1373
+ self.num_query_groups_per_partition = divide(
1374
+ self.config.num_query_groups, world_size
1375
+ )
1376
+ # TODO Rotary Embedding
1377
+ # self.rotary_pos_emb = YarnRotaryEmbedding(
1378
+ # self.config.qk_pos_emb_head_dim,
1379
+ # rotary_base=self.config.rotary_base,
1380
+ # scaling_factor=self.config.rotary_scaling_factor,
1381
+ # original_max_position_embeddings=self.config.max_position_embeddings,
1382
+ # beta_fast=self.config.beta_fast,
1383
+ # beta_slow=self.config.beta_slow,
1384
+ # mscale=self.config.mscale,
1385
+ # mscale_all_dim=self.config.mscale_all_dim,
1386
+ # )
1387
+
1388
+ self.core_attention = build_module(
1389
+ submodules.core_attention,
1390
+ config=self.config,
1391
+ layer_number=self.layer_number,
1392
+ attn_mask_type=self.attn_mask_type,
1393
+ attention_type=self.attention_type,
1394
+ softmax_scale=self.softmax_scale,
1395
+ k_channels=self.q_head_dim,
1396
+ v_channels=self.config.v_head_dim,
1397
+ )
1398
+
1399
+ if self.config.q_lora_rank is None:
1400
+ # Not projectiing query
1401
+ self.linear_q_proj = build_module(
1402
+ submodules.linear_q_proj,
1403
+ self.config.hidden_size,
1404
+ self.config.num_attention_heads * self.q_head_dim,
1405
+ config=self.config,
1406
+ init_method=self.config.init_method,
1407
+ gather_output=False,
1408
+ bias=False,
1409
+ skip_bias_add=False,
1410
+ is_expert=False,
1411
+ is_mla=True,
1412
+ )
1413
+
1414
+ else:
1415
+ self.linear_q_down_proj = Linear(
1416
+ self.config.hidden_size, self.config.q_lora_rank, bias=False
1417
+ )
1418
+
1419
+ self.linear_q_up_proj = build_module(
1420
+ submodules.linear_q_up_proj,
1421
+ self.config.q_lora_rank,
1422
+ self.config.num_attention_heads * self.q_head_dim,
1423
+ config=self.config,
1424
+ init_method=self.config.init_method,
1425
+ gather_output=False,
1426
+ bias=False,
1427
+ skip_bias_add=False,
1428
+ is_expert=False,
1429
+ is_mla=True,
1430
+ )
1431
+ self.linear_kv_down_proj = Linear(
1432
+ self.config.hidden_size,
1433
+ self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim,
1434
+ bias=False,
1435
+ )
1436
+
1437
+ self.linear_kv_up_proj = build_module(
1438
+ submodules.linear_kv_up_proj,
1439
+ self.config.kv_lora_rank,
1440
+ self.config.num_attention_heads
1441
+ * (self.config.qk_head_dim + self.config.v_head_dim),
1442
+ config=self.config,
1443
+ init_method=self.config.init_method,
1444
+ gather_output=False,
1445
+ bias=False,
1446
+ skip_bias_add=False,
1447
+ is_expert=False,
1448
+ is_mla=True,
1449
+ )
1450
+
1451
+ if self.config.q_lora_rank is not None:
1452
+ self.q_layernorm = build_module(
1453
+ submodules.q_layernorm,
1454
+ hidden_size=self.config.q_lora_rank,
1455
+ config=self.config,
1456
+ eps=self.config.layernorm_epsilon,
1457
+ )
1458
+
1459
+ self.kv_layernorm = build_module(
1460
+ submodules.kv_layernorm,
1461
+ hidden_size=self.config.kv_lora_rank,
1462
+ config=self.config,
1463
+ eps=self.config.layernorm_epsilon,
1464
+ )
1465
+
1466
+ # Output.
1467
+ self.linear_proj = build_module(
1468
+ submodules.linear_proj,
1469
+ self.query_projection_size,
1470
+ self.config.hidden_size,
1471
+ config=self.config,
1472
+ init_method=self.config.output_layer_init_method,
1473
+ bias=self.config.add_bias_linear,
1474
+ input_is_parallel=True,
1475
+ skip_bias_add=True,
1476
+ is_expert=False,
1477
+ tp_comm_buffer_name="proj",
1478
+ )
1479
+
1480
+ self.checkpoint_core_attention = (
1481
+ self.config.recompute_granularity == "selective"
1482
+ )
1483
+
1484
+ def num_parameter(self):
1485
+ result = 0
1486
+ result += self.core_attention.num_parameter()
1487
+ result += self.linear_proj.num_parameter()
1488
+ if self.config.q_lora_rank is None:
1489
+ result += self.linear_q_proj.num_parameter()
1490
+ else:
1491
+ result += self.linear_q_down_proj.num_parameter()
1492
+ result += self.linear_q_up_proj.num_parameter()
1493
+ result += self.linear_kv_down_proj.num_parameter()
1494
+ result += self.linear_kv_up_proj.num_parameter()
1495
+ result += self.kv_layernorm.num_parameter()
1496
+ if self.config.q_lora_rank is not None:
1497
+ result += self.q_layernorm.num_parameter()
1498
+
1499
+ return result
1500
+
1501
+ def num_activation(self, input_shape: list[int]):
1502
+ q_len, bsz, _ = input_shape
1503
+ ret = 0
1504
+ if self.config.q_lora_rank is not None:
1505
+ ret += self.linear_q_down_proj.num_activation(input_shape)
1506
+ q_compressed_shape = self.linear_q_down_proj.mock_forward(input_shape)
1507
+ ret += self.q_layernorm.num_activation(q_compressed_shape)
1508
+ ret += self.linear_q_up_proj.num_activation(q_compressed_shape)
1509
+ q_shape = self.linear_q_up_proj.mock_forward(q_compressed_shape)
1510
+ else:
1511
+ # hidden_states:[s, b, 2048], q: [s, b, n * 192]
1512
+ ret += self.linear_q_proj.num_activation(input_shape)
1513
+ q_shape = self.linear_q_proj.mock_forward(input_shape)
1514
+
1515
+ # kv_combined: [s, b, 576]
1516
+ ret += self.linear_kv_down_proj.num_activation(input_shape)
1517
+ kv_combined_shape = self.linear_kv_down_proj.mock_forward(input_shape)
1518
+ # kv_compressed:[s, b, 512], k_pos_emb: [s, b, 64]
1519
+ kv_compressed_shape = kv_combined_shape[:-1] + [self.config.kv_lora_rank]
1520
+
1521
+ # kv: [s, b, 2048]
1522
+ ret += self.kv_layernorm.num_activation(kv_compressed_shape)
1523
+ ret += self.linear_kv_up_proj.num_activation(kv_compressed_shape)
1524
+
1525
+ q_shape = [q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim]
1526
+ k_shape = [q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim]
1527
+ v_shape = [
1528
+ q_len,
1529
+ bsz,
1530
+ self.num_attention_heads_per_partition,
1531
+ self.config.v_head_dim,
1532
+ ]
1533
+
1534
+ if not self.checkpoint_core_attention:
1535
+ ret += self.core_attention.num_activation(q_shape, k_shape, v_shape)
1536
+
1537
+ ret += self.linear_proj.num_activation(input_shape)
1538
+
1539
+ return ret
1540
+
1541
+ def mock_forward(self, input_shape: list[int]):
1542
+ return input_shape
1543
+
1544
+
1545
+ class TENorm:
1546
+ def __new__(cls, config: TransformerConfig, hidden_size: int, eps: float = 1e-5):
1547
+ from megatron.core.extensions.transformer_engine import _get_extra_te_kwargs, te
1548
+
1549
+ if config.normalization == "LayerNorm":
1550
+ # TODO layernorm
1551
+ pass
1552
+ elif config.normalization == "RMSNorm":
1553
+ assert hasattr(
1554
+ te.pytorch, "RMSNorm"
1555
+ ), "Transformer-Engine >= v0.11 required to use this feature"
1556
+ instance = RMSNorm(
1557
+ hidden_size=hidden_size,
1558
+ eps=eps,
1559
+ sequence_parallel=config.sequence_parallel,
1560
+ zero_centered_gamma=config.layernorm_zero_centered_gamma,
1561
+ **_get_extra_te_kwargs(config),
1562
+ )
1563
+ else:
1564
+ raise Exception("Only LayerNorm and RMSNorm are curently supported")
1565
+
1566
+ return instance
1567
+
1568
+
1569
+ def build_module(
1570
+ spec_or_module: Union[ModuleSpec, type], *args, **kwargs
1571
+ ) -> MemEstimator:
1572
+ """replace module with MemEstimators"""
1573
+ if isinstance(spec_or_module, types.FunctionType):
1574
+ return globals()[spec_or_module.__name__]
1575
+
1576
+ if isinstance(spec_or_module, ModuleSpec) and isinstance(
1577
+ spec_or_module.module, types.FunctionType
1578
+ ):
1579
+ assert False
1580
+ return spec_or_module.module
1581
+
1582
+ if isinstance(spec_or_module, type):
1583
+ module = spec_or_module
1584
+ elif hasattr(spec_or_module, "module") and isinstance(spec_or_module.module, type):
1585
+ module = spec_or_module.module
1586
+ else:
1587
+ module = import_module(spec_or_module.module)
1588
+
1589
+ if isinstance(module, types.FunctionType):
1590
+ assert False
1591
+ return module
1592
+
1593
+ if hasattr(spec_or_module, "submodules") and spec_or_module.submodules is not None:
1594
+ kwargs["submodules"] = spec_or_module.submodules
1595
+
1596
+ try:
1597
+ module = globals()[module.__name__]
1598
+ return module(
1599
+ *args,
1600
+ **spec_or_module.params if hasattr(spec_or_module, "params") else {},
1601
+ **kwargs,
1602
+ )
1603
+ except Exception as e:
1604
+ # import ipdb
1605
+
1606
+ # ipdb.set_trace()
1607
+ # improve the error message since we hide the module name in the line above
1608
+ import sys
1609
+
1610
+ raise type(e)(f"{str(e)} when instantiating {module.__name__}").with_traceback(
1611
+ sys.exc_info()[2]
1612
+ )
1613
+
1614
+
1615
+ from megatron.core.transformer.transformer_block import (
1616
+ BaseTransformerLayer,
1617
+ LayerNormImpl,
1618
+ TransformerBlockSubmodules,
1619
+ )
1620
+
1621
+
1622
+ def _get_block_submodules(
1623
+ config: TransformerConfig,
1624
+ spec: Union[TransformerBlockSubmodules, ModuleSpec],
1625
+ vp_stage: Optional[int] = None,
1626
+ ) -> TransformerBlockSubmodules:
1627
+ """
1628
+ Retrieve or construct TransformerBlockSubmodules based on the provided specification.
1629
+
1630
+ Args:
1631
+ config (TransformerConfig): Configuration object for the transformer model.
1632
+ spec (Union[TransformerBlockSubmodules, ModuleSpec]): Specification for the
1633
+ transformer block submodules. Can be either a TransformerBlockSubmodules
1634
+ instance or a ModuleSpec.
1635
+
1636
+ Returns:
1637
+ TransformerBlockSubmodules: The submodules for the transformer block.
1638
+ """
1639
+
1640
+ # Transformer block submodules.
1641
+ if isinstance(spec, TransformerBlockSubmodules):
1642
+ return spec
1643
+
1644
+ # ModuleSpec here is generally assumed to be for a transformer layer that
1645
+ # is implemented in `transformer_layer.py` or if it subclasses
1646
+ # `BaseTransformerLayer` from the `transformer_layer.py` file.
1647
+ elif isinstance(spec, ModuleSpec):
1648
+ if issubclass(spec.module, TransformerBlock):
1649
+ return spec.submodules
1650
+ elif issubclass(spec.module, BaseTransformerLayer):
1651
+ num_layers = get_num_layers_to_build(config, vp_stage)
1652
+ return TransformerBlockSubmodules(
1653
+ layer_specs=[spec] * num_layers, layer_norm=LayerNormImpl
1654
+ )
1655
+ else:
1656
+ raise Exception(f"specialize for {spec.module.__name__}.")
1657
+ else:
1658
+ raise Exception(f"specialize for {type(spec).__name__}.")
1659
+
1660
+
1661
+ from megatron.core.transformer.transformer_block import get_num_layers_to_build
1662
+
1663
+
1664
+ def ___get_num_layers_to_build(config: TransformerConfig) -> int:
1665
+ """
1666
+ Determine the number of transformer layers to build for the current pipeline stage.
1667
+ Args:
1668
+ config (TransformerConfig): Configuration object containing transformer model parameters.
1669
+
1670
+ Returns:
1671
+ int: The number of layers to be built for the current pipeline stage.
1672
+ """
1673
+ if (
1674
+ config.num_layers_in_first_pipeline_stage is not None
1675
+ or config.num_layers_in_last_pipeline_stage is not None
1676
+ ):
1677
+
1678
+ assert not (
1679
+ config.account_for_embedding_in_pipeline_split
1680
+ or config.account_for_loss_in_pipeline_split
1681
+ ), " \
1682
+ Does not support standalone embedding stage and standalone loss stage with uneven pp"
1683
+ # Number of layers to distribute over rest of pipeline stages
1684
+ layers_to_distribute = config.num_layers
1685
+ # Number of pipeline stages left for distributing transformer layers
1686
+ pipeline_stages_left = get_pipeline_model_parallel_world_size()
1687
+
1688
+ # If the uneven first (last) pipeline stage is enabled, remove the specified number
1689
+ # of layers to calculate the number of layers on each middle pipeline stage.
1690
+ if config.num_layers_in_first_pipeline_stage is not None:
1691
+ layers_to_distribute -= config.num_layers_in_first_pipeline_stage
1692
+ pipeline_stages_left -= 1
1693
+
1694
+ if config.num_layers_in_last_pipeline_stage is not None:
1695
+ layers_to_distribute -= config.num_layers_in_last_pipeline_stage
1696
+ pipeline_stages_left -= 1
1697
+
1698
+ assert (
1699
+ layers_to_distribute % pipeline_stages_left == 0
1700
+ ), "With uneven pipelineing the left over layers must be divisible by left over stages"
1701
+ num_layers_per_pipeline_rank = layers_to_distribute // pipeline_stages_left
1702
+
1703
+ # If the uneven first (last) pipeline stage is enabled, return the specified number
1704
+ # of layers for all virtual pipeline parallel stages within the first (last) pipeline
1705
+ # parallel stage.
1706
+ if (
1707
+ is_pipeline_first_stage(ignore_virtual=True)
1708
+ and config.num_layers_in_first_pipeline_stage is not None
1709
+ ):
1710
+ num_layers_per_pipeline_rank = config.num_layers_in_first_pipeline_stage
1711
+
1712
+ if (
1713
+ is_pipeline_last_stage(ignore_virtual=True)
1714
+ and config.num_layers_in_last_pipeline_stage is not None
1715
+ ):
1716
+ num_layers_per_pipeline_rank = config.num_layers_in_last_pipeline_stage
1717
+ else:
1718
+ # Include the embedding layer and loss layer into pipeline parallelism partition
1719
+ num_layers = config.num_layers
1720
+ if config.account_for_embedding_in_pipeline_split:
1721
+ num_layers += 1
1722
+
1723
+ if config.account_for_loss_in_pipeline_split:
1724
+ num_layers += 1
1725
+
1726
+ assert (
1727
+ num_layers % config.pipeline_model_parallel_size == 0
1728
+ ), "num_layers should be divisible by pipeline_model_parallel_size"
1729
+ num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size
1730
+
1731
+ # if get_virtual_pipeline_model_parallel_world_size() is not None:
1732
+ # # Interleaved pipeline parallelism:
1733
+ # # Number of layers in each model chunk is the number of layers in the stage,
1734
+ # # divided by the number of model chunks in a stage.
1735
+ # # With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
1736
+ # # layers to stages like (each list is a model chunk):
1737
+ # # Stage 0: [0] [2] [4] [6]
1738
+ # # Stage 1: [1] [3] [5] [7]
1739
+ # # With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
1740
+ # # layers to stages like (each list is a model chunk):
1741
+ # # Stage 0: [0, 1] [4, 5]
1742
+ # # Stage 1: [2, 3] [6, 7]
1743
+ # vp_size = get_virtual_pipeline_model_parallel_world_size()
1744
+
1745
+ # assert (
1746
+ # num_layers_per_pipeline_rank % vp_size == 0
1747
+ # ), "num_layers_per_pipeline_rank should be divisible by vp_size"
1748
+ # num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
1749
+
1750
+ # num_layers_to_build = num_layers_per_virtual_rank
1751
+
1752
+ # else:
1753
+ # # Non-interleaved pipeline parallelism:
1754
+ # # Each stage gets a contiguous set of layers.
1755
+ # num_layers_to_build = num_layers_per_pipeline_rank
1756
+ num_layers_to_build = num_layers_per_pipeline_rank
1757
+ # The embedding (or loss) layer cannot function as a standalone transformer layer
1758
+ # Reduce the number of layers to construct by 1 on the first (or last) stage if the
1759
+ # embedding (or loss) layer is included in the pipeline parallelism partition and placement.
1760
+ if is_pipeline_first_stage() and config.account_for_embedding_in_pipeline_split:
1761
+ num_layers_to_build -= 1
1762
+ assert (
1763
+ num_layers_to_build >= 0
1764
+ ), "Not enough layers in the first virtual pipeline stage"
1765
+
1766
+ if is_pipeline_last_stage() and config.account_for_loss_in_pipeline_split:
1767
+ num_layers_to_build -= 1
1768
+ assert (
1769
+ num_layers_to_build >= 0
1770
+ ), "Not enough layers in the last virtual pipeline stage"
1771
+
1772
+ return num_layers_to_build
1773
+
1774
+
1775
+ from megatron.core.transformer.enums import LayerType
1776
+
1777
+
1778
+ def get_transformer_layer_offset(
1779
+ config: TransformerConfig, vp_stage: Optional[int] = None
1780
+ ):
1781
+ """Get the index offset of current pipeline stage, given the level of pipelining."""
1782
+ pipeline_rank = get_pipeline_model_parallel_rank()
1783
+
1784
+ if config.pipeline_model_parallel_size > 1:
1785
+
1786
+ if config.pipeline_model_parallel_layout:
1787
+ offset = config.pipeline_model_parallel_layout.get_layer_offset(
1788
+ layer_type=LayerType.decoder, vp_stage=vp_stage
1789
+ )
1790
+ elif (
1791
+ config.num_layers_in_first_pipeline_stage is not None
1792
+ or config.num_layers_in_last_pipeline_stage is not None
1793
+ ):
1794
+ # Calculate number of pipeline stages to distribute the remaining Transformer
1795
+ # layers after deducting the Transformer layers in the first or the last stages
1796
+ middle_pipeline_stages = config.pipeline_model_parallel_size
1797
+ middle_pipeline_stages -= sum(
1798
+ [
1799
+ 1 if x is not None else 0
1800
+ for x in (
1801
+ config.num_layers_in_first_pipeline_stage,
1802
+ config.num_layers_in_last_pipeline_stage,
1803
+ )
1804
+ ]
1805
+ )
1806
+
1807
+ # Calculate layers to distribute in each pipeline stage. If the
1808
+ # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage
1809
+ # are not set, we will not enable uneven pipeline. All layers will be treated
1810
+ # as middle layers.
1811
+ num_layers_in_first_pipeline_stage = (
1812
+ 0
1813
+ if config.num_layers_in_first_pipeline_stage is None
1814
+ else config.num_layers_in_first_pipeline_stage
1815
+ )
1816
+ num_layers_in_last_pipeline_stage = (
1817
+ 0
1818
+ if config.num_layers_in_last_pipeline_stage is None
1819
+ else config.num_layers_in_last_pipeline_stage
1820
+ )
1821
+
1822
+ middle_num_layers = (
1823
+ config.num_layers
1824
+ - num_layers_in_first_pipeline_stage
1825
+ - num_layers_in_last_pipeline_stage
1826
+ )
1827
+
1828
+ if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
1829
+ assert (
1830
+ vp_stage is not None
1831
+ ), "vp_stage must be provided if virtual pipeline model parallel size is set"
1832
+
1833
+ # Calculate number of layers in each virtual model chunk
1834
+ # If the num_layers_in_first_pipeline_stage and
1835
+ # num_layers_in_last_pipeline_stage are not set, all pipeline stages
1836
+ # will be treated as middle pipeline stages in the calculation
1837
+ num_layers_per_virtual_model_chunk_in_first_pipeline_stage = (
1838
+ 0
1839
+ if config.num_layers_in_first_pipeline_stage is None
1840
+ else config.num_layers_in_first_pipeline_stage // vp_size
1841
+ )
1842
+
1843
+ num_layers_per_virtual_model_chunk_in_last_pipeline_stage = (
1844
+ 0
1845
+ if config.num_layers_in_last_pipeline_stage is None
1846
+ else config.num_layers_in_last_pipeline_stage // vp_size
1847
+ )
1848
+
1849
+ num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = (
1850
+ middle_num_layers // vp_size
1851
+ )
1852
+
1853
+ # First stage + middle stage + last stage
1854
+ total_virtual_chunks = (
1855
+ num_layers_per_virtual_model_chunk_in_first_pipeline_stage
1856
+ + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
1857
+ + num_layers_per_virtual_model_chunk_in_last_pipeline_stage
1858
+ )
1859
+
1860
+ # Calculate the layer offset with interleaved uneven pipeline parallelism
1861
+ if pipeline_rank == 0:
1862
+ offset = vp_stage * total_virtual_chunks
1863
+ else:
1864
+ offset = (
1865
+ vp_stage * total_virtual_chunks
1866
+ + num_layers_per_virtual_model_chunk_in_first_pipeline_stage
1867
+ + (pipeline_rank - 1)
1868
+ * (
1869
+ num_layers_per_vritual_model_chunk_in_middle_pipeline_stage
1870
+ // middle_pipeline_stages
1871
+ )
1872
+ )
1873
+ else:
1874
+ if middle_pipeline_stages > 0:
1875
+ num_layers_per_pipeline_rank = (
1876
+ middle_num_layers // middle_pipeline_stages
1877
+ )
1878
+ else:
1879
+ num_layers_per_pipeline_rank = 0
1880
+
1881
+ middle_pipeline_rank = (
1882
+ pipeline_rank
1883
+ if config.num_layers_in_first_pipeline_stage is None
1884
+ else pipeline_rank - 1
1885
+ )
1886
+
1887
+ if pipeline_rank == 0:
1888
+ offset = 0
1889
+ else:
1890
+ offset = (
1891
+ middle_pipeline_rank * num_layers_per_pipeline_rank
1892
+ ) + num_layers_in_first_pipeline_stage
1893
+ else:
1894
+ num_layers = config.num_layers
1895
+
1896
+ # Increase the number of layers by one if we include the embedding (loss)
1897
+ # layer into pipeline parallelism partition and placement
1898
+ if config.account_for_embedding_in_pipeline_split:
1899
+ num_layers += 1
1900
+
1901
+ if config.account_for_loss_in_pipeline_split:
1902
+ num_layers += 1
1903
+
1904
+ num_layers_per_pipeline_rank = (
1905
+ num_layers // config.pipeline_model_parallel_size
1906
+ )
1907
+
1908
+ if (vp_size := config.virtual_pipeline_model_parallel_size) is not None:
1909
+ assert (
1910
+ vp_stage is not None
1911
+ ), "vp_stage must be provided if virtual pipeline model parallel size is set"
1912
+
1913
+ num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size
1914
+ total_virtual_chunks = num_layers // vp_size
1915
+ offset = vp_stage * total_virtual_chunks + (
1916
+ pipeline_rank * num_layers_per_virtual_rank
1917
+ )
1918
+
1919
+ # Reduce the offset of embedding layer from the total layer number
1920
+ if (
1921
+ config.account_for_embedding_in_pipeline_split
1922
+ and not is_pipeline_first_stage(
1923
+ ignore_virtual=False, vp_stage=vp_stage
1924
+ )
1925
+ ):
1926
+ offset -= 1
1927
+ else:
1928
+ offset = pipeline_rank * num_layers_per_pipeline_rank
1929
+
1930
+ # Reduce the offset of embedding layer from the total layer number
1931
+ if (
1932
+ config.account_for_embedding_in_pipeline_split
1933
+ and not is_pipeline_first_stage(
1934
+ ignore_virtual=False, vp_stage=vp_stage
1935
+ )
1936
+ ):
1937
+ offset -= 1
1938
+ else:
1939
+ offset = 0
1940
+ return offset
webui/index.html ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Megatron Memory Estimator</title>
7
+ <link rel="stylesheet" href="style.css">
8
+ </head>
9
+ <body>
10
+ <div class="container">
11
+ <h1>Megatron Memory Estimator v0.13</h1>
12
+ <div class="disclaimer-banner">
13
+ Note: This estimator only measures the GPU memory directly managed by PyTorch when running Megatron. It does not include extra consumption from NCCL communication buffers, kernel fusion, overlap optimizations, CUDA Graphs, etc. Please use the "Overhead per GPU" option below to account for these additional costs.
14
+ </div>
15
+
16
+ <div class="main-layout">
17
+ <div class="top-section">
18
+ <div class="config-column">
19
+ <form id="config-form">
20
+ <h2>Configuration</h2>
21
+ <p class="config-hint" style="font-size: 0.9em; color: #666; margin-top: -0.5em; margin-bottom: 1em;">
22
+ For detailed explanations of each parameter, please see the&nbsp;<a href="https://github.com/NVIDIA/Megatron-LM/blob/core_r0.13.0/megatron/training/arguments.py#L2266" target="_blank">Megatron-LM arguments documentation</a>.
23
+ </p>
24
+ <div class="form-group">
25
+ <label for="model-select">Select a Local Config:</label>
26
+ <select id="model-select" name="model">
27
+ <option value="">Loading...</option>
28
+ </select>
29
+ </div>
30
+
31
+ <!-- All settings are now in one block -->
32
+ <div class="form-row">
33
+ <div class="form-group">
34
+ <label for="num-gpus">Total GPUs:</label>
35
+ <input type="number" id="num-gpus" name="num_gpus" value="8" step="8" min="8">
36
+ </div>
37
+ <div class="form-group">
38
+ <label for="mbs">micro batch size:</label>
39
+ <input type="number" id="mbs" name="mbs" value="1" min="1">
40
+ </div>
41
+ <div class="form-group">
42
+ <label for="seq-len">SeqLen:</label>
43
+ <input type="number"id="seq-len" name="seq-len" value="4096" min="1">
44
+ </div>
45
+ </div>
46
+
47
+ <div class="form-group">
48
+ <input type="checkbox" id="use-distributed-optimizer" name="use_distributed_optimizer" checked>
49
+ <label for="use-distributed-optimizer" class="inline-label">Use Distributed Optimizer</label>
50
+ </div>
51
+
52
+ <!-- 新增:Embedding/Loss Pipeline Split 选项 -->
53
+ <div class="form-group vpp-dependent" style="display: none;">
54
+ <input type="checkbox" id="account_for_embedding_in_pipeline_split" name="account_for_embedding_in_pipeline_split">
55
+ <label for="account_for_embedding_in_pipeline_split" class="inline-label">Account for Embedding in PP Split</label>
56
+ </div>
57
+ <div class="form-group vpp-dependent" style="display: none;">
58
+ <input type="checkbox" id="account_for_loss_in_pipeline_split" name="account_for_loss_in_pipeline_split">
59
+ <label for="account_for_loss_in_pipeline_split" class="inline-label">Account for Loss in PP Split</label>
60
+ </div>
61
+ <!-- 选项结束 -->
62
+
63
+ <div class="form-row">
64
+ <div class="form-group">
65
+ <label for="recompute-granularity">Recomputation:</label>
66
+ <select id="recompute-granularity" name="recompute_granularity">
67
+ <option value="none">None</option>
68
+ <option value="selective">Selective</option>
69
+ <option value="full">Full</option>
70
+ </select>
71
+ </div>
72
+ <div class="form-group recompute-options" style="display: none;">
73
+ <label for="recompute-method">Method:</label>
74
+ <select id="recompute-method" name="recompute_method">
75
+ <option value="uniform">Uniform</option>
76
+ <option value="block">Block</option>
77
+ </select>
78
+ </div>
79
+ <div class="form-group recompute-options" style="display: none;">
80
+ <label for="recompute-num-layers">Layers:</label>
81
+ <input type="number" id="recompute-num-layers" name="recompute_num_layers" value="1" min="1">
82
+ </div>
83
+ </div>
84
+
85
+ <!-- 新增:Selective Recompute 模块选择 -->
86
+ <div class="form-row selective-options" style="display: none;">
87
+ <div class="form-group">
88
+ <label><input type="checkbox" name="recompute_modules" value="core_attn"> core_attn</label>
89
+ </div>
90
+ <div class="form-group">
91
+ <label><input type="checkbox" name="recompute_modules" value="moe_act"> moe_act</label>
92
+ </div>
93
+ <div class="form-group">
94
+ <label><input type="checkbox" name="recompute_modules" value="layernorm"> layernorm</label>
95
+ </div>
96
+ <div class="form-group">
97
+ <label><input type="checkbox" name="recompute_modules" value="mla_up_proj"> mla_up_proj</label>
98
+ </div>
99
+ <div class="form-group">
100
+ <label><input type="checkbox" name="recompute_modules" value="mlp"> mlp</label>
101
+ </div>
102
+ <div class="form-group">
103
+ <label><input type="checkbox" name="recompute_modules" value="moe"> moe</label>
104
+ </div>
105
+ </div>
106
+ <!-- Selective Recompute 结束 -->
107
+
108
+ <div class="form-row">
109
+ <div class="form-group">
110
+ <label for="tp">TP:</label>
111
+ <select id="tp" name="tp"></select>
112
+ </div>
113
+ <div class="form-group">
114
+ <label for="pp">PP:</label>
115
+ <input type="number" id="pp" name="pp" value="1" min="1">
116
+ </div>
117
+ <div class="form-group">
118
+ <label for="ep">EP:</label>
119
+ <select id="ep" name="ep"></select>
120
+ </div>
121
+ <div class="form-group">
122
+ <label for="cp">CP:</label>
123
+ <select id="cp" name="cp"></select>
124
+ </div>
125
+ </div>
126
+ <div class="form-row">
127
+ <div class="form-group">
128
+ <label for="vpp">VPP:</label>
129
+ <input type="number" id="vpp" name="vpp" placeholder="None" min="1">
130
+ </div>
131
+ <div class="form-group">
132
+ <label for="etp">ETP:</label>
133
+ <input type="number" id="etp" name="etp" placeholder="None" min="1">
134
+ </div>
135
+ </div>
136
+ <div class="form-row">
137
+ <div class="form-group">
138
+ <label for="num_layers_in_first_pipeline_stage">First Stage Layers:</label>
139
+ <input type="number" id="num_layers_in_first_pipeline_stage" name="num_layers_in_first_pipeline_stage" placeholder="None" min="0">
140
+ </div>
141
+ <div class="form-group">
142
+ <label for="num_layers_in_last_pipeline_stage">Last Stage Layers:</label>
143
+ <input type="number" id="num_layers_in_last_pipeline_stage" name="num_layers_in_last_pipeline_stage" placeholder="None" min="0">
144
+ </div>
145
+ </div>
146
+ <div class="form-row">
147
+ <div class="form-group">
148
+ <label for="overhead">Overhead per GPU:</label>
149
+ <select id="overhead" name="overhead">
150
+ <option value="5">5GB</option>
151
+ <option value="10" selected>10GB</option>
152
+ </select>
153
+ </div>
154
+ </div>
155
+ <!-- Pipeline Layout Row Added -->
156
+ <div class="form-row">
157
+ <div class="form-group" style="width: 100%;">
158
+ <label for="pipeline-layout">Pipeline Layout (comma-separated layers per stage):</label>
159
+ <input type="text" id="pipeline-layout" name="pipeline_model_parallel_layout" placeholder="e.g., Et|(tt|)*30L">
160
+ </div>
161
+ </div>
162
+ <!-- End Pipeline Layout Row -->
163
+
164
+ <div id="validation-message" class="error-message" style="display: none;"></div>
165
+ <div class="button-container">
166
+ <button type="submit">Estimate</button>
167
+ </div>
168
+ </form>
169
+ </div>
170
+
171
+ <div class="output-column">
172
+ <div class="config-editor-wrapper">
173
+ <h2>Model Config (Editable)</h2>
174
+ <textarea id="config-editor" rows="20"></textarea>
175
+ </div>
176
+ </div>
177
+ </div>
178
+
179
+ <div class="bottom-section">
180
+ <div id="output-container">
181
+ <div id="loading" style="display: none;">Calculating...</div>
182
+ <div id="history-wrapper">
183
+ <h3>History</h3>
184
+ <table id="history-table">
185
+ <thead>
186
+ <tr>
187
+ <th>Model</th>
188
+ <th>Weight Gradient Optimizer (GB)</th>
189
+ <th>Activation (GB)</th>
190
+ <th>Total (GB/GPU)</th>
191
+ <th>Actions</th>
192
+ </tr>
193
+ </thead>
194
+ <tbody>
195
+ </tbody>
196
+ </table>
197
+ <button id="clear-history" style="margin-top: 1em;">Clear History</button>
198
+ </div>
199
+ </div>
200
+ </div>
201
+ </div>
202
+ </div>
203
+ <script src="script.js"></script>
204
+ <footer class="footer">
205
+ <p>&copy; 2025 <a href="https://github.com/ISEEKYAN" target="_blank">ISEEKYAN</a>. Developed at NVIDIA.</p>
206
+ </footer>
207
+ </body>
208
+ </html>
webui/main.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import os
5
+ import tempfile
6
+ from typing import Optional
7
+
8
+ import requests
9
+ from estimate_013 import estimate_from_config
10
+ from fastapi import Body, FastAPI
11
+ from fastapi.responses import FileResponse
12
+ from fastapi.staticfiles import StaticFiles
13
+ from megatron.core import parallel_state as mpu
14
+ from pydantic import BaseModel, field_validator
15
+
16
+ from mbridge import AutoBridge
17
+
18
+ # The directory of the current script (main.py)
19
+ WEBUI_DIR = os.path.dirname(os.path.abspath(__file__))
20
+
21
+ app = FastAPI()
22
+
23
+ # Mount static files from the webui directory
24
+ app.mount("/static", StaticFiles(directory=WEBUI_DIR), name="static")
25
+
26
+
27
+ @app.get("/")
28
+ async def read_index():
29
+ return FileResponse(os.path.join(WEBUI_DIR, "index.html"))
30
+
31
+
32
+ @app.get("/style.css")
33
+ async def read_css():
34
+ return FileResponse(os.path.join(WEBUI_DIR, "style.css"))
35
+
36
+
37
+ @app.get("/script.js")
38
+ async def read_js():
39
+ return FileResponse(os.path.join(WEBUI_DIR, "script.js"))
40
+
41
+
42
+ SUPPORTED_MODELS = [
43
+ "Qwen/Qwen3-235B-A22B",
44
+ "Qwen/Qwen3-30B-A3B",
45
+ "Qwen/Qwen3-32B",
46
+ "Qwen/Qwen3-14B",
47
+ "Qwen/Qwen3-8B",
48
+ "Qwen/Qwen2.5-7B",
49
+ "Qwen/Qwen2.5-14B",
50
+ "Qwen/Qwen2.5-32B",
51
+ "Qwen/Qwen2.5-72B",
52
+ "moonshotai/Moonlight-16B-A3B",
53
+ "moonshotai/Kimi-K2-Instruct",
54
+ "deepseek-ai/DeepSeek-V3",
55
+ "XiaomiMiMo/MiMo-7B-RL",
56
+ ]
57
+
58
+
59
+ @app.get("/local-hf-configs")
60
+ async def get_supported_models():
61
+ """Return the list of HF model identifiers supported by the UI."""
62
+ return SUPPORTED_MODELS
63
+
64
+
65
+ @app.get("/get-megatron-config/{model_path:path}")
66
+ async def get_remote_hf_config(model_path: str):
67
+ """Fetch the HuggingFace config.json for the given model id."""
68
+ url = f"https://huggingface.co/{model_path}/raw/main/config.json"
69
+ try:
70
+ resp = requests.get(url, timeout=10)
71
+ resp.raise_for_status()
72
+ return resp.json()
73
+ except Exception as e:
74
+ return {"error": f"Failed to fetch config from {url}: {str(e)}"}
75
+
76
+
77
+ class MBridgeEstimateConfig(BaseModel):
78
+ hf_model_path: str
79
+ custom_hf_config: Optional[dict] = None # Renamed for clarity
80
+
81
+ # Hardware & Training
82
+ num_gpus: int = 8
83
+ mbs: int = 1
84
+ seq_len: int = 4096
85
+ use_distributed_optimizer: bool = True
86
+ # Recompute settings are now part of the main config
87
+ recompute_granularity: str = "selective"
88
+ recompute_method: str = "uniform"
89
+ recompute_num_layers: Optional[int] = 1
90
+
91
+ # Selective recompute modules (optional list only used when granularity==selective)
92
+ recompute_modules: Optional[list[str]] = None
93
+
94
+ # 新增:Embedding/Loss PP Split 选项
95
+ account_for_embedding_in_pipeline_split: bool = False
96
+ account_for_loss_in_pipeline_split: bool = False
97
+
98
+ # Parallelism
99
+ tp: int = 1
100
+ pp: int = 1
101
+ ep: int = 1
102
+ cp: int = 1
103
+ vpp: Optional[int] = None
104
+ etp: Optional[int] = None
105
+
106
+ # Pipeline stage layer counts
107
+ num_layers_in_first_pipeline_stage: Optional[int] = None
108
+ num_layers_in_last_pipeline_stage: Optional[int] = None
109
+
110
+ # New field: custom pipeline-model-parallel layout
111
+ pipeline_model_parallel_layout: Optional[str] = None # Comma-separated ints
112
+
113
+ @field_validator("num_gpus")
114
+ def num_gpus_must_be_multiple_of_8(cls, v):
115
+ if v <= 0 or v % 8 != 0:
116
+ raise ValueError("must be a positive multiple of 8")
117
+ return v
118
+
119
+
120
+ def patch_parallel_states(config: MBridgeEstimateConfig):
121
+ from mbridge.core.parallel_states import ParallelStates
122
+
123
+ ParallelStates.get_default_parallel_states = lambda: ParallelStates(
124
+ tp_size=config.tp,
125
+ pp_size=config.pp,
126
+ ep_size=config.ep,
127
+ cp_size=config.cp,
128
+ vpp_size=config.vpp,
129
+ etp_size=config.etp,
130
+ )
131
+
132
+
133
+ @app.post("/estimate_with_mbridge")
134
+ async def estimate_with_mbridge(config: MBridgeEstimateConfig):
135
+ # Validate Inputs
136
+ if config.num_gpus <= 0 or config.num_gpus % 8 != 0:
137
+ return {"error": "Total number of GPUs must be a positive multiple of 8."}
138
+
139
+ parallel_product = config.tp * config.pp * config.cp
140
+ if parallel_product == 0: # Avoid division by zero
141
+ return {"error": "Parallelism dimensions (TP, PP, CP) cannot be zero."}
142
+
143
+ if config.num_gpus % parallel_product != 0:
144
+ return {
145
+ "error": f"Number of GPUs ({config.num_gpus}) must be divisible by the product of TP*PP*CP ({parallel_product})."
146
+ }
147
+
148
+ patch_parallel_states(config)
149
+
150
+ # If the path is just a filename, assume it's in our local model-configs dir
151
+ hf_model_path = config.hf_model_path
152
+ # This logic needs to change. The custom config from the UI is an HF config, not a Megatron config.
153
+ # We need to load it via a temporary file.
154
+ if config.custom_hf_config:
155
+ try:
156
+ # Create a temporary file to save the custom HF config
157
+ with tempfile.NamedTemporaryFile(
158
+ mode="w+",
159
+ delete=False,
160
+ suffix=".json",
161
+ dir=os.path.join("/dev/shm"),
162
+ ) as tmp:
163
+ json.dump(config.custom_hf_config, tmp)
164
+ tmp_path = tmp.name
165
+
166
+ # Load the bridge from the temporary config file
167
+ from transformers import AutoConfig
168
+
169
+ AutoConfig.trust_remote_code = True
170
+ bridge = AutoBridge.from_pretrained(tmp_path)
171
+ tf_config = bridge.config
172
+ hf_config = bridge.hf_config
173
+
174
+ finally:
175
+ # Ensure the temporary file is deleted
176
+ if "tmp_path" in locals() and os.path.exists(tmp_path):
177
+ os.remove(tmp_path)
178
+ else:
179
+ # If no custom config, load from the original path
180
+ if not os.path.isabs(hf_model_path) and not hf_model_path.startswith(
181
+ ("http", "./", "../")
182
+ ):
183
+ hf_model_path = os.path.join("/dev/shm", hf_model_path)
184
+ bridge = AutoBridge.from_pretrained(hf_model_path)
185
+ tf_config = bridge.config
186
+ hf_config = bridge.hf_config
187
+
188
+ # --- Configuration Unification ---
189
+ # Update the tf_config with values from the form. This makes tf_config the single source of truth.
190
+ tf_config.tensor_model_parallel_size = config.tp
191
+ tf_config.pipeline_model_parallel_size = config.pp
192
+ tf_config.expert_model_parallel_size = config.ep
193
+ tf_config.context_parallel_size = config.cp
194
+ tf_config.recompute_granularity = config.recompute_granularity
195
+ tf_config.recompute_method = config.recompute_method
196
+ tf_config.recompute_num_layers = config.recompute_num_layers
197
+ # 新增:Selective 模式下的模块列表
198
+ tf_config.recompute_modules = config.recompute_modules if config.recompute_modules is not None else []
199
+ # 新增:Embedding/Loss PP Split
200
+ tf_config.account_for_embedding_in_pipeline_split = config.account_for_embedding_in_pipeline_split
201
+ tf_config.account_for_loss_in_pipeline_split = config.account_for_loss_in_pipeline_split
202
+ tf_config.num_layers_per_virtual_pipeline_stage = (
203
+ config.vpp if config.vpp and config.vpp > 1 else None
204
+ )
205
+
206
+ if config.num_layers_in_first_pipeline_stage is not None:
207
+ tf_config.num_layers_in_first_pipeline_stage = (
208
+ config.num_layers_in_first_pipeline_stage
209
+ )
210
+ if config.num_layers_in_last_pipeline_stage is not None:
211
+ tf_config.num_layers_in_last_pipeline_stage = (
212
+ config.num_layers_in_last_pipeline_stage
213
+ )
214
+
215
+ # Handle custom pipeline layout if provided
216
+ if config.pipeline_model_parallel_layout:
217
+ from megatron.core.transformer.pipeline_parallel_layer_layout import (
218
+ PipelineParallelLayerLayout,
219
+ )
220
+
221
+ tf_config.pipeline_model_parallel_layout = PipelineParallelLayerLayout(
222
+ config.pipeline_model_parallel_layout, config.pp
223
+ )
224
+ # print(tf_config)
225
+
226
+ # Create a minimal 'args' object with parameters not present in TransformerConfig
227
+ args = argparse.Namespace()
228
+ args.micro_batch_size = config.mbs
229
+ args.seq_length = config.seq_len
230
+ args.use_distributed_optimizer = config.use_distributed_optimizer
231
+ args.data_parallel_size = config.num_gpus // parallel_product
232
+ args.expert_tensor_parallel_size = config.etp if config.etp else 1
233
+
234
+ # These are required by the estimator but can be derived or defaulted
235
+ args.transformer_impl = "transformer_engine"
236
+ args.fp8 = False
237
+ args.num_experts = getattr(tf_config, "num_moe_experts", 1) # Needed for layer spec
238
+ args.moe_grouped_gemm = True # Default
239
+ args.qk_layernorm = tf_config.qk_layernorm
240
+ args.multi_latent_attention = "deepseek" in getattr(hf_config, "model_type", "")
241
+ args.padded_vocab_size = getattr(hf_config, "vocab_size")
242
+ args.max_position_embeddings = getattr(hf_config, "max_position_embeddings")
243
+ args.tie_word_embeddings = getattr(hf_config, "tie_word_embeddings", False)
244
+ args.world_size = config.num_gpus
245
+
246
+ # This function now returns (aggregated_pp_reports, raw_chunk_reports)
247
+ aggregated_reports, raw_chunk_reports = estimate_from_config(tf_config, args)
248
+
249
+ processed_reports = []
250
+ for rpt in aggregated_reports:
251
+ p = rpt.copy()
252
+ p.pop("details", None)
253
+ processed_reports.append(p)
254
+
255
+ return {"processed_report": processed_reports, "raw_report": raw_chunk_reports}
webui/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
3
+ mbridge
webui/script.js ADDED
@@ -0,0 +1,787 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.addEventListener('DOMContentLoaded', () => {
2
+ // Initial UI setup
3
+ loadLocalConfigs();
4
+ updateHistoryView();
5
+ setupEventListeners();
6
+ updateParallelismOptions();
7
+ validateParallelismLive();
8
+ toggleEpBasedOnConfig(); // Disable EP initially
9
+ toggleVppDependentOptions(); // 初始化 VPP 相关复选框显隐
10
+ });
11
+
12
+ // Utility: convert ANSI color codes (red 31, green 32) to HTML spans for display
13
+ function ansiToHtml(str) {
14
+ if (!str) return '';
15
+ // Replace known ANSI codes
16
+ return str
17
+ .replace(/\u001b\[31m/g, '<span class="ansi-red">')
18
+ .replace(/\u001b\[32m/g, '<span class="ansi-green">')
19
+ .replace(/\u001b\[33m/g, '<span class="ansi-yellow">')
20
+ .replace(/\u001b\[34m/g, '<span class="ansi-blue">')
21
+ .replace(/\u001b\[35m/g, '<span class="ansi-magenta">')
22
+ .replace(/\u001b\[36m/g, '<span class="ansi-cyan">')
23
+ .replace(/\u001b\[0m/g, '</span>');
24
+ }
25
+
26
+ function setupEventListeners() {
27
+ document.getElementById('config-form').addEventListener('submit', (e) => {
28
+ e.preventDefault();
29
+ submitForm();
30
+ });
31
+
32
+ document.getElementById('model-select').addEventListener('change', loadSelectedModelConfig);
33
+
34
+ document.getElementById('recompute-granularity').addEventListener('change', (e) => {
35
+ const recomputeOptions = document.querySelectorAll('.recompute-options');
36
+ recomputeOptions.forEach(opt => {
37
+ opt.style.display = e.target.value === 'full' ? 'block' : 'none';
38
+ });
39
+
40
+ // 新增:Selective 模式下展示复选框
41
+ const selectiveOptions = document.querySelectorAll('.selective-options');
42
+ selectiveOptions.forEach(opt => {
43
+ opt.style.display = e.target.value === 'selective' ? 'block' : 'none';
44
+ });
45
+ });
46
+
47
+ const liveValidationInputs = ['num-gpus', 'tp', 'pp', 'ep', 'cp', 'etp', 'vpp', 'config-editor', 'pipeline-layout'];
48
+ liveValidationInputs.forEach(id => {
49
+ const input = document.getElementById(id);
50
+ if(input) {
51
+ input.addEventListener('change', validateParallelismLive);
52
+ if (id === 'num-gpus') {
53
+ input.addEventListener('change', updateParallelismOptions);
54
+ }
55
+ if (id === 'vpp') {
56
+ input.addEventListener('change', toggleVppDependentOptions);
57
+ }
58
+ }
59
+ });
60
+
61
+ document.getElementById('config-editor').addEventListener('input', toggleEpBasedOnConfig);
62
+ document.getElementById('history-table').addEventListener('click', handleHistoryAction);
63
+ document.getElementById('clear-history').addEventListener('click', clearHistory);
64
+ }
65
+
66
+
67
+ async function loadLocalConfigs() {
68
+ const modelSelect = document.getElementById('model-select');
69
+ const defaultConfigName = 'Qwen/Qwen3-235B-A22B'; // Updated default model
70
+
71
+ try {
72
+ const response = await fetch('/local-hf-configs');
73
+ const configs = await response.json();
74
+
75
+ modelSelect.innerHTML = '<option value="">Select a model...</option>';
76
+ // Add custom option to allow user supplied configs
77
+ modelSelect.innerHTML += '<option value="__custom__">Custom (paste JSON below)...</option>';
78
+ configs.forEach(config => {
79
+ modelSelect.innerHTML += `<option value="${config}">${config}</option>`;
80
+ });
81
+
82
+ // Check if the default config exists and select it
83
+ if (configs.includes(defaultConfigName)) {
84
+ modelSelect.value = defaultConfigName;
85
+ // Await the loading of the model config to ensure it's ready
86
+ await loadSelectedModelConfig();
87
+ }
88
+
89
+ } catch (error) {
90
+ modelSelect.innerHTML = '<option value="">Error loading configs</option>';
91
+ console.error('Error loading local configs:', error);
92
+ }
93
+ }
94
+
95
+ async function loadSelectedModelConfig() {
96
+ const modelSelect = document.getElementById('model-select');
97
+ const editor = document.getElementById('config-editor');
98
+ const selectedConfig = modelSelect.value;
99
+ const messageDiv = document.getElementById('validation-message'); // move early for use in all branches
100
+ let configData = null; // declare for wider scope
101
+
102
+ if (!selectedConfig) {
103
+ editor.value = '';
104
+ toggleEpBasedOnConfig();
105
+ if (messageDiv) messageDiv.style.display = 'none';
106
+ return;
107
+ } else if (selectedConfig === '__custom__') {
108
+ // Custom config: do not fetch, user must paste JSON
109
+ editor.value = '';
110
+ toggleEpBasedOnConfig();
111
+ if (messageDiv) messageDiv.style.display = 'none';
112
+ return;
113
+ }
114
+
115
+ // 优先直接从 HuggingFace 仓库拉取配置文件
116
+ const hfUrl = `https://huggingface.co/${selectedConfig}/raw/main/config.json`;
117
+ try {
118
+ const resp = await fetch(hfUrl, { mode: 'cors' });
119
+ if (resp.ok) {
120
+ configData = await resp.json();
121
+ editor.value = JSON.stringify(configData, null, 2);
122
+ } else {
123
+ throw new Error(`HF returned status ${resp.status}`);
124
+ }
125
+ } catch (hfErr) {
126
+ console.warn('Direct HF fetch failed, fallback to backend:', hfErr);
127
+ // 回退到后端接口(兼容本地部署无 CORS 或私有模型)
128
+ try {
129
+ const response = await fetch(`/get-megatron-config/${encodeURIComponent(selectedConfig)}`);
130
+ configData = await response.json();
131
+ if (configData.error) {
132
+ editor.value = `Error: ${configData.error}`;
133
+ } else {
134
+ editor.value = JSON.stringify(configData, null, 2);
135
+ }
136
+ } catch (beErr) {
137
+ editor.value = 'Failed to fetch model configuration.';
138
+ console.error('Backend config fetch error:', beErr);
139
+ }
140
+ }
141
+
142
+ // Trigger validation and UI updates after loading new config
143
+ validateParallelismLive();
144
+ toggleEpBasedOnConfig();
145
+
146
+ // Show Kimi-K2-Instruct warning if needed
147
+ if (selectedConfig.includes('Kimi-K2-Instruct') && configData && configData.model_type !== 'deepseek_v3') {
148
+ messageDiv.textContent = 'Notice: For Kimi-K2-Instruct the config field "model_type" must be set to "deepseek_v3" before memory estimation.';
149
+ messageDiv.style.display = 'block';
150
+ } else if (messageDiv) {
151
+ messageDiv.style.display = 'none';
152
+ }
153
+ }
154
+
155
+
156
+ function getFormValues(isSubmission = false) {
157
+ const form = document.getElementById('config-form');
158
+ const formData = new FormData(form);
159
+ const modelSelect = document.getElementById('model-select');
160
+
161
+ const hfPath = modelSelect.value;
162
+ if (!hfPath) {
163
+ // We will now handle this case in the submitForm function instead of an alert.
164
+ return null;
165
+ }
166
+
167
+ const editor = document.getElementById('config-editor');
168
+ let customConfig = null;
169
+ try {
170
+ // Only parse if the editor has content
171
+ if (editor.value) {
172
+ customConfig = JSON.parse(editor.value);
173
+ }
174
+ } catch (e) {
175
+ // Only alert on final submission, not on live validation
176
+ if (isSubmission) {
177
+ // alert('Model Config is not valid JSON.'); // Removing alert
178
+ }
179
+ return null; // Return null if JSON is invalid
180
+ }
181
+
182
+ const vppInput = formData.get('vpp');
183
+ const etpInput = formData.get('etp');
184
+ const pipelineLayoutInput = formData.get('pipeline_model_parallel_layout');
185
+
186
+ // 新增:收集 selective 模式下用户选择的模块
187
+ const recomputeModules = formData.getAll('recompute_modules');
188
+
189
+ return {
190
+ hf_model_path: hfPath,
191
+ custom_hf_config: customConfig, // Renamed for clarity
192
+ num_gpus: parseInt(formData.get('num_gpus')),
193
+ mbs: parseInt(formData.get('mbs')),
194
+ seq_len: parseInt(formData.get('seq-len')),
195
+ use_distributed_optimizer: document.getElementById('use-distributed-optimizer').checked,
196
+ recompute_granularity: formData.get('recompute_granularity'),
197
+ recompute_method: formData.get('recompute_method'),
198
+ recompute_num_layers: parseInt(formData.get('recompute_num_layers')),
199
+ // 新增字段
200
+ recompute_modules: recomputeModules,
201
+ tp: parseInt(formData.get('tp')),
202
+ pp: parseInt(formData.get('pp')),
203
+ ep: parseInt(formData.get('ep')) || 1, // Default to 1 if disabled/null
204
+ cp: parseInt(formData.get('cp')),
205
+ vpp: vppInput ? parseInt(vppInput) : null,
206
+ etp: etpInput ? parseInt(etpInput) : null,
207
+ num_layers_in_first_pipeline_stage: formData.get('num_layers_in_first_pipeline_stage') ? parseInt(formData.get('num_layers_in_first_pipeline_stage')) : null,
208
+ num_layers_in_last_pipeline_stage: formData.get('num_layers_in_last_pipeline_stage') ? parseInt(formData.get('num_layers_in_last_pipeline_stage')) : null,
209
+ pipeline_model_parallel_layout: pipelineLayoutInput ? pipelineLayoutInput.trim() : null,
210
+ overhead: parseInt(formData.get('overhead')),
211
+ // 新增:
212
+ account_for_embedding_in_pipeline_split: document.getElementById('account_for_embedding_in_pipeline_split').checked,
213
+ account_for_loss_in_pipeline_split: document.getElementById('account_for_loss_in_pipeline_split').checked,
214
+ };
215
+ }
216
+
217
+ async function submitForm() {
218
+ const messageDiv = document.getElementById('validation-message');
219
+ messageDiv.textContent = '';
220
+ messageDiv.style.display = 'none';
221
+
222
+ // Get all form values first. We use getFormValues(false) to avoid any legacy alerts
223
+ // and handle all validation directly within this function for clarity.
224
+ const formValues = getFormValues(false);
225
+
226
+ // === START SUBMISSION VALIDATION ===
227
+
228
+ // 1. Check if form values could be retrieved. This catches both missing model selection
229
+ // and invalid JSON, as getFormValues returns null in those cases.
230
+ if (!formValues) {
231
+ if (!document.getElementById('model-select').value) {
232
+ messageDiv.textContent = 'Validation Error: Please select a model config.';
233
+ } else {
234
+ messageDiv.textContent = 'Validation Error: Model Config is not valid JSON.';
235
+ }
236
+ messageDiv.style.display = 'block';
237
+ return;
238
+ }
239
+
240
+ // Custom config must have valid JSON
241
+ if (document.getElementById('model-select').value === '__custom__' && !formValues.custom_hf_config) {
242
+ messageDiv.textContent = 'Validation Error: Please paste a valid model configuration JSON for the custom model.';
243
+ messageDiv.style.display = 'block';
244
+ return;
245
+ }
246
+
247
+ // 2. Perform all numeric and parallelism validation.
248
+ const { num_gpus, tp, pp, ep, cp, etp, custom_hf_config } = formValues;
249
+ const num_kv_heads = custom_hf_config?.num_key_value_heads || null;
250
+
251
+ let errors = [];
252
+ if (tp * pp * cp > num_gpus) {
253
+ errors.push(`TP*PP*CP (${tp * pp * cp}) > GPUs (${num_gpus}).`);
254
+ }
255
+ if (etp){
256
+ if (etp * pp * cp * ep > num_gpus) {
257
+ errors.push(`ETP*PP*CP*EP (${etp * pp * cp * ep}) > GPUs (${num_gpus}).`);
258
+ }
259
+ } else {
260
+ if (tp * pp * cp * ep > num_gpus) {
261
+ errors.push(`TP*PP*CP*EP (${tp * pp * cp * ep}) > GPUs (${num_gpus}) when ETP is not set.`);
262
+ }
263
+ }
264
+ if (num_kv_heads && tp > num_kv_heads) {
265
+ errors.push(`TP (${tp}) > Num KV Heads (${num_kv_heads}).`);
266
+ }
267
+
268
+ if (errors.length > 0) {
269
+ messageDiv.textContent = 'Validation Error: ' + errors.join(' ');
270
+ messageDiv.style.display = 'block';
271
+ return;
272
+ }
273
+ // === END SUBMISSION VALIDATION ===
274
+
275
+ const loading = document.getElementById('loading');
276
+ const submitBtn = document.querySelector('#config-form button[type="submit"]');
277
+ loading.style.display = 'block';
278
+ if (submitBtn) submitBtn.disabled = true;
279
+
280
+ try {
281
+ const response = await fetch('/estimate_with_mbridge', {
282
+ method: 'POST',
283
+ headers: { 'Content-Type': 'application/json' },
284
+ body: JSON.stringify(formValues) // Send the now fully-validated formValues
285
+ });
286
+
287
+ console.log('Response Status:', response.status);
288
+
289
+ if (response.ok) {
290
+ const data = await response.json();
291
+
292
+ // FIX: Ensure history wrapper is visible before updating and showing details
293
+ document.getElementById('history-wrapper').style.display = 'block';
294
+
295
+ saveToHistory(formValues, data);
296
+ updateHistoryView();
297
+ const newEntryRow = document.querySelector('#history-table tbody tr:first-child');
298
+ if (newEntryRow) {
299
+ const detailBtn = newEntryRow.querySelector('.detail-btn');
300
+ if (detailBtn) {
301
+ // We need to pass the event object structure to handleHistoryAction
302
+ handleHistoryAction({ target: detailBtn });
303
+ }
304
+ }
305
+ } else {
306
+ const error = await response.text();
307
+ console.error('Server error response:', error);
308
+ // Since we removed the main results display, show error in the validation div
309
+ messageDiv.textContent = `Server Error: ${error}`;
310
+ messageDiv.style.display = 'block';
311
+ }
312
+ } catch (error) {
313
+ console.error('Fetch API Error:', error);
314
+ messageDiv.textContent = `Client Error: ${error.message}`;
315
+ messageDiv.style.display = 'block';
316
+ } finally {
317
+ loading.style.display = 'none';
318
+ if (submitBtn) submitBtn.disabled = false;
319
+ }
320
+ }
321
+
322
+ function renderTable(details, rawFullReport) {
323
+ if (!details || details.length === 0) {
324
+ return '<p>No detailed memory breakdown available.</p>';
325
+ }
326
+
327
+ const headers = Object.keys(details[0]);
328
+ headers.push('Breakdown');
329
+
330
+ let table = '<table><thead><tr>';
331
+ headers.forEach(h => table += `<th>${h}</th>`);
332
+ table += '</tr></thead><tbody>';
333
+
334
+ details.forEach(row => {
335
+ const ppRank = row.pp_rank;
336
+ // FIX: Look in the full raw report array passed in.
337
+ const rawDataForRank = rawFullReport ? rawFullReport.find(r => r.pp_rank === ppRank) : null;
338
+
339
+ // FIX: Change to `let` to allow modification for highlighting.
340
+ let modelBreakdown = (rawDataForRank && rawDataForRank.model_breakdown)
341
+ ? rawDataForRank.model_breakdown
342
+ : 'No breakdown available.';
343
+
344
+ // Add syntax-like highlighting for params and activations
345
+ // Basic HTML escaping for safety before inserting spans
346
+ modelBreakdown = modelBreakdown.replace(/&/g, "&amp;").replace(/</g, "&lt;").replace(/>/g, "&gt;");
347
+ modelBreakdown = modelBreakdown
348
+ .replace(/(n_params=[0-9.]+[a-zA-Z]*)/g, '<span class="highlight-red">$1</span>')
349
+ .replace(/(n_act=[0-9.]+[a-zA-Z]*)/g, '<span class="highlight-red">$1</span>');
350
+
351
+ // Main row with data
352
+ table += `<tr data-pp-rank="${ppRank}">`;
353
+ headers.forEach(h => {
354
+ if (h !== 'Breakdown') {
355
+ table += `<td>${row[h]}</td>`;
356
+ }
357
+ });
358
+ table += `<td><button class="action-btn raw-per-rank-btn" data-pp-rank="${ppRank}">Raw</button></td>`;
359
+ table += '</tr>';
360
+
361
+ // Hidden row for the breakdown
362
+ table += `<tr class="raw-breakdown-row" data-pp-rank="${ppRank}" style="display: none;">
363
+ <td colspan="${headers.length}">
364
+ <pre>${modelBreakdown}</pre>
365
+ </td>
366
+ </tr>`;
367
+ });
368
+
369
+ table += '</tbody></table>';
370
+ return table;
371
+ }
372
+
373
+ function saveToHistory(params, resultData) {
374
+ let history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
375
+ const historyEntry = {
376
+ params: params,
377
+ result: resultData, // Store the full result object { processed_report, raw_report }
378
+ id: new Date().getTime()
379
+ };
380
+ history.unshift(historyEntry); // Add to the beginning
381
+ if (history.length > 20) { // Keep history size manageable
382
+ history.pop();
383
+ }
384
+ localStorage.setItem('estimationHistory', JSON.stringify(history));
385
+ }
386
+
387
+ function updateHistoryView() {
388
+ const history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
389
+ const historyTableBody = document.querySelector('#history-table tbody');
390
+ const historyWrapper = document.getElementById('history-wrapper');
391
+ historyTableBody.innerHTML = '';
392
+
393
+ if (history.length === 0) {
394
+ historyWrapper.style.display = 'none';
395
+ return;
396
+ }
397
+
398
+ historyWrapper.style.display = 'block';
399
+
400
+ history.forEach(item => {
401
+ const row = document.createElement('tr');
402
+
403
+ const params = item.params;
404
+ const resultData = item.result || {};
405
+
406
+ // FIX: Handle both old and new data structures for compatibility.
407
+ const details = (resultData.report && resultData.report.details) ? resultData.report.details : (resultData.processed_report || []);
408
+ const pp0Result = details.find(r => r.pp_rank === 0) || details[0] || {};
409
+
410
+ const modelName = params.hf_model_path.split('/').pop();
411
+
412
+ // Build parallelism string, e.g., "TP2 PP2 VPP2"
413
+ const parallelismParts = [];
414
+ ['tp', 'pp', 'ep', 'cp', 'vpp', 'etp'].forEach(p => {
415
+ const value = params[p];
416
+ if (value && value > 1) {
417
+ parallelismParts.push(`${p.toUpperCase()}${value}`);
418
+ }
419
+ });
420
+ const parallelismInfo = parallelismParts.join(' ') || 'No Parallelism';
421
+
422
+ const overheadGb = params.overhead ? parseInt(params.overhead) : 0;
423
+ const baseTotal = details.length > 0 ? Math.max(...details.map(r => r.total_gb || 0)) : null;
424
+ const totalGb = baseTotal !== null ? (baseTotal + overheadGb).toFixed(2) : 'N/A';
425
+
426
+ const seqLen = params.seq_len || 0;
427
+ const formattedSeqLen = seqLen >= 1024 ? `${seqLen / 1024}k` : seqLen;
428
+ const sequenceInfo = `${params.mbs || 'N/A'}*${formattedSeqLen}`;
429
+
430
+ // 新增:生成重算方式描述
431
+ let recomputeInfo = '';
432
+ switch (params.recompute_granularity) {
433
+ case 'none':
434
+ recomputeInfo = 'Recompute: None';
435
+ break;
436
+ case 'full':
437
+ const method = params.recompute_method || 'uniform';
438
+ const layers = params.recompute_num_layers ? params.recompute_num_layers : '';
439
+ recomputeInfo = `Recompute: Full (${method}${layers ? ',' + layers + 'L' : ''})`;
440
+ break;
441
+ case 'selective':
442
+ const mods = Array.isArray(params.recompute_modules) && params.recompute_modules.length ? params.recompute_modules.join('+') : '';
443
+ recomputeInfo = `Recompute: Selective${mods ? ' (' + mods + ')' : ''}`;
444
+ break;
445
+ default:
446
+ recomputeInfo = '';
447
+ }
448
+
449
+ row.innerHTML = `
450
+ <td>
451
+ <div>${modelName}</div>
452
+ <div class="model-meta-info">
453
+ <span>GPUs: ${params.num_gpus || 'N/A'}</span>
454
+ <span>${parallelismInfo}</span>
455
+ <span>Sequence: ${sequenceInfo}</span>
456
+ ${recomputeInfo ? `<span>${recomputeInfo}</span>` : ''}
457
+ </div>
458
+ </td>
459
+ <td>${pp0Result.weight_grad_optim_gb || 'N/A'}</td>
460
+ <td>${pp0Result.activation_gb || 'N/A'}</td>
461
+ <td>${totalGb}</td>
462
+ <td>
463
+ <button class="restore-btn" data-id="${item.id}">Restore</button>
464
+ <button class="detail-btn" data-id="${item.id}">Detail</button>
465
+ <button class="delete-btn" data-id="${item.id}">Delete</button>
466
+ </td>
467
+ `;
468
+ historyTableBody.appendChild(row);
469
+ });
470
+ }
471
+
472
+ async function handleHistoryAction(e) {
473
+ const button = e.target.closest('button');
474
+ if (!button) return;
475
+
476
+ // Handle breakdown toggle first
477
+ if (button.classList.contains('breakdown-btn')) {
478
+ const ppRank = button.dataset.ppRank;
479
+ const detailTable = button.closest('table');
480
+ if (!detailTable) return;
481
+
482
+ const breakdownRow = detailTable.querySelector(`tr.breakdown-row[data-pp-rank="${ppRank}"]`);
483
+ if (!breakdownRow) return;
484
+
485
+ const isVisible = breakdownRow.style.display !== 'none';
486
+ breakdownRow.style.display = isVisible ? 'none' : 'table-row';
487
+ button.textContent = isVisible ? 'Breakdown' : 'Hide';
488
+ return; // Do not continue to other handlers
489
+ }
490
+
491
+ if (!button.matches('.detail-btn, .restore-btn, .delete-btn')) return;
492
+
493
+ const id = parseInt(button.dataset.id, 10);
494
+ const history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
495
+ const entry = history.find(item => item.id === id);
496
+
497
+ if (!entry) {
498
+ console.error('History entry not found for id:', id);
499
+ return;
500
+ }
501
+
502
+ const row = button.closest('tr');
503
+
504
+ if (button.classList.contains('detail-btn')) {
505
+ const isDetailsVisible = row.nextElementSibling && row.nextElementSibling.classList.contains('detail-row');
506
+
507
+ document.querySelectorAll('.detail-row').forEach(detailRow => {
508
+ const prevRow = detailRow.previousElementSibling;
509
+ const detailBtn = prevRow.querySelector('.detail-btn');
510
+ if (detailRow !== row.nextElementSibling) {
511
+ detailRow.remove();
512
+ if (detailBtn) detailBtn.textContent = 'Detail';
513
+ }
514
+ });
515
+
516
+ if (isDetailsVisible) {
517
+ row.nextElementSibling.remove();
518
+ button.textContent = 'Detail';
519
+ } else {
520
+ const detailRow = document.createElement('tr');
521
+ detailRow.classList.add('detail-row');
522
+ const detailCell = detailRow.insertCell();
523
+ detailCell.colSpan = row.cells.length;
524
+
525
+ // FIX: Handle both old and new data structures for compatibility.
526
+ const report = entry.result.report;
527
+ const details = (report && report.details) ? report.details : (entry.result.processed_report || []);
528
+ const modelBreakdown = (report && report.model_breakdown) ? report.model_breakdown : null;
529
+
530
+ if (details && details.length > 0) {
531
+ const newTable = document.createElement('table');
532
+ // Determine if breakdown information exists per-row or globally
533
+ let headers = Object.keys(details[0]);
534
+
535
+ // If old-format data, there is a 'model_breakdown' key on each detail row
536
+ const hasRowBreakdown = headers.includes('model_breakdown');
537
+
538
+ // Remove the raw model_breakdown column from headers to keep table compact
539
+ if (hasRowBreakdown) {
540
+ headers = headers.filter(h => h !== 'model_breakdown');
541
+ }
542
+
543
+ // Include global breakdown if provided, or row breakdowns if present
544
+ const includeBreakdown = hasRowBreakdown || (modelBreakdown && typeof modelBreakdown === 'string');
545
+
546
+ if (includeBreakdown) {
547
+ headers.push('Breakdown');
548
+ }
549
+
550
+ const headerRow = newTable.insertRow();
551
+ headers.forEach(h => {
552
+ const th = document.createElement('th');
553
+ th.textContent = h;
554
+ headerRow.appendChild(th);
555
+ });
556
+
557
+ details.forEach(detail => {
558
+ const newRow = newTable.insertRow();
559
+ headers.forEach(header => {
560
+ if (header === 'Breakdown') {
561
+ const cell = newRow.insertCell();
562
+ cell.innerHTML = `<button class="breakdown-btn" data-pp-rank="${detail.pp_rank}">Breakdown</button>`;
563
+ } else {
564
+ const cell = newRow.insertCell();
565
+ let value = detail[header];
566
+ if (typeof value === 'number' && !Number.isInteger(value)) {
567
+ value = value.toFixed(4);
568
+ }
569
+ cell.textContent = value;
570
+ }
571
+ });
572
+
573
+ // Hidden breakdown row
574
+ if (includeBreakdown) {
575
+ const breakdownRow = newTable.insertRow();
576
+ breakdownRow.classList.add('breakdown-row');
577
+ breakdownRow.dataset.ppRank = detail.pp_rank;
578
+ breakdownRow.style.display = 'none';
579
+ const breakdownCell = breakdownRow.insertCell();
580
+ breakdownCell.colSpan = headers.length;
581
+ const rowSpecificBreakdown = hasRowBreakdown ? (detail.model_breakdown || '') : modelBreakdown;
582
+ const htmlBreakdown = ansiToHtml(rowSpecificBreakdown);
583
+ breakdownCell.innerHTML = `<pre class="model-breakdown-view">${htmlBreakdown || 'No breakdown available.'}</pre>`;
584
+ }
585
+ });
586
+
587
+ detailCell.appendChild(newTable);
588
+ } else {
589
+ detailCell.innerHTML = 'No detailed per-rank results available.';
590
+ }
591
+
592
+ row.after(detailRow);
593
+ button.textContent = 'Hide';
594
+ }
595
+ } else if (button.classList.contains('restore-btn')) {
596
+ restoreForm(entry.params);
597
+ } else if (button.classList.contains('delete-btn')) {
598
+ deleteHistoryEntry(id);
599
+ }
600
+ }
601
+
602
+ function deleteHistoryEntry(id) {
603
+ let history = JSON.parse(localStorage.getItem('estimationHistory')) || [];
604
+ const updatedHistory = history.filter(item => item.id != id);
605
+ localStorage.setItem('estimationHistory', JSON.stringify(updatedHistory));
606
+ updateHistoryView();
607
+
608
+ // If history is now empty, hide the whole output container
609
+ if (updatedHistory.length === 0) {
610
+ // document.getElementById('output-container').style.display = 'none';
611
+ }
612
+ }
613
+
614
+ function clearHistory() {
615
+ localStorage.removeItem('estimationHistory');
616
+ updateHistoryView();
617
+ // document.getElementById('output-container').style.display = 'none';
618
+ }
619
+
620
+
621
+ function restoreForm(params) {
622
+ if (!params) return;
623
+
624
+ const setElementValue = (id, value, defaultValue = '') => {
625
+ const element = document.getElementById(id);
626
+ if (element) {
627
+ if (element.type === 'checkbox') {
628
+ element.checked = value ?? defaultValue;
629
+ } else {
630
+ element.value = value ?? defaultValue;
631
+ }
632
+ }
633
+ };
634
+
635
+ setElementValue('num-gpus', params.num_gpus, 8);
636
+ setElementValue('mbs', params.mbs, 1);
637
+ setElementValue('seq-len', params.seq_len, 4096);
638
+ setElementValue('use-distributed-optimizer', params.use_distributed_optimizer, true);
639
+ setElementValue('recompute_granularity', params.recompute_granularity, 'selective');
640
+ setElementValue('recompute_method', params.recompute_method, 'uniform');
641
+ setElementValue('recompute_num_layers', params.recompute_num_layers, 1);
642
+ setElementValue('tp', params.tp, 1);
643
+ setElementValue('pp', params.pp, 1);
644
+ setElementValue('ep', params.ep, 1);
645
+ setElementValue('cp', params.cp, 1);
646
+ setElementValue('vpp', params.vpp);
647
+ // 在设置 vpp 之后更新依赖显示
648
+ toggleVppDependentOptions();
649
+ setElementValue('etp', params.etp);
650
+ setElementValue('num_layers_in_first_pipeline_stage', params.num_layers_in_first_pipeline_stage);
651
+ setElementValue('num_layers_in_last_pipeline_stage', params.num_layers_in_last_pipeline_stage);
652
+ setElementValue('pipeline-layout', params.pipeline_model_parallel_layout);
653
+ setElementValue('overhead', params.overhead, 10);
654
+
655
+ // 新增 checkbox 恢复
656
+ setElementValue('account_for_embedding_in_pipeline_split', params.account_for_embedding_in_pipeline_split, false);
657
+ setElementValue('account_for_loss_in_pipeline_split', params.account_for_loss_in_pipeline_split, false);
658
+
659
+ const modelSelect = document.getElementById('model-select');
660
+ if (modelSelect && params.hf_model_path) {
661
+ modelSelect.value = params.hf_model_path;
662
+ }
663
+
664
+ // Manually trigger change event for UI updates
665
+ const recomputeSelect = document.getElementById('recompute_granularity');
666
+ if (recomputeSelect) {
667
+ recomputeSelect.dispatchEvent(new Event('change'));
668
+ }
669
+ }
670
+
671
+ function updateParallelismOptions() {
672
+ const numGpusInput = document.getElementById('num-gpus');
673
+ if (!numGpusInput) return;
674
+
675
+ const numGpus = parseInt(numGpusInput.value);
676
+ if (isNaN(numGpus) || numGpus <= 0) {
677
+ return; // Don't update if GPU count is invalid
678
+ }
679
+
680
+ const tpSelect = document.getElementById('tp');
681
+ const epSelect = document.getElementById('ep');
682
+ const cpSelect = document.getElementById('cp');
683
+
684
+ // PP is now a manual input, so we only handle TP, EP, CP here.
685
+ const selects = [tpSelect, epSelect, cpSelect];
686
+
687
+ const powersOfTwo = [1];
688
+ for (let i = 1; (1 << i) <= numGpus; i++) {
689
+ powersOfTwo.push(1 << i);
690
+ }
691
+
692
+ selects.forEach(select => {
693
+ if (!select) return;
694
+ const currentVal = select.value;
695
+ select.innerHTML = ''; // Clear existing options
696
+
697
+ powersOfTwo.forEach(val => {
698
+ const option = document.createElement('option');
699
+ option.value = val;
700
+ option.textContent = val;
701
+ select.appendChild(option);
702
+ });
703
+
704
+ // Try to restore the previous value, otherwise default to 1
705
+ if (powersOfTwo.includes(parseInt(currentVal))) {
706
+ select.value = currentVal;
707
+ } else {
708
+ select.value = 1;
709
+ }
710
+ });
711
+ }
712
+
713
+ function validateParallelismLive() {
714
+ const messageDiv = document.getElementById('validation-message');
715
+ // Pass isSubmission = false to getFormValues to prevent alerts during live validation
716
+ const formValues = getFormValues(false);
717
+
718
+ if (!formValues) {
719
+ messageDiv.textContent = '';
720
+ return true;
721
+ }
722
+
723
+ const { num_gpus, tp, pp, ep, cp, etp, custom_hf_config } = formValues;
724
+ // The key is the same in the HF config, so this logic remains valid.
725
+ const num_kv_heads = custom_hf_config?.num_key_value_heads || null;
726
+
727
+ let errors = [];
728
+ if (tp * pp * cp > num_gpus) {
729
+ errors.push(`TP*PP*CP (${tp*pp*cp}) > GPUs (${num_gpus}).`);
730
+ }
731
+ if (etp) {
732
+ if (etp * pp * cp * ep > num_gpus) {
733
+ errors.push(`ETP*PP*CP*EP (${etp*pp*cp*ep}) > GPUs (${num_gpus}).`);
734
+ }
735
+ } else {
736
+ if (tp * pp * cp * ep > num_gpus) {
737
+ errors.push(`TP*PP*CP*EP (${tp*pp*cp*ep}) > GPUs (${num_gpus}) when ETP is not set.`);
738
+ }
739
+ }
740
+ if (num_kv_heads && tp > num_kv_heads) {
741
+ errors.push(`TP (${tp}) > Num KV Heads (${num_kv_heads}).`);
742
+ }
743
+
744
+ if (errors.length > 0) {
745
+ messageDiv.textContent = 'Validation Error: ' + errors.join(' ');
746
+ messageDiv.style.display = 'block';
747
+ } else {
748
+ messageDiv.textContent = '';
749
+ messageDiv.style.display = 'none';
750
+ }
751
+ return errors.length === 0;
752
+ }
753
+
754
+ function toggleEpBasedOnConfig() {
755
+ const editor = document.getElementById('config-editor');
756
+ const epSelect = document.getElementById('ep');
757
+ if (!editor || !epSelect) return;
758
+
759
+ let config = null;
760
+ try {
761
+ if (editor.value) {
762
+ config = JSON.parse(editor.value);
763
+ }
764
+ } catch (e) {
765
+ // Invalid JSON, disable EP as a safety measure
766
+ epSelect.disabled = true;
767
+ return;
768
+ }
769
+
770
+ if (config && config.num_experts_per_tok) {
771
+ epSelect.disabled = false;
772
+ } else {
773
+ epSelect.disabled = true;
774
+ epSelect.value = 1; // Reset to 1 if disabled
775
+ }
776
+ }
777
+
778
+ // 新增:根据 vpp 输入显示/隐藏依赖选项
779
+ function toggleVppDependentOptions() {
780
+ const vppInput = document.getElementById('vpp');
781
+ const dependents = document.querySelectorAll('.vpp-dependent');
782
+ if (!vppInput) return;
783
+ const shouldShow = vppInput.value && parseInt(vppInput.value) > 0;
784
+ dependents.forEach(el => {
785
+ el.style.display = shouldShow ? 'block' : 'none';
786
+ });
787
+ }
webui/style.css ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ body {
2
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif;
3
+ line-height: 1.6;
4
+ background-color: #f4f4f4;
5
+ color: #333;
6
+ margin: 0;
7
+ padding: 1em;
8
+ }
9
+
10
+ .container {
11
+ max-width: 1600px;
12
+ margin: auto;
13
+ background: #fff;
14
+ padding: 2em;
15
+ border-radius: 8px;
16
+ box-shadow: 0 0 20px rgba(0, 0, 0, 0.05);
17
+ }
18
+
19
+ .main-layout {
20
+ display: flex;
21
+ flex-direction: column; /* Main axis is vertical */
22
+ gap: 2em;
23
+ }
24
+
25
+ .top-section {
26
+ display: flex;
27
+ flex-direction: row; /* Children are horizontal */
28
+ gap: 2em;
29
+ }
30
+
31
+ .config-column, .output-column {
32
+ flex: 1; /* Each column takes up half the space */
33
+ display: flex;
34
+ flex-direction: column;
35
+ }
36
+
37
+ /* The editor wrapper should grow to fill the space */
38
+ .config-editor-wrapper {
39
+ flex-grow: 1;
40
+ display: flex;
41
+ flex-direction: column;
42
+ }
43
+
44
+ #config-editor {
45
+ flex-grow: 1; /* The textarea itself should grow */
46
+ width: 100%;
47
+ box-sizing: border-box; /* Include padding and border in the element's total width and height */
48
+ resize: vertical; /* Allow vertical resizing */
49
+ }
50
+
51
+
52
+ .bottom-section {
53
+ width: 100%;
54
+ }
55
+
56
+ .form-row {
57
+ display: flex;
58
+ gap: 1em;
59
+ align-items: flex-end;
60
+ }
61
+
62
+ .form-row .form-group {
63
+ flex: 1; /* Allow groups to grow and fill space */
64
+ margin-bottom: 0.8em;
65
+ }
66
+
67
+ .form-group {
68
+ margin-bottom: 0.8em; /* Reduced from default */
69
+ }
70
+
71
+ .form-group label {
72
+ display: block;
73
+ margin-bottom: 0.25em; /* Reduced */
74
+ font-weight: 500;
75
+ }
76
+
77
+ .form-group label.inline-label {
78
+ display: inline-block;
79
+ margin-left: 0.5em;
80
+ font-weight: normal;
81
+ }
82
+
83
+ .form-group input[type="number"],
84
+ .form-group select {
85
+ width: 100%;
86
+ padding: 6px 10px; /* Reduced padding */
87
+ border-radius: 4px;
88
+ border: 1px solid #ccc;
89
+ box-sizing: border-box;
90
+ }
91
+
92
+ button {
93
+ background-color: #3498db;
94
+ color: white;
95
+ padding: 10px 15px;
96
+ border: none;
97
+ border-radius: 4px;
98
+ cursor: pointer;
99
+ font-size: 16px;
100
+ margin-top: 10px;
101
+ }
102
+
103
+ button:hover {
104
+ background-color: #2980b9;
105
+ }
106
+
107
+ #results {
108
+ background-color: #ecf0f1;
109
+ padding: 15px;
110
+ border-radius: 4px;
111
+ white-space: pre-wrap;
112
+ word-wrap: break-word;
113
+ min-height: 100px;
114
+ }
115
+
116
+ .results-container {
117
+ margin-top: 20px;
118
+ }
119
+
120
+ /* New styles for results table */
121
+ table {
122
+ width: 100%;
123
+ border-collapse: collapse;
124
+ margin-top: 20px;
125
+ }
126
+
127
+ th, td {
128
+ border: 1px solid #ddd;
129
+ padding: 12px;
130
+ text-align: left;
131
+ }
132
+
133
+ th {
134
+ background-color: #f2f2f2;
135
+ font-weight: bold;
136
+ }
137
+
138
+ tbody tr:nth-child(even) {
139
+ background-color: #f9f9f9;
140
+ }
141
+
142
+ tbody tr:hover {
143
+ background-color: #f1f1f1;
144
+ }
145
+
146
+ .error {
147
+ color: #e74c3c;
148
+ font-weight: bold;
149
+ }
150
+
151
+ .button-container {
152
+ grid-column: 1 / -1; /* Span across all columns */
153
+ text-align: center;
154
+ margin-top: 20px;
155
+ }
156
+
157
+ /* History Section */
158
+ .history-container {
159
+ margin-top: 40px;
160
+ border-top: 1px solid #e0e0e0;
161
+ padding-top: 20px;
162
+ }
163
+
164
+ .history-container h2 {
165
+ display: flex;
166
+ justify-content: space-between;
167
+ align-items: center;
168
+ }
169
+
170
+ #history-list table {
171
+ margin-top: 10px;
172
+ }
173
+
174
+ .small-button {
175
+ padding: 4px 8px;
176
+ font-size: 0.8em;
177
+ background-color: #e74c3c;
178
+ }
179
+
180
+ .small-button:hover {
181
+ background-color: #c0392b;
182
+ }
183
+
184
+ .history-item-actions {
185
+ display: flex;
186
+ gap: 10px;
187
+ }
188
+
189
+ #output-container {
190
+ margin-top: 2em;
191
+ padding: 1.5em;
192
+ background-color: #f9f9f9;
193
+ border: 1px solid #ddd;
194
+ border-radius: 8px;
195
+ }
196
+
197
+ #results-wrapper h3, #history-wrapper h3 {
198
+ margin-top: 0;
199
+ border-bottom: 2px solid #eee;
200
+ padding-bottom: 0.5em;
201
+ margin-bottom: 1em;
202
+ }
203
+
204
+ #results-display table {
205
+ width: 100%;
206
+ border-collapse: collapse;
207
+ }
208
+
209
+ #results-display th, #results-display td {
210
+ padding: 8px 12px;
211
+ border: 1px solid #ddd;
212
+ text-align: left;
213
+ }
214
+
215
+ #results-display th {
216
+ background-color: #f2f2f2;
217
+ }
218
+
219
+ #history-table {
220
+ width: 100%;
221
+ border-collapse: collapse;
222
+ }
223
+
224
+ #history-table th, #history-table td {
225
+ padding: 8px 12px;
226
+ border: 1px solid #ddd;
227
+ text-align: left;
228
+ }
229
+
230
+ #history-table th {
231
+ background-color: #f2f2f2;
232
+ }
233
+
234
+ #history-table td:last-child {
235
+ text-align: right;
236
+ }
237
+
238
+ #raw-json-output {
239
+ background-color: #2d2d2d;
240
+ color: #f1f1f1;
241
+ padding: 1em;
242
+ border-radius: 5px;
243
+ max-height: 500px;
244
+ overflow-y: auto;
245
+ }
246
+
247
+ #clear-history {
248
+ background-color: #dc3545;
249
+ }
250
+
251
+ #clear-history:hover {
252
+ background-color: #c82333;
253
+ }
254
+
255
+ .error-message {
256
+ color: #dc3545;
257
+ background-color: #f8d7da;
258
+ border: 1px solid #f5c6cb;
259
+ padding: 0.75rem 1.25rem;
260
+ margin-top: 1rem;
261
+ margin-bottom: 1rem;
262
+ border-radius: 0.25rem;
263
+ text-align: center;
264
+ }
265
+
266
+ /* Responsive Design for smaller screens */
267
+ @media (max-width: 992px) {
268
+ .top-section {
269
+ flex-direction: column;
270
+ }
271
+ }
272
+
273
+ .history-detail-row td {
274
+ background-color: #333;
275
+ padding: 15px;
276
+ border-top: 2px solid #555;
277
+ text-align: left; /* Align content to the left */
278
+ }
279
+
280
+ .history-detail-row pre {
281
+ background-color: #1e1e1e;
282
+ color: #d4d4d4;
283
+ padding: 10px;
284
+ border-radius: 4px;
285
+ white-space: pre-wrap;
286
+ word-break: break-all;
287
+ }
288
+
289
+ .history-detail-row table {
290
+ width: 100%;
291
+ border-collapse: collapse;
292
+ margin: 0;
293
+ }
294
+
295
+ .history-detail-row table th {
296
+ background-color: #e0e0e0;
297
+ color: #333;
298
+ padding: 8px 12px;
299
+ border: 1px solid #555;
300
+ }
301
+
302
+ .history-detail-row table td {
303
+ color: #d4d4d4;
304
+ padding: 8px 12px;
305
+ border: 1px solid #555;
306
+ background-color: #2a2a2a;
307
+ }
308
+
309
+ .model-breakdown-view {
310
+ max-height: 400px; /* Or any other suitable height */
311
+ overflow-y: auto;
312
+ overflow-x: auto;
313
+ background-color: #2d2d2d;
314
+ color: #f1f1f1;
315
+ padding: 1em;
316
+ border-radius: 5px;
317
+ white-space: pre-wrap; /* Ensures the pre content wraps */
318
+ margin: 0;
319
+ font-family: monospace;
320
+ font-size: 0.85em;
321
+ }
322
+
323
+ .model-meta-info {
324
+ font-size: 0.9em;
325
+ color: #666;
326
+ margin-top: 4px;
327
+ }
328
+
329
+ .model-meta-info span {
330
+ margin-right: 15px;
331
+ }
332
+
333
+ .action-btn.raw-btn {
334
+ background-color: #555;
335
+ color: white;
336
+ }
337
+
338
+ .highlight-red {
339
+ color: #ff6b6b;
340
+ }
341
+
342
+ .ansi-red { color: #e74c3c; }
343
+ .ansi-green { color: #2ecc71; }
344
+ .ansi-yellow { color: #f1c40f; }
345
+ .ansi-blue { color: #3498db; }
346
+ .ansi-magenta { color: #9b59b6; }
347
+ .ansi-cyan { color: #1abc9c; }
348
+
349
+ .breakdown-row td {
350
+ text-align: left !important;
351
+ }
352
+
353
+ .footer {
354
+ margin-top: 2em;
355
+ font-size: 0.85em;
356
+ color: #555;
357
+ text-align: center;
358
+ }
359
+
360
+ .footer a {
361
+ color: #2a77d4;
362
+ text-decoration: none;
363
+ }
364
+
365
+ .footer a:hover {
366
+ text-decoration: underline;
367
+ }
368
+
369
+ .disclaimer {
370
+ margin-top: 0.5em;
371
+ font-style: italic;
372
+ }
373
+
374
+ .disclaimer-banner {
375
+ background-color: #fff3cd;
376
+ color: #856404;
377
+ border: 1px solid #ffeeba;
378
+ padding: 10px 15px;
379
+ border-radius: 4px;
380
+ margin: 15px 0;
381
+ font-weight: bold;
382
+ text-align: center;
383
+ }