Update tokenization_chatglm.py
Browse files- tokenization_chatglm.py +30 -5
tokenization_chatglm.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
|
|
|
| 1 |
import os
|
| 2 |
-
import
|
| 3 |
from typing import List, Optional, Union, Dict
|
| 4 |
from sentencepiece import SentencePieceProcessor
|
| 5 |
from transformers import PreTrainedTokenizer
|
| 6 |
from transformers.utils import logging, PaddingStrategy
|
| 7 |
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
| 8 |
|
|
|
|
| 9 |
class SPTokenizer:
|
| 10 |
def __init__(self, model_path: str):
|
| 11 |
# reload tokenizer
|
|
@@ -30,6 +32,7 @@ class SPTokenizer:
|
|
| 30 |
def tokenize(self, s: str):
|
| 31 |
return self.sp_model.EncodeAsPieces(s)
|
| 32 |
|
|
|
|
| 33 |
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
| 34 |
assert type(s) is str
|
| 35 |
t = self.sp_model.encode(s)
|
|
@@ -40,7 +43,18 @@ class SPTokenizer:
|
|
| 40 |
return t
|
| 41 |
|
| 42 |
def decode(self, t: List[int]) -> str:
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def decode_tokens(self, tokens: List[str]) -> str:
|
| 46 |
text = self.sp_model.DecodePieces(tokens)
|
|
@@ -54,7 +68,9 @@ class SPTokenizer:
|
|
| 54 |
|
| 55 |
def convert_id_to_token(self, index):
|
| 56 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 57 |
-
if index in self.index_special_tokens
|
|
|
|
|
|
|
| 58 |
return ""
|
| 59 |
return self.sp_model.IdToPiece(index)
|
| 60 |
|
|
@@ -64,8 +80,8 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 64 |
|
| 65 |
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
| 66 |
|
| 67 |
-
def __init__(self, vocab_file, padding_side="left",
|
| 68 |
-
|
| 69 |
self.name = "GLMTokenizer"
|
| 70 |
|
| 71 |
self.vocab_file = vocab_file
|
|
@@ -75,6 +91,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 75 |
"<eos>": self.tokenizer.eos_id,
|
| 76 |
"<pad>": self.tokenizer.pad_id
|
| 77 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
def get_command(self, token):
|
| 80 |
if token in self.special_tokens:
|
|
@@ -82,6 +102,10 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 82 |
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
| 83 |
return self.tokenizer.special_tokens[token]
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
@property
|
| 86 |
def pad_token(self) -> str:
|
| 87 |
return "<unk>"
|
|
@@ -163,6 +187,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 163 |
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
| 164 |
return prompt
|
| 165 |
|
|
|
|
| 166 |
def build_inputs_with_special_tokens(
|
| 167 |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 168 |
) -> List[int]:
|
|
|
|
| 1 |
+
import json
|
| 2 |
import os
|
| 3 |
+
import re
|
| 4 |
from typing import List, Optional, Union, Dict
|
| 5 |
from sentencepiece import SentencePieceProcessor
|
| 6 |
from transformers import PreTrainedTokenizer
|
| 7 |
from transformers.utils import logging, PaddingStrategy
|
| 8 |
from transformers.tokenization_utils_base import EncodedInput, BatchEncoding
|
| 9 |
|
| 10 |
+
|
| 11 |
class SPTokenizer:
|
| 12 |
def __init__(self, model_path: str):
|
| 13 |
# reload tokenizer
|
|
|
|
| 32 |
def tokenize(self, s: str):
|
| 33 |
return self.sp_model.EncodeAsPieces(s)
|
| 34 |
|
| 35 |
+
|
| 36 |
def encode(self, s: str, bos: bool = False, eos: bool = False) -> List[int]:
|
| 37 |
assert type(s) is str
|
| 38 |
t = self.sp_model.encode(s)
|
|
|
|
| 43 |
return t
|
| 44 |
|
| 45 |
def decode(self, t: List[int]) -> str:
|
| 46 |
+
text, buffer = "", []
|
| 47 |
+
for token in t:
|
| 48 |
+
if token in self.index_special_tokens:
|
| 49 |
+
if buffer:
|
| 50 |
+
text += self.sp_model.decode(buffer)
|
| 51 |
+
buffer = []
|
| 52 |
+
text += self.index_special_tokens[token]
|
| 53 |
+
else:
|
| 54 |
+
buffer.append(token)
|
| 55 |
+
if buffer:
|
| 56 |
+
text += self.sp_model.decode(buffer)
|
| 57 |
+
return text
|
| 58 |
|
| 59 |
def decode_tokens(self, tokens: List[str]) -> str:
|
| 60 |
text = self.sp_model.DecodePieces(tokens)
|
|
|
|
| 68 |
|
| 69 |
def convert_id_to_token(self, index):
|
| 70 |
"""Converts an index (integer) in a token (str) using the vocab."""
|
| 71 |
+
if index in self.index_special_tokens:
|
| 72 |
+
return self.index_special_tokens[index]
|
| 73 |
+
if index in [self.eos_id, self.bos_id, self.pad_id] or index < 0 or index > self.sp_model.vocab_size():
|
| 74 |
return ""
|
| 75 |
return self.sp_model.IdToPiece(index)
|
| 76 |
|
|
|
|
| 80 |
|
| 81 |
model_input_names = ["input_ids", "attention_mask", "position_ids"]
|
| 82 |
|
| 83 |
+
def __init__(self, vocab_file, padding_side="left", clean_up_tokenization_spaces=False, encode_special_tokens=False,
|
| 84 |
+
**kwargs):
|
| 85 |
self.name = "GLMTokenizer"
|
| 86 |
|
| 87 |
self.vocab_file = vocab_file
|
|
|
|
| 91 |
"<eos>": self.tokenizer.eos_id,
|
| 92 |
"<pad>": self.tokenizer.pad_id
|
| 93 |
}
|
| 94 |
+
self.encode_special_tokens = encode_special_tokens
|
| 95 |
+
super().__init__(padding_side=padding_side, clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 96 |
+
encode_special_tokens=encode_special_tokens,
|
| 97 |
+
**kwargs)
|
| 98 |
|
| 99 |
def get_command(self, token):
|
| 100 |
if token in self.special_tokens:
|
|
|
|
| 102 |
assert token in self.tokenizer.special_tokens, f"{token} is not a special token for {self.name}"
|
| 103 |
return self.tokenizer.special_tokens[token]
|
| 104 |
|
| 105 |
+
@property
|
| 106 |
+
def unk_token(self) -> str:
|
| 107 |
+
return "<unk>"
|
| 108 |
+
|
| 109 |
@property
|
| 110 |
def pad_token(self) -> str:
|
| 111 |
return "<unk>"
|
|
|
|
| 187 |
prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history) + 1, query)
|
| 188 |
return prompt
|
| 189 |
|
| 190 |
+
|
| 191 |
def build_inputs_with_special_tokens(
|
| 192 |
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
| 193 |
) -> List[int]:
|