Patrick Haller
commited on
Commit
·
8ac80ca
1
Parent(s):
20f7607
Fix
Browse files
modeling_hf_alibaba_nlp_gte.py
CHANGED
@@ -965,6 +965,8 @@ class GteModel(GtePreTrainedModel):
|
|
965 |
|
966 |
class GteForSequenceClassification(GtePreTrainedModel):
|
967 |
|
|
|
|
|
968 |
def __init__(self, config: GteConfig):
|
969 |
super().__init__(config)
|
970 |
self.config = config
|
@@ -973,6 +975,7 @@ class GteForSequenceClassification(GtePreTrainedModel):
|
|
973 |
|
974 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
975 |
self.loss_function = nn.MSELoss()
|
|
|
976 |
|
977 |
def get_input_embeddings(self):
|
978 |
return self.model.embed_tokens
|
|
|
965 |
|
966 |
class GteForSequenceClassification(GtePreTrainedModel):
|
967 |
|
968 |
+
base_model_prefix = "model"
|
969 |
+
|
970 |
def __init__(self, config: GteConfig):
|
971 |
super().__init__(config)
|
972 |
self.config = config
|
|
|
975 |
|
976 |
self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
|
977 |
self.loss_function = nn.MSELoss()
|
978 |
+
self.post_init()
|
979 |
|
980 |
def get_input_embeddings(self):
|
981 |
return self.model.embed_tokens
|