duzx16
commited on
Commit
·
096f3de
1
Parent(s):
4a9b711
Fix context length in get_position_ids
Browse files- modeling_chatglm.py +1 -1
modeling_chatglm.py
CHANGED
|
@@ -769,7 +769,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 769 |
return attention_mask
|
| 770 |
|
| 771 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
| 772 |
-
context_length = seq
|
| 773 |
if self.position_encoding_2d:
|
| 774 |
seq_length = seq.index(self.config.bos_token_id)
|
| 775 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
|
|
|
| 769 |
return attention_mask
|
| 770 |
|
| 771 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
| 772 |
+
context_length = len(seq)
|
| 773 |
if self.position_encoding_2d:
|
| 774 |
seq_length = seq.index(self.config.bos_token_id)
|
| 775 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|