duzx16
commited on
Commit
·
0564795
1
Parent(s):
2200e2b
Fix bugs
Browse files- modeling_chatglm.py +1 -2
modeling_chatglm.py
CHANGED
|
@@ -817,7 +817,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 817 |
# past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
|
| 818 |
return past_key_values
|
| 819 |
|
| 820 |
-
@staticmethod
|
| 821 |
def get_masks(self, input_ids, device):
|
| 822 |
batch_size, seq_length = input_ids.shape
|
| 823 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
|
@@ -900,7 +899,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 900 |
)
|
| 901 |
|
| 902 |
if self.pre_seq_len is not None:
|
| 903 |
-
prefix_attention_mask = torch.ones(
|
| 904 |
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
| 905 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 906 |
|
|
|
|
| 817 |
# past_key_values = [(v1,v2) for v1, v2 in zip(past_key_values[0], past_key_values[1])]
|
| 818 |
return past_key_values
|
| 819 |
|
|
|
|
| 820 |
def get_masks(self, input_ids, device):
|
| 821 |
batch_size, seq_length = input_ids.shape
|
| 822 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
|
|
|
| 899 |
)
|
| 900 |
|
| 901 |
if self.pre_seq_len is not None:
|
| 902 |
+
prefix_attention_mask = torch.ones(batch_size, 1, input_ids.size(-1), self.pre_seq_len).to(attention_mask.device)
|
| 903 |
prefix_attention_mask = (prefix_attention_mask < 0.5).bool()
|
| 904 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 905 |
|