duzx16
commited on
Commit
·
373fd6b
1
Parent(s):
e22cddf
Fix attention_mask and position_ids
Browse files- tokenization_chatglm.py +23 -21
tokenization_chatglm.py
CHANGED
|
@@ -340,7 +340,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 340 |
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
| 341 |
|
| 342 |
if token_ids_1 is not None:
|
| 343 |
-
if token_ids_1[-1] != eop_id:
|
| 344 |
token_ids_1 += [eop_id]
|
| 345 |
token_ids_0 += token_ids_1
|
| 346 |
|
|
@@ -397,26 +397,28 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 397 |
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
| 398 |
|
| 399 |
# Initialize attention mask if not present.
|
| 400 |
-
if
|
| 401 |
-
if
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
| 420 |
|
| 421 |
if needs_to_be_padded:
|
| 422 |
difference = max_length - len(required_input)
|
|
|
|
| 340 |
token_ids_0 += [self.sp_tokenizer[self.bos_token]]
|
| 341 |
|
| 342 |
if token_ids_1 is not None:
|
| 343 |
+
if not token_ids_1 or token_ids_1[-1] != eop_id:
|
| 344 |
token_ids_1 += [eop_id]
|
| 345 |
token_ids_0 += token_ids_1
|
| 346 |
|
|
|
|
| 397 |
needs_to_be_padded = padding_strategy != PaddingStrategy.DO_NOT_PAD and len(required_input) != max_length
|
| 398 |
|
| 399 |
# Initialize attention mask if not present.
|
| 400 |
+
if max_length is not None:
|
| 401 |
+
if "attention_mask" not in encoded_inputs:
|
| 402 |
+
if bos_token_id in required_input:
|
| 403 |
+
context_length = required_input.index(bos_token_id)
|
| 404 |
+
else:
|
| 405 |
+
context_length = seq_length
|
| 406 |
+
attention_mask = np.ones((1, seq_length, seq_length))
|
| 407 |
+
attention_mask = np.tril(attention_mask)
|
| 408 |
+
attention_mask[:, :, :context_length] = 1
|
| 409 |
+
attention_mask = np.bool_(attention_mask < 0.5)
|
| 410 |
+
encoded_inputs["attention_mask"] = attention_mask
|
| 411 |
+
|
| 412 |
+
if "position_ids" not in encoded_inputs:
|
| 413 |
+
position_ids = np.arange(seq_length, dtype=np.int64)
|
| 414 |
+
mask_token = mask_token_id if mask_token_id in required_input else gmask_token_id
|
| 415 |
+
if mask_token in required_input:
|
| 416 |
+
mask_position = required_input.index(mask_token)
|
| 417 |
+
position_ids[context_length:] = mask_position
|
| 418 |
+
block_position_ids = np.concatenate(
|
| 419 |
+
[np.zeros(context_length, dtype=np.int64),
|
| 420 |
+
np.arange(1, seq_length - context_length + 1, dtype=np.int64)])
|
| 421 |
+
encoded_inputs["position_ids"] = np.stack([position_ids, block_position_ids], axis=0)
|
| 422 |
|
| 423 |
if needs_to_be_padded:
|
| 424 |
difference = max_length - len(required_input)
|