megatron_memory_estimator / estimate_013.py
Yan Bai
0.13
d2006b6
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore")
import inspect
import os
from contextlib import nullcontext
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
from megatron.core import mpu
from megatron.core.datasets.blended_megatron_dataset_builder import (
BlendedMegatronDatasetBuilder,
)
from megatron.core.datasets.gpt_dataset import (
GPTDataset,
GPTDatasetConfig,
MockGPTDataset,
)
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.enums import ModelType
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.transformer.spec_utils import import_module
from megatron.core.utils import StragglerDetector
from megatron.training import (
get_args,
get_timers,
get_tokenizer,
pretrain,
print_rank_0,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.initialize import initialize_megatron
from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from moe_mem_estimator.base import (
get_pipeline_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_world_size,
is_pipeline_first_stage,
is_pipeline_last_stage,
set_global_config,
set_pipeline_model_parallel_rank,
)
from moe_mem_estimator.gpt_model import GPTModel
from moe_mem_estimator.layers import MLASelfAttention, MoELayer
torch.distributed.get_rank = lambda: 0
torch.cuda.get_device_capability = lambda: [8]
def estimate_from_config(config, args):
"""
Estimate memory usage from a given config and args, instead of global state.
Now supports virtual pipeline model parallelism for more accurate results.
"""
args.moe_grouped_gemm = True
patch_parallel_states()
if config is None:
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
input_shape = [args.micro_batch_size, args.seq_length]
set_global_config(config)
print(config)
# return
cli_reports = []
if config.pipeline_model_parallel_size > 1:
for pp_rank in range(config.pipeline_model_parallel_size):
set_pipeline_model_parallel_rank(pp_rank)
print(
f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------"
)
input_shape, rpt = report_memory_usage_one_pp_rank(
input_shape, args, config, pp_rank, config.pipeline_model_parallel_size
)
cli_reports.append(rpt)
else:
set_pipeline_model_parallel_rank(0)
_, rpt = report_memory_usage_one_pp_rank(input_shape, args, config)
cli_reports.append(rpt)
aggregated_reports: list[dict] = cli_reports
# θΏ”ε›ž (θšεˆεŽηš„ pp ζŠ₯ε‘Šεˆ—θ‘¨, 全量 raw chunk εˆ—θ‘¨)
return aggregated_reports, cli_reports
def _get_transformer_layer_spec(use_te, config):
"""Get transformer layer specification based on configuration.
Args:
use_te (bool): Whether to use Transformer Engine
args: Training arguments
config: Model configuration
Returns:
transformer_layer_spec: The transformer layer specification
"""
if use_te:
return get_gpt_layer_with_transformer_engine_spec(
config.num_moe_experts,
config.moe_grouped_gemm,
config.qk_layernorm,
config.multi_latent_attention,
config.fp8,
)
else:
return get_gpt_layer_local_spec(
config.num_moe_experts,
config.moe_grouped_gemm,
config.qk_layernorm,
config.multi_latent_attention,
)
def model_provider(
args, config, pre_process=True, post_process=True, vp_stage: Optional[int] = None
) -> GPTModel:
use_te = True
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(
config,
use_transformer_engine=use_te,
normalization="LayerNorm",
qk_l2_norm=False,
vp_stage=vp_stage,
)
else:
# Define the decoder layer spec
transformer_layer_spec = _get_transformer_layer_spec(use_te, config)
mtp_block_spec = None
# TODO fp8
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=getattr(config, "fp16_lm_cross_entropy", False),
parallel_output=True,
share_embeddings_and_output_weights=False,
position_embedding_type="rope",
rotary_percent=getattr(args, "rotary_percent", 1.0),
rotary_base=getattr(args, "rotary_base", 10000),
rope_scaling=getattr(config, "use_rope_scaling", False),
mtp_block_spec=mtp_block_spec,
vp_stage=vp_stage,
)
return model
def get_model(
model_provider_func, args, config, model_type=ModelType.encoder_or_decoder
):
"""Build the model."""
# args = get_args()
# args.model_type = model_type
# Build model.
if not getattr(args, "virtual_pipeline_model_parallel_size", None):
args.virtual_pipeline_model_parallel_size = None
if config.pipeline_model_parallel_layout:
args.virtual_pipeline_model_parallel_size = (
config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size
)
config.virtual_pipeline_model_parallel_size = (
config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size
)
def build_model():
if (
get_pipeline_model_parallel_world_size() > 1
and args.virtual_pipeline_model_parallel_size is not None
):
if model_type == ModelType.encoder_and_decoder:
assert (
config.encoder_pipeline_model_parallel_size == 0
), "Interleaved schedule not supported for model with encoder on separate PP rank"
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
# Set pre_process and post_process only after virtual rank is set.
pre_process = is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
post_process = is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
this_model = model_provider_func(
args,
config,
pre_process=pre_process,
post_process=post_process,
vp_stage=i,
)
this_model.model_type = model_type
this_model.vp_stage = i
model.append(this_model)
else:
pre_process = is_pipeline_first_stage()
post_process = is_pipeline_last_stage()
if model_type == ModelType.encoder_and_decoder:
if get_pipeline_model_parallel_world_size() > 1:
rank = get_pipeline_model_parallel_rank()
first_decoder_rank = config.encoder_pipeline_model_parallel_size
world_size = get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == first_decoder_rank
post_process = (rank == (first_decoder_rank - 1)) or (
rank == (world_size - 1)
)
model = model_provider_func(
args,
config,
pre_process=pre_process,
post_process=post_process,
)
else:
model = model_provider_func(
args, config, pre_process=pre_process, post_process=post_process
)
model.model_type = model_type
return model
model = build_model()
if not isinstance(model, list):
model = [model]
return model
NUM_BYTES_IN_MEGABYTE = 1024 * 1024
NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024
def patch_parallel_states():
from megatron.core import parallel_state
parallel_state.is_pipeline_first_stage = is_pipeline_first_stage
parallel_state.is_pipeline_last_stage = is_pipeline_last_stage
parallel_state.get_pipeline_model_parallel_rank = get_pipeline_model_parallel_rank
parallel_state.get_pipeline_model_parallel_world_size = (
get_pipeline_model_parallel_world_size
)
parallel_state.get_virtual_pipeline_model_parallel_world_size = (
get_virtual_pipeline_model_parallel_world_size
)
parallel_state.is_inside_encoder = lambda: False
parallel_state.get_pipeline_model_parallel_decoder_start = lambda: 0
def report_memory_usage(args, config=None):
args.moe_grouped_gemm = True
patch_parallel_states()
if config is None:
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
input_shape = [args.micro_batch_size, args.seq_length]
set_global_config(config)
cli_reports = []
if config.pipeline_model_parallel_size > 1:
for pp_rank in range(config.pipeline_model_parallel_size):
set_pipeline_model_parallel_rank(pp_rank)
print(
f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------"
)
input_shape, rpt = report_memory_usage_one_pp_rank(
input_shape, args, config, pp_rank, config.pipeline_model_parallel_size
)
cli_reports.append(rpt)
else:
set_pipeline_model_parallel_rank(0)
_, rpt = report_memory_usage_one_pp_rank(input_shape, args, config)
cli_reports.append(rpt)
# Optionally pretty print summary
print("\n===== Summary (per PP rank) =====")
for r in cli_reports:
print(
f"PP{r['pp_rank']} total {r['total_gb']} GB (weight_grad {r['weight_grad_gb']} GB weight_grad_optim {r['weight_grad_optim_gb']} GB act {r['activation_gb']} GB)"
)
def report_memory_usage_one_pp_rank(
input_shape: list[int], args, config, pp_rank=0, pp_size=1
) -> tuple[list[int], dict]:
print(f"{input_shape=}")
model: list[GPTModel] = get_model(model_provider, args, config)
num_parameter_this_shard_all = 0
num_parameter_this_shard_sparse_all = 0
num_activation_all = 0
output_shape = input_shape
for vpp_rank, one_chunk in enumerate(model):
num_parameter_this_shard = one_chunk.num_parameter()
num_activation = one_chunk.num_activation(output_shape)
output_shape = one_chunk.mock_forward(output_shape)
print(f"{output_shape=}")
num_parameter_this_shard_sparse = 0
for layer in one_chunk.decoder.layers.modules:
if isinstance(layer.mlp, MoELayer):
num_parameter_this_shard_sparse += layer.mlp.num_parameter()
if (
"shared_experts" in layer.mlp.__dir__()
and layer.mlp.shared_experts is not None
):
num_parameter_this_shard_sparse -= (
layer.mlp.shared_experts.num_parameter()
)
num_activation_this_shard_mlp = sum(
[m.mlp.num_activation() for m in one_chunk.decoder.layers.modules]
)
if len(model) > 1:
if vpp_rank >= 1 and vpp_rank < len(model) - 1:
num_microbatch_this_pp_rank = pp_size
elif vpp_rank == 0:
num_microbatch_this_pp_rank = pp_size + max(
(pp_size - pp_rank) * 2 - 1 - pp_size, 0
)
elif vpp_rank == len(model) - 1:
num_microbatch_this_pp_rank = min((pp_size - pp_rank) * 2 + 1, pp_size)
else:
num_microbatch_this_pp_rank = pp_size - pp_rank
num_parameter_this_shard_sparse = 0
for layer in one_chunk.decoder.layers.modules:
if isinstance(layer.mlp, MoELayer):
num_parameter_this_shard_sparse += layer.mlp.num_parameter()
if (
"shared_experts" in layer.mlp.__dir__()
and layer.mlp.shared_experts is not None
):
num_parameter_this_shard_sparse -= (
layer.mlp.shared_experts.num_parameter()
)
one_chunk.__repr__()
print(one_chunk)
print(
f"Number of parameters in every GPU in billions: "
f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}"
)
num_parameter_this_shard_all += num_parameter_this_shard
num_parameter_this_shard_sparse_all += num_parameter_this_shard_sparse
# recompute
if config.recompute_granularity == "full":
recompute_num_layers = config.recompute_num_layers
num_layers = one_chunk.num_layers
common_act = (
one_chunk.num_act_pre
+ one_chunk.num_act_between_layers
* num_layers
* num_microbatch_this_pp_rank
) # recompute with pipeline parallel
info = "With this recomputing setting, the number of activation achieve peak when "
if config.recompute_method == "block":
num_layers_with_loss = num_layers - recompute_num_layers
if num_layers_with_loss == 0:
peak1 = common_act + one_chunk.num_act_post
peak2 = common_act + one_chunk.num_act_per_layer
if peak1 > peak2:
info += "calculating loss"
else:
info += "back-propogating loss"
num_activation = max(peak1, peak2)
else:
info += f"calculating loss with {num_layers_with_loss} non-recompute layers"
num_activation = (
common_act
+ one_chunk.num_act_post
+ one_chunk.num_act_per_layer
* num_layers_with_loss
* num_microbatch_this_pp_rank
)
elif config.recompute_method == "uniform":
peak1 = common_act + one_chunk.num_act_post
peak2 = (
(common_act + one_chunk.num_act_per_layer)
if vpp_rank == 0
else (common_act)
)
if peak1 > peak2:
info += "calculating loss"
else:
info += f"back-propogating loss recomputing every {recompute_num_layers} layers"
num_activation = max(peak1, peak2)
if len(one_chunk.decoder.layers.modules) > 0 and isinstance(
one_chunk.decoder.layers.modules[0].self_attention, MLASelfAttention
): # MLA recompute achieve peak at backward
num_activation += one_chunk.decoder.layers.modules[
0
].self_attention.core_attention.num_activation()
print(info)
else:
num_activation = (
num_activation - one_chunk.num_act_post
) * num_microbatch_this_pp_rank + one_chunk.num_act_post
# CP
num_activation = num_activation / config.context_parallel_size
if pp_size == 1:
print(
f"Number of activation in every GPU in billions: "
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
)
else:
print(
f"Number of activation per microbatch in every GPU in billions: "
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
f", {num_microbatch_this_pp_rank=} {vpp_rank=}"
)
num_activation_all += num_activation
num_bytes_per_parameter = (
18
if not args.use_distributed_optimizer
else 6 + (12 / args.data_parallel_size / config.context_parallel_size)
)
if config.expert_model_parallel_size * config.expert_tensor_parallel_size > 1:
num_bytes_per_parameter_dense = num_bytes_per_parameter
num_bytes_per_parameter_moe = (
18
if not args.use_distributed_optimizer
else 6
+ (
12
/ (
args.world_size
/ config.pipeline_model_parallel_size
/ config.expert_model_parallel_size
/ config.expert_tensor_parallel_size
)
)
)
print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}")
weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE
weight_grad_optim_memory = (
(num_parameter_this_shard_all - num_parameter_this_shard_sparse_all)
* num_bytes_per_parameter_dense
+ num_parameter_this_shard_sparse_all * num_bytes_per_parameter_moe
) / NUM_BYTES_IN_GIGABYTE
else:
print(f"{num_bytes_per_parameter=}")
weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE
weight_grad_optim_memory = (
num_parameter_this_shard_all
* num_bytes_per_parameter
/ NUM_BYTES_IN_GIGABYTE
)
activation_memory = (
num_activation_all * 2 / NUM_BYTES_IN_GIGABYTE
) # only support fp16
total_memory = weight_grad_optim_memory + activation_memory
print(
f"Theoretical memory footprints: weight and optimizer={weight_grad_optim_memory:.2f} GB, "
f"activation={activation_memory:.2f} GB, total={total_memory:.2f} GB\n"
)
# η”ŸζˆδΈŽ estimate_from_config η›ΈεŒζ ΌεΌηš„θšεˆζŠ₯ε‘Š
model_breakdown_concat = "\n\n".join(
[f"--- vpp_chunk {i} ---\n{str(m)}" for i, m in enumerate(model)]
)
report = {
"pp_rank": pp_rank,
"parameters_b": num_parameter_this_shard_all / 1e9,
"activation_b": num_activation_all / 1e9,
"weight_grad_gb": round(weight_grad_memory, 2),
"weight_grad_optim_gb": round(weight_grad_optim_memory, 2),
"activation_gb": round(activation_memory, 2),
"total_gb": round(total_memory, 2),
"model_breakdown": model_breakdown_concat,
"details": None,
}
return output_shape, report
if __name__ == "__main__":
initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True)
import ipdb
with ipdb.launch_ipdb_on_exception():
args = get_args()
report_memory_usage(args)