Warn about megablocks more clearly and less often (#20)
Browse files- warn about megablocks more clearly and less often (8a4d4d9a7f96bf4ffe71c72251432824ebfd90d4)
Co-authored-by: Cebtenzzre <[email protected]>
- modeling_hf_nomic_bert.py +11 -6
modeling_hf_nomic_bert.py
CHANGED
|
@@ -3,13 +3,15 @@
|
|
| 3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
| 4 |
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
| 5 |
|
|
|
|
|
|
|
| 6 |
import collections
|
|
|
|
| 7 |
import logging
|
| 8 |
-
|
| 9 |
-
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
| 10 |
import math
|
| 11 |
import os
|
| 12 |
import re
|
|
|
|
| 13 |
from collections import OrderedDict
|
| 14 |
from functools import partial
|
| 15 |
from typing import List, Optional, Tuple, Union
|
|
@@ -54,8 +56,9 @@ try:
|
|
| 54 |
from megablocks.layers import dmoe
|
| 55 |
from megablocks.layers.arguments import Arguments
|
| 56 |
except ImportError:
|
| 57 |
-
logger.warning("!!!!!!!!!!!!megablocks not available, using torch.matmul instead")
|
| 58 |
dmoe = None
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
|
|
@@ -1612,7 +1615,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
| 1612 |
)
|
| 1613 |
self.moe = moe
|
| 1614 |
if moe:
|
| 1615 |
-
if dmoe is not None:
|
| 1616 |
megablocks_args = Arguments(
|
| 1617 |
moe_num_experts=config.num_experts,
|
| 1618 |
moe_top_k=config.moe_top_k,
|
|
@@ -1628,6 +1631,8 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
| 1628 |
)
|
| 1629 |
self.mlp = dmoe.dMoE(megablocks_args)
|
| 1630 |
else:
|
|
|
|
|
|
|
| 1631 |
self.mlp = NomicMoELayer(
|
| 1632 |
config
|
| 1633 |
)
|
|
@@ -1698,7 +1703,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
| 1698 |
residual = (dropped + residual) if residual is not None else dropped
|
| 1699 |
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 1700 |
if self.moe:
|
| 1701 |
-
hidden_states = self.mlp(hidden_states, torch.where(attention_mask.squeeze() == 0, 1, 0))
|
| 1702 |
else:
|
| 1703 |
hidden_states = self.mlp(hidden_states)
|
| 1704 |
|
|
@@ -1715,7 +1720,7 @@ class NomicBertBlock(NomicBertPreTrainedModel):
|
|
| 1715 |
)
|
| 1716 |
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
| 1717 |
if self.moe:
|
| 1718 |
-
mlp_out = self.mlp(hidden_states, torch.where(attention_mask.squeeze() == 0, 1, 0))
|
| 1719 |
else:
|
| 1720 |
mlp_out = self.mlp(hidden_states)
|
| 1721 |
|
|
|
|
| 3 |
# https://github.com/mlcommons/training_results_v2.0/blob/main/HazyResearch/benchmarks/bert/implementations/pytorch/modeling.py
|
| 4 |
# https://github.com/mlcommons/training_results_v2.1/blob/main/Azure-HazyResearch/benchmarks/bert/implementations/ND96amsr_A100_v4/modeling.py
|
| 5 |
|
| 6 |
+
# Inspired by https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py
|
| 7 |
+
|
| 8 |
import collections
|
| 9 |
+
import inspect
|
| 10 |
import logging
|
|
|
|
|
|
|
| 11 |
import math
|
| 12 |
import os
|
| 13 |
import re
|
| 14 |
+
import warnings
|
| 15 |
from collections import OrderedDict
|
| 16 |
from functools import partial
|
| 17 |
from typing import List, Optional, Tuple, Union
|
|
|
|
| 56 |
from megablocks.layers import dmoe
|
| 57 |
from megablocks.layers.arguments import Arguments
|
| 58 |
except ImportError:
|
|
|
|
| 59 |
dmoe = None
|
| 60 |
+
else:
|
| 61 |
+
dmoe_is_nomic = 'attention_mask' in inspect.signature(dmoe.dMoE.forward).parameters
|
| 62 |
|
| 63 |
|
| 64 |
|
|
|
|
| 1615 |
)
|
| 1616 |
self.moe = moe
|
| 1617 |
if moe:
|
| 1618 |
+
if dmoe is not None and dmoe_is_nomic:
|
| 1619 |
megablocks_args = Arguments(
|
| 1620 |
moe_num_experts=config.num_experts,
|
| 1621 |
moe_top_k=config.moe_top_k,
|
|
|
|
| 1631 |
)
|
| 1632 |
self.mlp = dmoe.dMoE(megablocks_args)
|
| 1633 |
else:
|
| 1634 |
+
warnings.warn("Install Nomic's megablocks fork for better speed: " +
|
| 1635 |
+
"`pip install git+https://github.com/nomic-ai/megablocks.git`")
|
| 1636 |
self.mlp = NomicMoELayer(
|
| 1637 |
config
|
| 1638 |
)
|
|
|
|
| 1703 |
residual = (dropped + residual) if residual is not None else dropped
|
| 1704 |
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
|
| 1705 |
if self.moe:
|
| 1706 |
+
hidden_states = self.mlp(hidden_states, attention_mask=torch.where(attention_mask.squeeze() == 0, 1, 0))
|
| 1707 |
else:
|
| 1708 |
hidden_states = self.mlp(hidden_states)
|
| 1709 |
|
|
|
|
| 1720 |
)
|
| 1721 |
hidden_states = self.norm1((self.dropout1(attn_outputs) + hidden_states).to(dtype=self.norm1.weight.dtype))
|
| 1722 |
if self.moe:
|
| 1723 |
+
mlp_out = self.mlp(hidden_states, attention_mask=torch.where(attention_mask.squeeze() == 0, 1, 0))
|
| 1724 |
else:
|
| 1725 |
mlp_out = self.mlp(hidden_states)
|
| 1726 |
|