File size: 2,693 Bytes
1272ff3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
GPTNeoXConfig, T5Config, LlamaConfig
)
# ends with $ to prevent sharding lora parameters
GPTNEOX_RULES = (
# embeddings
("gpt_neox\\.embed_in", ("mp", "fsdp")),
# atention
("attention\\.query_key_value$", ("fsdp", "mp")),
("attention\\.dense$", ("mp", "fsdp")),
# mlp
("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
# output
("embed_out", ("fsdp", "mp")),
)
T5_RULES = (
# embeddings
("shared$", ("mp", "fsdp")),
("embed_tokens$", ("mp", "fsdp")),
# attention
("q$", ("fsdp", "mp")),
("k$", ("fsdp", "mp")),
("v$", ("fsdp", "mp")),
("o$", ("mp", "fsdp")),
# mlp
("w$", ("fsdp", "mp")),
("wi_0$", ("fsdp", "mp")),
("wi_1$", ("fsdp", "mp")),
("wo$", ("mp", "fsdp")),
# seq2seq lm head
("lm_head", ("fsdp", "mp")),
)
LLAMA_RULES = (
("model\\.embed_tokens", ("mp", "fsdp")),
("self_attn\\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")),
("self_attn\\.o_proj", ("mp", "fsdp")),
("mlp\\.gate_proj", ("fsdp", "mp")),
("mlp\\.down_proj", ("mp", "fsdp")),
("mlp\\.up_proj", ("fsdp", "mp")),
("lm_head", ("fsdp", "mp")),
)
ALL_RULES = [
(GPTNeoXConfig, GPTNEOX_RULES),
(T5Config, T5_RULES),
(LlamaConfig, LLAMA_RULES)
]
def find_rule(model):
for config, rule in ALL_RULES:
if model.config.__class__ == config:
return rule
raise Exception("unsupported model to partitioning")
strkey2id = {
"dp": 0,
"fsdp": 1,
"mp": 2
}
def partition_module(model, mesh, device=xm.xla_device(), verbose=False):
partition_specs = find_rule(model)
rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]
# print(rule)
for name, module in model.named_modules():
module.to(device)
# print(name, module.__class__.__name__)
if isinstance(module, (nn.Embedding, nn.Linear)):
for rule_pattern, spec in rule:
if re.findall(rule_pattern, name):
if verbose:
print("match", rule_pattern, name)
xs.mark_sharding(module.weight, mesh, spec)
break
def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=False):
spec = (1, 2)
for name, module in model.named_modules():
module.to(device)
if isinstance(module, (nn.Embedding, nn.Linear)):
xs.mark_sharding(module.weight, mesh, spec) |