Patrick Haller commited on
Commit
fac4c2e
·
1 Parent(s): d4d56eb

Fix loss function

Browse files
Files changed (1) hide show
  1. modeling_hf_alibaba_nlp_gte.py +1 -2
modeling_hf_alibaba_nlp_gte.py CHANGED
@@ -975,7 +975,6 @@ class GteForSequenceClassification(GtePreTrainedModel):
975
  self.model = GteModel(config, add_pooling_layer=config.add_pooling_layer)
976
 
977
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
978
- self.loss_function = nn.MSELoss()
979
  self.post_init()
980
 
981
  def get_input_embeddings(self):
@@ -1020,7 +1019,7 @@ class GteForSequenceClassification(GtePreTrainedModel):
1020
 
1021
  loss = None
1022
  if labels is not None:
1023
- loss = self.loss_function(logits.squeeze(-1), labels.squeeze(-1))
1024
 
1025
  return SequenceClassifierOutputWithPast(
1026
  loss=loss,
 
975
  self.model = GteModel(config, add_pooling_layer=config.add_pooling_layer)
976
 
977
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
 
978
  self.post_init()
979
 
980
  def get_input_embeddings(self):
 
1019
 
1020
  loss = None
1021
  if labels is not None:
1022
+ loss = self.loss_function(labels, logits, self.config)
1023
 
1024
  return SequenceClassifierOutputWithPast(
1025
  loss=loss,