|
|
import torch |
|
|
import torch.nn as nn |
|
|
import re |
|
|
import torch_xla.experimental.xla_sharding as xs |
|
|
import torch_xla.core.xla_model as xm |
|
|
from transformers import ( |
|
|
GPTNeoXConfig, T5Config, LlamaConfig |
|
|
) |
|
|
|
|
|
|
|
|
GPTNEOX_RULES = ( |
|
|
|
|
|
("gpt_neox\\.embed_in", ("mp", "fsdp")), |
|
|
|
|
|
("attention\\.query_key_value$", ("fsdp", "mp")), |
|
|
("attention\\.dense$", ("mp", "fsdp")), |
|
|
|
|
|
("mlp\\.dense_h_to_4h$", ("fsdp", "mp")), |
|
|
("mlp\\.dense_4h_to_h$", ("mp", "fsdp")), |
|
|
|
|
|
("embed_out", ("fsdp", "mp")), |
|
|
) |
|
|
|
|
|
T5_RULES = ( |
|
|
|
|
|
("shared$", ("mp", "fsdp")), |
|
|
("embed_tokens$", ("mp", "fsdp")), |
|
|
|
|
|
|
|
|
("q$", ("fsdp", "mp")), |
|
|
("k$", ("fsdp", "mp")), |
|
|
("v$", ("fsdp", "mp")), |
|
|
("o$", ("mp", "fsdp")), |
|
|
|
|
|
|
|
|
("w$", ("fsdp", "mp")), |
|
|
("wi_0$", ("fsdp", "mp")), |
|
|
("wi_1$", ("fsdp", "mp")), |
|
|
("wo$", ("mp", "fsdp")), |
|
|
|
|
|
|
|
|
("lm_head", ("fsdp", "mp")), |
|
|
) |
|
|
|
|
|
LLAMA_RULES = ( |
|
|
("model\\.embed_tokens", ("mp", "fsdp")), |
|
|
("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")), |
|
|
("self_attn\\.o_proj", ("mp", "fsdp")), |
|
|
("mlp\\.gate_proj", ("fsdp", "mp")), |
|
|
("mlp\\.down_proj", ("mp", "fsdp")), |
|
|
("mlp\\.up_proj", ("fsdp", "mp")), |
|
|
("lm_head", ("fsdp", "mp")), |
|
|
) |
|
|
|
|
|
ALL_RULES = [ |
|
|
(GPTNeoXConfig, GPTNEOX_RULES), |
|
|
(T5Config, T5_RULES), |
|
|
(LlamaConfig, LLAMA_RULES) |
|
|
] |
|
|
|
|
|
def find_rule(model): |
|
|
for config, rule in ALL_RULES: |
|
|
if model.config.__class__ == config: |
|
|
return rule |
|
|
raise Exception("unsupported model to partitioning") |
|
|
|
|
|
strkey2id = { |
|
|
"dp": 0, |
|
|
"fsdp": 1, |
|
|
"mp": 2 |
|
|
} |
|
|
|
|
|
def partition_module(model, mesh, device=xm.xla_device(), verbose=False): |
|
|
partition_specs = find_rule(model) |
|
|
rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs] |
|
|
|
|
|
|
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
module.to(device) |
|
|
|
|
|
if isinstance(module, (nn.Embedding, nn.Linear)): |
|
|
for rule_pattern, spec in rule: |
|
|
if re.findall(rule_pattern, name): |
|
|
if verbose: |
|
|
print("match", rule_pattern, name) |
|
|
|
|
|
xs.mark_sharding(module.weight, mesh, spec) |
|
|
break |
|
|
|
|
|
def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=False): |
|
|
spec = (1, 2) |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
module.to(device) |
|
|
if isinstance(module, (nn.Embedding, nn.Linear)): |
|
|
xs.mark_sharding(module.weight, mesh, spec) |