Shawn Tan commited on
Commit
fb9d7e3
·
1 Parent(s): 192f087

Change signature to match transformers v5.

Browse files
Files changed (1) hide show
  1. torch-ext/scattermoe/layers.py +1 -1
torch-ext/scattermoe/layers.py CHANGED
@@ -48,5 +48,5 @@ class ScatterMoEGatedMLP(nn.Module):
48
  gates=routing_weights
49
  )
50
  layer_output = layer_output.view(bsz, length, emb_size)
51
- return layer_output, router_logits
52
 
 
48
  gates=routing_weights
49
  )
50
  layer_output = layer_output.view(bsz, length, emb_size)
51
+ return layer_output
52