|
|
|
"""Pretrain GPT.""" |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning) |
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
warnings.filterwarnings("ignore") |
|
import inspect |
|
import os |
|
from contextlib import nullcontext |
|
from functools import partial |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from megatron.core import mpu |
|
from megatron.core.datasets.blended_megatron_dataset_builder import ( |
|
BlendedMegatronDatasetBuilder, |
|
) |
|
from megatron.core.datasets.gpt_dataset import ( |
|
GPTDataset, |
|
GPTDatasetConfig, |
|
MockGPTDataset, |
|
) |
|
from megatron.core.datasets.utils import get_blend_from_list |
|
from megatron.core.enums import ModelType |
|
from megatron.core.models.gpt.gpt_layer_specs import ( |
|
get_gpt_decoder_block_spec, |
|
get_gpt_layer_local_spec, |
|
get_gpt_layer_with_transformer_engine_spec, |
|
get_gpt_mtp_block_spec, |
|
) |
|
from megatron.core.transformer.spec_utils import import_module |
|
from megatron.core.utils import StragglerDetector |
|
from megatron.training import ( |
|
get_args, |
|
get_timers, |
|
get_tokenizer, |
|
pretrain, |
|
print_rank_0, |
|
) |
|
from megatron.training.arguments import core_transformer_config_from_args |
|
from megatron.training.initialize import initialize_megatron |
|
from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank |
|
from megatron.training.yaml_arguments import core_transformer_config_from_yaml |
|
from moe_mem_estimator.base import ( |
|
get_pipeline_model_parallel_rank, |
|
get_pipeline_model_parallel_world_size, |
|
get_virtual_pipeline_model_parallel_world_size, |
|
is_pipeline_first_stage, |
|
is_pipeline_last_stage, |
|
set_global_config, |
|
set_pipeline_model_parallel_rank, |
|
) |
|
from moe_mem_estimator.gpt_model import GPTModel |
|
from moe_mem_estimator.layers import MLASelfAttention, MoELayer |
|
|
|
torch.distributed.get_rank = lambda: 0 |
|
torch.cuda.get_device_capability = lambda: [8] |
|
|
|
def estimate_from_config(config, args): |
|
""" |
|
Estimate memory usage from a given config and args, instead of global state. |
|
Now supports virtual pipeline model parallelism for more accurate results. |
|
""" |
|
|
|
args.moe_grouped_gemm = True |
|
patch_parallel_states() |
|
if config is None: |
|
if args.yaml_cfg is not None: |
|
config = core_transformer_config_from_yaml(args, "language_model") |
|
else: |
|
config = core_transformer_config_from_args(args) |
|
|
|
input_shape = [args.micro_batch_size, args.seq_length] |
|
|
|
set_global_config(config) |
|
print(config) |
|
|
|
cli_reports = [] |
|
|
|
if config.pipeline_model_parallel_size > 1: |
|
for pp_rank in range(config.pipeline_model_parallel_size): |
|
set_pipeline_model_parallel_rank(pp_rank) |
|
print( |
|
f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------" |
|
) |
|
input_shape, rpt = report_memory_usage_one_pp_rank( |
|
input_shape, args, config, pp_rank, config.pipeline_model_parallel_size |
|
) |
|
cli_reports.append(rpt) |
|
else: |
|
set_pipeline_model_parallel_rank(0) |
|
_, rpt = report_memory_usage_one_pp_rank(input_shape, args, config) |
|
cli_reports.append(rpt) |
|
|
|
aggregated_reports: list[dict] = cli_reports |
|
|
|
|
|
return aggregated_reports, cli_reports |
|
|
|
|
|
def _get_transformer_layer_spec(use_te, config): |
|
"""Get transformer layer specification based on configuration. |
|
|
|
Args: |
|
use_te (bool): Whether to use Transformer Engine |
|
args: Training arguments |
|
config: Model configuration |
|
|
|
Returns: |
|
transformer_layer_spec: The transformer layer specification |
|
""" |
|
if use_te: |
|
return get_gpt_layer_with_transformer_engine_spec( |
|
config.num_moe_experts, |
|
config.moe_grouped_gemm, |
|
config.qk_layernorm, |
|
config.multi_latent_attention, |
|
config.fp8, |
|
) |
|
else: |
|
return get_gpt_layer_local_spec( |
|
config.num_moe_experts, |
|
config.moe_grouped_gemm, |
|
config.qk_layernorm, |
|
config.multi_latent_attention, |
|
) |
|
|
|
|
|
def model_provider( |
|
args, config, pre_process=True, post_process=True, vp_stage: Optional[int] = None |
|
) -> GPTModel: |
|
use_te = True |
|
if args.num_experts: |
|
|
|
transformer_layer_spec = get_gpt_decoder_block_spec( |
|
config, |
|
use_transformer_engine=use_te, |
|
normalization="LayerNorm", |
|
qk_l2_norm=False, |
|
vp_stage=vp_stage, |
|
) |
|
else: |
|
|
|
transformer_layer_spec = _get_transformer_layer_spec(use_te, config) |
|
mtp_block_spec = None |
|
|
|
model = GPTModel( |
|
config=config, |
|
transformer_layer_spec=transformer_layer_spec, |
|
vocab_size=args.padded_vocab_size, |
|
max_sequence_length=args.max_position_embeddings, |
|
pre_process=pre_process, |
|
post_process=post_process, |
|
fp16_lm_cross_entropy=getattr(config, "fp16_lm_cross_entropy", False), |
|
parallel_output=True, |
|
share_embeddings_and_output_weights=False, |
|
position_embedding_type="rope", |
|
rotary_percent=getattr(args, "rotary_percent", 1.0), |
|
rotary_base=getattr(args, "rotary_base", 10000), |
|
rope_scaling=getattr(config, "use_rope_scaling", False), |
|
mtp_block_spec=mtp_block_spec, |
|
vp_stage=vp_stage, |
|
) |
|
|
|
return model |
|
|
|
|
|
def get_model( |
|
model_provider_func, args, config, model_type=ModelType.encoder_or_decoder |
|
): |
|
"""Build the model.""" |
|
|
|
|
|
|
|
|
|
if not getattr(args, "virtual_pipeline_model_parallel_size", None): |
|
args.virtual_pipeline_model_parallel_size = None |
|
if config.pipeline_model_parallel_layout: |
|
args.virtual_pipeline_model_parallel_size = ( |
|
config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size |
|
) |
|
config.virtual_pipeline_model_parallel_size = ( |
|
config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size |
|
) |
|
|
|
def build_model(): |
|
if ( |
|
get_pipeline_model_parallel_world_size() > 1 |
|
and args.virtual_pipeline_model_parallel_size is not None |
|
): |
|
if model_type == ModelType.encoder_and_decoder: |
|
assert ( |
|
config.encoder_pipeline_model_parallel_size == 0 |
|
), "Interleaved schedule not supported for model with encoder on separate PP rank" |
|
model = [] |
|
for i in range(args.virtual_pipeline_model_parallel_size): |
|
|
|
pre_process = is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) |
|
post_process = is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) |
|
|
|
this_model = model_provider_func( |
|
args, |
|
config, |
|
pre_process=pre_process, |
|
post_process=post_process, |
|
vp_stage=i, |
|
) |
|
this_model.model_type = model_type |
|
this_model.vp_stage = i |
|
model.append(this_model) |
|
else: |
|
pre_process = is_pipeline_first_stage() |
|
post_process = is_pipeline_last_stage() |
|
if model_type == ModelType.encoder_and_decoder: |
|
if get_pipeline_model_parallel_world_size() > 1: |
|
rank = get_pipeline_model_parallel_rank() |
|
first_decoder_rank = config.encoder_pipeline_model_parallel_size |
|
world_size = get_pipeline_model_parallel_world_size() |
|
pre_process = rank == 0 or rank == first_decoder_rank |
|
post_process = (rank == (first_decoder_rank - 1)) or ( |
|
rank == (world_size - 1) |
|
) |
|
model = model_provider_func( |
|
args, |
|
config, |
|
pre_process=pre_process, |
|
post_process=post_process, |
|
) |
|
else: |
|
model = model_provider_func( |
|
args, config, pre_process=pre_process, post_process=post_process |
|
) |
|
model.model_type = model_type |
|
return model |
|
|
|
model = build_model() |
|
|
|
if not isinstance(model, list): |
|
model = [model] |
|
return model |
|
|
|
|
|
NUM_BYTES_IN_MEGABYTE = 1024 * 1024 |
|
NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024 |
|
|
|
|
|
def patch_parallel_states(): |
|
from megatron.core import parallel_state |
|
|
|
parallel_state.is_pipeline_first_stage = is_pipeline_first_stage |
|
parallel_state.is_pipeline_last_stage = is_pipeline_last_stage |
|
parallel_state.get_pipeline_model_parallel_rank = get_pipeline_model_parallel_rank |
|
parallel_state.get_pipeline_model_parallel_world_size = ( |
|
get_pipeline_model_parallel_world_size |
|
) |
|
parallel_state.get_virtual_pipeline_model_parallel_world_size = ( |
|
get_virtual_pipeline_model_parallel_world_size |
|
) |
|
parallel_state.is_inside_encoder = lambda: False |
|
parallel_state.get_pipeline_model_parallel_decoder_start = lambda: 0 |
|
|
|
|
|
def report_memory_usage(args, config=None): |
|
args.moe_grouped_gemm = True |
|
patch_parallel_states() |
|
if config is None: |
|
if args.yaml_cfg is not None: |
|
config = core_transformer_config_from_yaml(args, "language_model") |
|
else: |
|
config = core_transformer_config_from_args(args) |
|
|
|
input_shape = [args.micro_batch_size, args.seq_length] |
|
|
|
set_global_config(config) |
|
|
|
cli_reports = [] |
|
|
|
if config.pipeline_model_parallel_size > 1: |
|
for pp_rank in range(config.pipeline_model_parallel_size): |
|
set_pipeline_model_parallel_rank(pp_rank) |
|
print( |
|
f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------" |
|
) |
|
input_shape, rpt = report_memory_usage_one_pp_rank( |
|
input_shape, args, config, pp_rank, config.pipeline_model_parallel_size |
|
) |
|
cli_reports.append(rpt) |
|
else: |
|
set_pipeline_model_parallel_rank(0) |
|
_, rpt = report_memory_usage_one_pp_rank(input_shape, args, config) |
|
cli_reports.append(rpt) |
|
|
|
|
|
print("\n===== Summary (per PP rank) =====") |
|
for r in cli_reports: |
|
print( |
|
f"PP{r['pp_rank']} total {r['total_gb']} GB (weight_grad {r['weight_grad_gb']} GB weight_grad_optim {r['weight_grad_optim_gb']} GB act {r['activation_gb']} GB)" |
|
) |
|
|
|
|
|
def report_memory_usage_one_pp_rank( |
|
input_shape: list[int], args, config, pp_rank=0, pp_size=1 |
|
) -> tuple[list[int], dict]: |
|
print(f"{input_shape=}") |
|
model: list[GPTModel] = get_model(model_provider, args, config) |
|
num_parameter_this_shard_all = 0 |
|
num_parameter_this_shard_sparse_all = 0 |
|
num_activation_all = 0 |
|
output_shape = input_shape |
|
for vpp_rank, one_chunk in enumerate(model): |
|
num_parameter_this_shard = one_chunk.num_parameter() |
|
num_activation = one_chunk.num_activation(output_shape) |
|
output_shape = one_chunk.mock_forward(output_shape) |
|
print(f"{output_shape=}") |
|
num_parameter_this_shard_sparse = 0 |
|
for layer in one_chunk.decoder.layers.modules: |
|
if isinstance(layer.mlp, MoELayer): |
|
num_parameter_this_shard_sparse += layer.mlp.num_parameter() |
|
if ( |
|
"shared_experts" in layer.mlp.__dir__() |
|
and layer.mlp.shared_experts is not None |
|
): |
|
num_parameter_this_shard_sparse -= ( |
|
layer.mlp.shared_experts.num_parameter() |
|
) |
|
num_activation_this_shard_mlp = sum( |
|
[m.mlp.num_activation() for m in one_chunk.decoder.layers.modules] |
|
) |
|
if len(model) > 1: |
|
if vpp_rank >= 1 and vpp_rank < len(model) - 1: |
|
num_microbatch_this_pp_rank = pp_size |
|
elif vpp_rank == 0: |
|
num_microbatch_this_pp_rank = pp_size + max( |
|
(pp_size - pp_rank) * 2 - 1 - pp_size, 0 |
|
) |
|
elif vpp_rank == len(model) - 1: |
|
num_microbatch_this_pp_rank = min((pp_size - pp_rank) * 2 + 1, pp_size) |
|
else: |
|
num_microbatch_this_pp_rank = pp_size - pp_rank |
|
|
|
num_parameter_this_shard_sparse = 0 |
|
for layer in one_chunk.decoder.layers.modules: |
|
if isinstance(layer.mlp, MoELayer): |
|
num_parameter_this_shard_sparse += layer.mlp.num_parameter() |
|
if ( |
|
"shared_experts" in layer.mlp.__dir__() |
|
and layer.mlp.shared_experts is not None |
|
): |
|
num_parameter_this_shard_sparse -= ( |
|
layer.mlp.shared_experts.num_parameter() |
|
) |
|
|
|
one_chunk.__repr__() |
|
print(one_chunk) |
|
print( |
|
f"Number of parameters in every GPU in billions: " |
|
f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}" |
|
) |
|
num_parameter_this_shard_all += num_parameter_this_shard |
|
num_parameter_this_shard_sparse_all += num_parameter_this_shard_sparse |
|
|
|
if config.recompute_granularity == "full": |
|
recompute_num_layers = config.recompute_num_layers |
|
num_layers = one_chunk.num_layers |
|
common_act = ( |
|
one_chunk.num_act_pre |
|
+ one_chunk.num_act_between_layers |
|
* num_layers |
|
* num_microbatch_this_pp_rank |
|
) |
|
info = "With this recomputing setting, the number of activation achieve peak when " |
|
if config.recompute_method == "block": |
|
num_layers_with_loss = num_layers - recompute_num_layers |
|
if num_layers_with_loss == 0: |
|
peak1 = common_act + one_chunk.num_act_post |
|
peak2 = common_act + one_chunk.num_act_per_layer |
|
if peak1 > peak2: |
|
info += "calculating loss" |
|
else: |
|
info += "back-propogating loss" |
|
num_activation = max(peak1, peak2) |
|
else: |
|
info += f"calculating loss with {num_layers_with_loss} non-recompute layers" |
|
num_activation = ( |
|
common_act |
|
+ one_chunk.num_act_post |
|
+ one_chunk.num_act_per_layer |
|
* num_layers_with_loss |
|
* num_microbatch_this_pp_rank |
|
) |
|
elif config.recompute_method == "uniform": |
|
peak1 = common_act + one_chunk.num_act_post |
|
peak2 = ( |
|
(common_act + one_chunk.num_act_per_layer) |
|
if vpp_rank == 0 |
|
else (common_act) |
|
) |
|
if peak1 > peak2: |
|
info += "calculating loss" |
|
else: |
|
info += f"back-propogating loss recomputing every {recompute_num_layers} layers" |
|
num_activation = max(peak1, peak2) |
|
if len(one_chunk.decoder.layers.modules) > 0 and isinstance( |
|
one_chunk.decoder.layers.modules[0].self_attention, MLASelfAttention |
|
): |
|
num_activation += one_chunk.decoder.layers.modules[ |
|
0 |
|
].self_attention.core_attention.num_activation() |
|
print(info) |
|
|
|
else: |
|
num_activation = ( |
|
num_activation - one_chunk.num_act_post |
|
) * num_microbatch_this_pp_rank + one_chunk.num_act_post |
|
|
|
|
|
num_activation = num_activation / config.context_parallel_size |
|
if pp_size == 1: |
|
print( |
|
f"Number of activation in every GPU in billions: " |
|
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}" |
|
) |
|
else: |
|
print( |
|
f"Number of activation per microbatch in every GPU in billions: " |
|
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}" |
|
f", {num_microbatch_this_pp_rank=} {vpp_rank=}" |
|
) |
|
num_activation_all += num_activation |
|
num_bytes_per_parameter = ( |
|
18 |
|
if not args.use_distributed_optimizer |
|
else 6 + (12 / args.data_parallel_size / config.context_parallel_size) |
|
) |
|
if config.expert_model_parallel_size * config.expert_tensor_parallel_size > 1: |
|
num_bytes_per_parameter_dense = num_bytes_per_parameter |
|
num_bytes_per_parameter_moe = ( |
|
18 |
|
if not args.use_distributed_optimizer |
|
else 6 |
|
+ ( |
|
12 |
|
/ ( |
|
args.world_size |
|
/ config.pipeline_model_parallel_size |
|
/ config.expert_model_parallel_size |
|
/ config.expert_tensor_parallel_size |
|
) |
|
) |
|
) |
|
print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}") |
|
weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE |
|
weight_grad_optim_memory = ( |
|
(num_parameter_this_shard_all - num_parameter_this_shard_sparse_all) |
|
* num_bytes_per_parameter_dense |
|
+ num_parameter_this_shard_sparse_all * num_bytes_per_parameter_moe |
|
) / NUM_BYTES_IN_GIGABYTE |
|
else: |
|
print(f"{num_bytes_per_parameter=}") |
|
weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE |
|
weight_grad_optim_memory = ( |
|
num_parameter_this_shard_all |
|
* num_bytes_per_parameter |
|
/ NUM_BYTES_IN_GIGABYTE |
|
) |
|
|
|
activation_memory = ( |
|
num_activation_all * 2 / NUM_BYTES_IN_GIGABYTE |
|
) |
|
total_memory = weight_grad_optim_memory + activation_memory |
|
|
|
print( |
|
f"Theoretical memory footprints: weight and optimizer={weight_grad_optim_memory:.2f} GB, " |
|
f"activation={activation_memory:.2f} GB, total={total_memory:.2f} GB\n" |
|
) |
|
|
|
|
|
model_breakdown_concat = "\n\n".join( |
|
[f"--- vpp_chunk {i} ---\n{str(m)}" for i, m in enumerate(model)] |
|
) |
|
|
|
report = { |
|
"pp_rank": pp_rank, |
|
"parameters_b": num_parameter_this_shard_all / 1e9, |
|
"activation_b": num_activation_all / 1e9, |
|
"weight_grad_gb": round(weight_grad_memory, 2), |
|
"weight_grad_optim_gb": round(weight_grad_optim_memory, 2), |
|
"activation_gb": round(activation_memory, 2), |
|
"total_gb": round(total_memory, 2), |
|
"model_breakdown": model_breakdown_concat, |
|
"details": None, |
|
} |
|
|
|
return output_shape, report |
|
|
|
|
|
if __name__ == "__main__": |
|
initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True) |
|
|
|
import ipdb |
|
|
|
with ipdb.launch_ipdb_on_exception(): |
|
args = get_args() |
|
report_memory_usage(args) |
|
|