Commit
·
5310008
1
Parent(s):
66875f9
Update modeling_baichuan.py
Browse files- modeling_baichuan.py +3 -1
modeling_baichuan.py
CHANGED
|
@@ -704,9 +704,11 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
|
|
| 704 |
loss_fct = CrossEntropyLoss()
|
| 705 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 706 |
shift_labels = shift_labels.view(-1)
|
|
|
|
|
|
|
| 707 |
# Enable model parallelism
|
| 708 |
shift_labels = shift_labels.to(shift_logits.device)
|
| 709 |
-
loss = loss_fct(shift_logits, shift_labels)
|
| 710 |
|
| 711 |
if not return_dict:
|
| 712 |
output = (logits,) + outputs[1:]
|
|
|
|
| 704 |
loss_fct = CrossEntropyLoss()
|
| 705 |
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 706 |
shift_labels = shift_labels.view(-1)
|
| 707 |
+
softmax_normalizer = shift_logits.max(-1).values ** 2
|
| 708 |
+
z_loss = self.config.z_loss_weight * softmax_normalizer.mean()
|
| 709 |
# Enable model parallelism
|
| 710 |
shift_labels = shift_labels.to(shift_logits.device)
|
| 711 |
+
loss = loss_fct(shift_logits, shift_labels) + z_loss
|
| 712 |
|
| 713 |
if not return_dict:
|
| 714 |
output = (logits,) + outputs[1:]
|