Patrick Haller commited on
Commit
8ac80ca
·
1 Parent(s): 20f7607
Files changed (1) hide show
  1. modeling_hf_alibaba_nlp_gte.py +3 -0
modeling_hf_alibaba_nlp_gte.py CHANGED
@@ -965,6 +965,8 @@ class GteModel(GtePreTrainedModel):
965
 
966
  class GteForSequenceClassification(GtePreTrainedModel):
967
 
 
 
968
  def __init__(self, config: GteConfig):
969
  super().__init__(config)
970
  self.config = config
@@ -973,6 +975,7 @@ class GteForSequenceClassification(GtePreTrainedModel):
973
 
974
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
975
  self.loss_function = nn.MSELoss()
 
976
 
977
  def get_input_embeddings(self):
978
  return self.model.embed_tokens
 
965
 
966
  class GteForSequenceClassification(GtePreTrainedModel):
967
 
968
+ base_model_prefix = "model"
969
+
970
  def __init__(self, config: GteConfig):
971
  super().__init__(config)
972
  self.config = config
 
975
 
976
  self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
977
  self.loss_function = nn.MSELoss()
978
+ self.post_init()
979
 
980
  def get_input_embeddings(self):
981
  return self.model.embed_tokens