Update modeling_phi.py
Browse files- modeling_phi.py +2 -2
modeling_phi.py
CHANGED
|
@@ -296,8 +296,8 @@ class MoE(nn.Module):
|
|
| 296 |
config: PretrainedConfig,
|
| 297 |
):
|
| 298 |
super().__init__()
|
| 299 |
-
self.mlp = nn.ModuleList([MLP(config) for i in range(config.
|
| 300 |
-
self.gate = nn.Linear(config.n_embd, config.
|
| 301 |
self.num_experts_per_tok = config.num_experts_per_tok
|
| 302 |
|
| 303 |
def forward(self, x):
|
|
|
|
| 296 |
config: PretrainedConfig,
|
| 297 |
):
|
| 298 |
super().__init__()
|
| 299 |
+
self.mlp = nn.ModuleList([MLP(config) for i in range(config.num_local_experts)])
|
| 300 |
+
self.gate = nn.Linear(config.n_embd, config.num_local_experts, bias=False)
|
| 301 |
self.num_experts_per_tok = config.num_experts_per_tok
|
| 302 |
|
| 303 |
def forward(self, x):
|