LLM-foundry update June 16, 2023 22:55:57
#63
by
daking
- opened
- custom_embedding.py +1 -2
- modeling_mpt.py +11 -1
custom_embedding.py
CHANGED
|
@@ -3,10 +3,9 @@ import torch.nn as nn
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch import Tensor
|
| 5 |
|
| 6 |
-
|
| 7 |
class SharedEmbedding(nn.Embedding):
|
| 8 |
|
| 9 |
-
def forward(self, input: Tensor, unembed: bool
|
| 10 |
if unembed:
|
| 11 |
return F.linear(input, self.weight)
|
| 12 |
return super().forward(input)
|
|
|
|
| 3 |
import torch.nn.functional as F
|
| 4 |
from torch import Tensor
|
| 5 |
|
|
|
|
| 6 |
class SharedEmbedding(nn.Embedding):
|
| 7 |
|
| 8 |
+
def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
|
| 9 |
if unembed:
|
| 10 |
return F.linear(input, self.weight)
|
| 11 |
return super().forward(input)
|
modeling_mpt.py
CHANGED
|
@@ -40,6 +40,11 @@ class MPTModel(MPTPreTrainedModel):
|
|
| 40 |
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
| 41 |
self.alibi = config.attn_config['alibi']
|
| 42 |
self.alibi_bias_max = config.attn_config['alibi_bias_max']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
| 44 |
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
|
| 45 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
|
@@ -47,7 +52,7 @@ class MPTModel(MPTPreTrainedModel):
|
|
| 47 |
self.embedding_fraction = config.embedding_fraction
|
| 48 |
self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
|
| 49 |
if not self.alibi:
|
| 50 |
-
self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
| 51 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
| 52 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
| 53 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
|
@@ -221,6 +226,11 @@ class MPTForCausalLM(MPTPreTrainedModel):
|
|
| 221 |
if not config.tie_word_embeddings:
|
| 222 |
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
| 223 |
self.transformer = MPTModel(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
self.logit_scale = None
|
| 225 |
if config.logit_scale is not None:
|
| 226 |
logit_scale = config.logit_scale
|
|
|
|
| 40 |
self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
|
| 41 |
self.alibi = config.attn_config['alibi']
|
| 42 |
self.alibi_bias_max = config.attn_config['alibi_bias_max']
|
| 43 |
+
if config.init_device == 'mixed':
|
| 44 |
+
if dist.get_local_rank() == 0:
|
| 45 |
+
config.init_device = 'cpu'
|
| 46 |
+
else:
|
| 47 |
+
config.init_device = 'meta'
|
| 48 |
if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
|
| 49 |
norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
|
| 50 |
raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
|
|
|
|
| 52 |
self.embedding_fraction = config.embedding_fraction
|
| 53 |
self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
|
| 54 |
if not self.alibi:
|
| 55 |
+
self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
|
| 56 |
self.emb_drop = nn.Dropout(config.emb_pdrop)
|
| 57 |
self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
|
| 58 |
self.norm_f = norm_class(config.d_model, device=config.init_device)
|
|
|
|
| 226 |
if not config.tie_word_embeddings:
|
| 227 |
raise ValueError('MPTForCausalLM only supports tied word embeddings')
|
| 228 |
self.transformer = MPTModel(config)
|
| 229 |
+
for child in self.transformer.children():
|
| 230 |
+
if isinstance(child, torch.nn.ModuleList):
|
| 231 |
+
continue
|
| 232 |
+
if isinstance(child, torch.nn.Module):
|
| 233 |
+
child._fsdp_wrap = True
|
| 234 |
self.logit_scale = None
|
| 235 |
if config.logit_scale is not None:
|
| 236 |
logit_scale = config.logit_scale
|