Merge branch 'main' into dev_pt
Browse files- README.md +4 -0
- modeling_chatglm.py +154 -63
- tokenization_chatglm.py +1 -1
README.md
CHANGED
|
@@ -11,6 +11,8 @@ tags:
|
|
| 11 |
## 介绍
|
| 12 |
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
|
| 13 |
|
|
|
|
|
|
|
| 14 |
## 软件依赖
|
| 15 |
|
| 16 |
```shell
|
|
@@ -44,6 +46,8 @@ pip install protobuf==3.20.0 transformers==4.26.1 icetk cpm_kernels
|
|
| 44 |
|
| 45 |
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
|
| 46 |
|
|
|
|
|
|
|
| 47 |
## 协议
|
| 48 |
|
| 49 |
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
|
|
|
|
| 11 |
## 介绍
|
| 12 |
ChatGLM-6B 是一个开源的、支持中英双语问答的对话语言模型,基于 [General Language Model (GLM)](https://github.com/THUDM/GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4 量化级别下最低只需 6GB 显存)。ChatGLM-6B 使用了和 [ChatGLM](https://chatglm.cn) 相同的技术,针对中文问答和对话进行了优化。经过约 1T 标识符的中英双语训练,辅以监督微调、反馈自助、人类反馈强化学习等技术的加持,62 亿参数的 ChatGLM-6B 已经能生成相当符合人类偏好的回答。
|
| 13 |
|
| 14 |
+
ChatGLM-6B is an open bilingual language model based on [General Language Model (GLM)](https://github.com/THUDM/GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). ChatGLM-6B uses technology similar to ChatGPT, optimized for Chinese QA and dialogue. The model is trained for about 1T tokens of Chinese and English corpus, supplemented by supervised fine-tuning, feedback bootstrap, and reinforcement learning wit human feedback. With only about 6.2 billion parameters, the model is able to generate answers that are in line with human preference.
|
| 15 |
+
|
| 16 |
## 软件依赖
|
| 17 |
|
| 18 |
```shell
|
|
|
|
| 46 |
|
| 47 |
关于更多的使用说明,包括如何运行命令行和网页版本的 DEMO,以及使用模型量化以节省显存,请参考我们的 [Github Repo](https://github.com/THUDM/ChatGLM-6B)。
|
| 48 |
|
| 49 |
+
For more instructions, including how to run CLI and web demos, and model quantization, please refer to our [Github Repo](https://github.com/THUDM/ChatGLM-6B).
|
| 50 |
+
|
| 51 |
## 协议
|
| 52 |
|
| 53 |
本仓库的代码依照 [Apache-2.0](LICENSE) 协议开源,ChatGLM-6B 模型的权重的使用则需要遵循 [Model License](MODEL_LICENSE)。
|
modeling_chatglm.py
CHANGED
|
@@ -3,7 +3,9 @@
|
|
| 3 |
import math
|
| 4 |
import copy
|
| 5 |
import os
|
| 6 |
-
import
|
|
|
|
|
|
|
| 7 |
|
| 8 |
import torch
|
| 9 |
import torch.utils.checkpoint
|
|
@@ -11,7 +13,7 @@ import torch.nn.functional as F
|
|
| 11 |
from torch import nn
|
| 12 |
from torch.nn import CrossEntropyLoss, LayerNorm
|
| 13 |
from torch.nn.utils import skip_init
|
| 14 |
-
from typing import Optional, Tuple, Union, List
|
| 15 |
|
| 16 |
from transformers.utils import (
|
| 17 |
add_code_sample_docstrings,
|
|
@@ -26,15 +28,17 @@ from transformers.modeling_outputs import (
|
|
| 26 |
from transformers.modeling_utils import PreTrainedModel
|
| 27 |
from transformers.utils import logging
|
| 28 |
from transformers.generation.logits_process import LogitsProcessor
|
| 29 |
-
from transformers.generation.utils import LogitsProcessorList
|
| 30 |
|
| 31 |
from .configuration_chatglm import ChatGLMConfig
|
| 32 |
|
| 33 |
# flags required to enable jit fusion kernels
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
torch._C.
|
| 37 |
-
torch._C.
|
|
|
|
|
|
|
| 38 |
|
| 39 |
logger = logging.get_logger(__name__)
|
| 40 |
|
|
@@ -294,7 +298,7 @@ def attention_fn(
|
|
| 294 |
if not (attention_mask == 0).all():
|
| 295 |
# if auto-regressive, skip
|
| 296 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
| 297 |
-
dtype = attention_scores.
|
| 298 |
attention_scores = attention_scores.float()
|
| 299 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
| 300 |
|
|
@@ -814,8 +818,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 814 |
return past_key_values
|
| 815 |
|
| 816 |
@staticmethod
|
| 817 |
-
def get_masks(seq, device):
|
| 818 |
-
context_length = seq.index(
|
| 819 |
|
| 820 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
| 821 |
attention_mask.tril_()
|
|
@@ -826,9 +830,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 826 |
return attention_mask
|
| 827 |
|
| 828 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
| 829 |
-
context_length = seq
|
| 830 |
if self.position_encoding_2d:
|
| 831 |
-
seq_length = seq.index(
|
| 832 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
| 833 |
if not gmask:
|
| 834 |
position_ids[seq_length:] = mask_position
|
|
@@ -886,14 +890,8 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 886 |
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
| 887 |
else:
|
| 888 |
past_key_values = tuple([None] * len(self.layers))
|
| 889 |
-
|
| 890 |
-
MASK, gMASK = 150000, 150001
|
| 891 |
-
mask_token = MASK if MASK in input_ids else gMASK
|
| 892 |
-
use_gmask = False if MASK in input_ids else gMASK
|
| 893 |
seq = input_ids[0].tolist()
|
| 894 |
|
| 895 |
-
mask_position = seq.index(mask_token)
|
| 896 |
-
|
| 897 |
if attention_mask is None:
|
| 898 |
attention_mask = self.get_masks(
|
| 899 |
seq=seq,
|
|
@@ -906,6 +904,11 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 906 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 907 |
|
| 908 |
if position_ids is None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 909 |
position_ids = self.get_position_ids(
|
| 910 |
seq=seq,
|
| 911 |
mask_position=mask_position,
|
|
@@ -1009,7 +1012,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1009 |
attention_mask = (attention_mask < 0.5).bool()
|
| 1010 |
|
| 1011 |
if self.position_encoding_2d:
|
| 1012 |
-
seq_length = seq.index(
|
| 1013 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
| 1014 |
if not gmask:
|
| 1015 |
position_ids[seq_length:] = mask_position
|
|
@@ -1047,7 +1050,7 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1047 |
|
| 1048 |
# only last token for input_ids if past is not None
|
| 1049 |
if past is not None or past_key_values is not None:
|
| 1050 |
-
context_length = seq.index(
|
| 1051 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
| 1052 |
if self.position_encoding_2d:
|
| 1053 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
|
@@ -1155,6 +1158,21 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1155 |
for layer_past in past
|
| 1156 |
)
|
| 1157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1158 |
@torch.no_grad()
|
| 1159 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
| 1160 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
@@ -1175,66 +1193,139 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1175 |
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
| 1176 |
input_ids = input_ids.to(self.device)
|
| 1177 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
| 1178 |
-
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0])
|
| 1179 |
response = tokenizer.decode(outputs)
|
| 1180 |
-
response =
|
| 1181 |
-
response = response.replace("[[训练时间]]", "2023年")
|
| 1182 |
history = history + [(query, response)]
|
| 1183 |
return response, history
|
| 1184 |
|
| 1185 |
@torch.no_grad()
|
| 1186 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1187 |
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1188 |
**kwargs,
|
| 1189 |
):
|
| 1190 |
-
|
| 1191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1192 |
|
| 1193 |
-
if
|
| 1194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1195 |
|
| 1196 |
-
|
|
|
|
|
|
|
| 1197 |
|
| 1198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1199 |
|
| 1200 |
-
|
|
|
|
|
|
|
|
|
|
| 1201 |
|
|
|
|
|
|
|
| 1202 |
while True:
|
| 1203 |
-
|
| 1204 |
-
|
| 1205 |
-
|
| 1206 |
-
|
| 1207 |
-
|
| 1208 |
-
|
| 1209 |
-
|
| 1210 |
-
|
| 1211 |
-
mask_token = MASK if MASK in output_seq else gMASK
|
| 1212 |
-
mask_position = output_seq.index(mask_token)
|
| 1213 |
-
bos_position = output_seq.index(bos)
|
| 1214 |
-
if eos in output_seq:
|
| 1215 |
-
eos_position = output_seq.index(eos)
|
| 1216 |
-
else:
|
| 1217 |
-
eos_position = len(output_seq)
|
| 1218 |
-
|
| 1219 |
-
return_seq = output_seq[:mask_position] + output_seq[bos_position + 1:eos_position] + output_seq[
|
| 1220 |
-
mask_position + 1:bos_position]
|
| 1221 |
-
max_length = max(max_length, len(return_seq))
|
| 1222 |
-
return_seqs.append(return_seq)
|
| 1223 |
-
|
| 1224 |
-
for i in range(output_ids.shape[0]):
|
| 1225 |
-
return_seqs[i] = [0] * (max_length - len(return_seqs[i])) + return_seqs[i] # padding
|
| 1226 |
-
if mask_token not in return_seqs[i]:
|
| 1227 |
-
stop = True
|
| 1228 |
-
|
| 1229 |
-
if stop:
|
| 1230 |
-
break
|
| 1231 |
|
| 1232 |
-
|
| 1233 |
-
return_seq += [bos]
|
| 1234 |
|
| 1235 |
-
|
|
|
|
|
|
|
| 1236 |
|
| 1237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1238 |
|
| 1239 |
def quantize(self, bits: int):
|
| 1240 |
from .quantization import quantize
|
|
|
|
| 3 |
import math
|
| 4 |
import copy
|
| 5 |
import os
|
| 6 |
+
import warnings
|
| 7 |
+
import re
|
| 8 |
+
import sys
|
| 9 |
|
| 10 |
import torch
|
| 11 |
import torch.utils.checkpoint
|
|
|
|
| 13 |
from torch import nn
|
| 14 |
from torch.nn import CrossEntropyLoss, LayerNorm
|
| 15 |
from torch.nn.utils import skip_init
|
| 16 |
+
from typing import Optional, Tuple, Union, List, Callable
|
| 17 |
|
| 18 |
from transformers.utils import (
|
| 19 |
add_code_sample_docstrings,
|
|
|
|
| 28 |
from transformers.modeling_utils import PreTrainedModel
|
| 29 |
from transformers.utils import logging
|
| 30 |
from transformers.generation.logits_process import LogitsProcessor
|
| 31 |
+
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig
|
| 32 |
|
| 33 |
from .configuration_chatglm import ChatGLMConfig
|
| 34 |
|
| 35 |
# flags required to enable jit fusion kernels
|
| 36 |
+
|
| 37 |
+
if sys.platform != 'darwin':
|
| 38 |
+
torch._C._jit_set_profiling_mode(False)
|
| 39 |
+
torch._C._jit_set_profiling_executor(False)
|
| 40 |
+
torch._C._jit_override_can_fuse_on_cpu(True)
|
| 41 |
+
torch._C._jit_override_can_fuse_on_gpu(True)
|
| 42 |
|
| 43 |
logger = logging.get_logger(__name__)
|
| 44 |
|
|
|
|
| 298 |
if not (attention_mask == 0).all():
|
| 299 |
# if auto-regressive, skip
|
| 300 |
attention_scores.masked_fill_(attention_mask, -10000.0)
|
| 301 |
+
dtype = attention_scores.dtype
|
| 302 |
attention_scores = attention_scores.float()
|
| 303 |
attention_scores = attention_scores * query_key_layer_scaling_coeff
|
| 304 |
|
|
|
|
| 818 |
return past_key_values
|
| 819 |
|
| 820 |
@staticmethod
|
| 821 |
+
def get_masks(self, seq, device):
|
| 822 |
+
context_length = seq.index(self.config.bos_token_id) + 1
|
| 823 |
|
| 824 |
attention_mask = torch.ones((1, len(seq), len(seq)), device=device)
|
| 825 |
attention_mask.tril_()
|
|
|
|
| 830 |
return attention_mask
|
| 831 |
|
| 832 |
def get_position_ids(self, seq, mask_position, device, gmask=False):
|
| 833 |
+
context_length = len(seq)
|
| 834 |
if self.position_encoding_2d:
|
| 835 |
+
seq_length = seq.index(self.config.bos_token_id)
|
| 836 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
| 837 |
if not gmask:
|
| 838 |
position_ids[seq_length:] = mask_position
|
|
|
|
| 890 |
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device)
|
| 891 |
else:
|
| 892 |
past_key_values = tuple([None] * len(self.layers))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
seq = input_ids[0].tolist()
|
| 894 |
|
|
|
|
|
|
|
| 895 |
if attention_mask is None:
|
| 896 |
attention_mask = self.get_masks(
|
| 897 |
seq=seq,
|
|
|
|
| 904 |
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=3)
|
| 905 |
|
| 906 |
if position_ids is None:
|
| 907 |
+
MASK, gMASK = 150000, 150001
|
| 908 |
+
mask_token = MASK if MASK in input_ids else gMASK
|
| 909 |
+
use_gmask = False if MASK in input_ids else gMASK
|
| 910 |
+
|
| 911 |
+
mask_position = seq.index(mask_token)
|
| 912 |
position_ids = self.get_position_ids(
|
| 913 |
seq=seq,
|
| 914 |
mask_position=mask_position,
|
|
|
|
| 1012 |
attention_mask = (attention_mask < 0.5).bool()
|
| 1013 |
|
| 1014 |
if self.position_encoding_2d:
|
| 1015 |
+
seq_length = seq.index(self.config.bos_token_id)
|
| 1016 |
position_ids = torch.arange(context_length, dtype=torch.long, device=device)
|
| 1017 |
if not gmask:
|
| 1018 |
position_ids[seq_length:] = mask_position
|
|
|
|
| 1050 |
|
| 1051 |
# only last token for input_ids if past is not None
|
| 1052 |
if past is not None or past_key_values is not None:
|
| 1053 |
+
context_length = seq.index(self.config.bos_token_id)
|
| 1054 |
last_token = input_ids[:, -1].unsqueeze(-1)
|
| 1055 |
if self.position_encoding_2d:
|
| 1056 |
position_ids = torch.tensor([[[mask_position], [len(seq) - context_length]]], dtype=torch.long,
|
|
|
|
| 1158 |
for layer_past in past
|
| 1159 |
)
|
| 1160 |
|
| 1161 |
+
def process_response(self, response):
|
| 1162 |
+
response = response.strip()
|
| 1163 |
+
response = response.replace("[[训练时间]]", "2023年")
|
| 1164 |
+
punkts = [
|
| 1165 |
+
[",", ","],
|
| 1166 |
+
["!", "!"],
|
| 1167 |
+
[":", ":"],
|
| 1168 |
+
[";", ";"],
|
| 1169 |
+
["\?", "?"],
|
| 1170 |
+
]
|
| 1171 |
+
for item in punkts:
|
| 1172 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
| 1173 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
| 1174 |
+
return response
|
| 1175 |
+
|
| 1176 |
@torch.no_grad()
|
| 1177 |
def chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048, num_beams=1,
|
| 1178 |
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
|
|
|
| 1193 |
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
| 1194 |
input_ids = input_ids.to(self.device)
|
| 1195 |
outputs = self.generate(**input_ids, **gen_kwargs)
|
| 1196 |
+
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1197 |
response = tokenizer.decode(outputs)
|
| 1198 |
+
response = self.process_response(response)
|
|
|
|
| 1199 |
history = history + [(query, response)]
|
| 1200 |
return response, history
|
| 1201 |
|
| 1202 |
@torch.no_grad()
|
| 1203 |
+
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
|
| 1204 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
| 1205 |
+
if history is None:
|
| 1206 |
+
history = []
|
| 1207 |
+
if logits_processor is None:
|
| 1208 |
+
logits_processor = LogitsProcessorList()
|
| 1209 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1210 |
+
gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
|
| 1211 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1212 |
+
if not history:
|
| 1213 |
+
prompt = query
|
| 1214 |
+
else:
|
| 1215 |
+
prompt = ""
|
| 1216 |
+
for i, (old_query, response) in enumerate(history):
|
| 1217 |
+
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
| 1218 |
+
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
| 1219 |
+
input_ids = tokenizer([prompt], return_tensors="pt", padding=True)
|
| 1220 |
+
input_ids = input_ids.to(self.device)
|
| 1221 |
+
for outputs in self.stream_generate(**input_ids, **gen_kwargs):
|
| 1222 |
+
outputs = outputs.tolist()[0][len(input_ids["input_ids"][0]):]
|
| 1223 |
+
response = tokenizer.decode(outputs)
|
| 1224 |
+
response = self.process_response(response)
|
| 1225 |
+
new_history = history + [(query, response)]
|
| 1226 |
+
yield response, new_history
|
| 1227 |
+
|
| 1228 |
+
@torch.no_grad()
|
| 1229 |
+
def stream_generate(
|
| 1230 |
self,
|
| 1231 |
+
input_ids,
|
| 1232 |
+
generation_config: Optional[GenerationConfig] = None,
|
| 1233 |
+
logits_processor: Optional[LogitsProcessorList] = None,
|
| 1234 |
+
stopping_criteria: Optional[StoppingCriteriaList] = None,
|
| 1235 |
+
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
|
| 1236 |
**kwargs,
|
| 1237 |
):
|
| 1238 |
+
batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
|
| 1239 |
+
|
| 1240 |
+
if generation_config is None:
|
| 1241 |
+
generation_config = self.generation_config
|
| 1242 |
+
generation_config = copy.deepcopy(generation_config)
|
| 1243 |
+
model_kwargs = generation_config.update(**kwargs)
|
| 1244 |
+
bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
|
| 1245 |
+
|
| 1246 |
+
if isinstance(eos_token_id, int):
|
| 1247 |
+
eos_token_id = [eos_token_id]
|
| 1248 |
+
|
| 1249 |
+
has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
|
| 1250 |
+
if has_default_max_length and generation_config.max_new_tokens is None:
|
| 1251 |
+
warnings.warn(
|
| 1252 |
+
f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
|
| 1253 |
+
"This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
|
| 1254 |
+
" recommend using `max_new_tokens` to control the maximum length of the generation.",
|
| 1255 |
+
UserWarning,
|
| 1256 |
+
)
|
| 1257 |
+
elif generation_config.max_new_tokens is not None:
|
| 1258 |
+
generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
|
| 1259 |
+
if not has_default_max_length:
|
| 1260 |
+
logger.warn(
|
| 1261 |
+
f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
|
| 1262 |
+
f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
|
| 1263 |
+
"Please refer to the documentation for more information. "
|
| 1264 |
+
"(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
|
| 1265 |
+
UserWarning,
|
| 1266 |
+
)
|
| 1267 |
|
| 1268 |
+
if input_ids_seq_length >= generation_config.max_length:
|
| 1269 |
+
input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
|
| 1270 |
+
logger.warning(
|
| 1271 |
+
f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
|
| 1272 |
+
f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
|
| 1273 |
+
" increasing `max_new_tokens`."
|
| 1274 |
+
)
|
| 1275 |
|
| 1276 |
+
# 2. Set generation parameters if not already defined
|
| 1277 |
+
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
|
| 1278 |
+
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
|
| 1279 |
|
| 1280 |
+
logits_processor = self._get_logits_processor(
|
| 1281 |
+
generation_config=generation_config,
|
| 1282 |
+
input_ids_seq_length=input_ids_seq_length,
|
| 1283 |
+
encoder_input_ids=input_ids,
|
| 1284 |
+
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
|
| 1285 |
+
logits_processor=logits_processor,
|
| 1286 |
+
)
|
| 1287 |
|
| 1288 |
+
stopping_criteria = self._get_stopping_criteria(
|
| 1289 |
+
generation_config=generation_config, stopping_criteria=stopping_criteria
|
| 1290 |
+
)
|
| 1291 |
+
logits_warper = self._get_logits_warper(generation_config)
|
| 1292 |
|
| 1293 |
+
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
| 1294 |
+
scores = None
|
| 1295 |
while True:
|
| 1296 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
| 1297 |
+
# forward pass to get next token
|
| 1298 |
+
outputs = self(
|
| 1299 |
+
**model_inputs,
|
| 1300 |
+
return_dict=True,
|
| 1301 |
+
output_attentions=False,
|
| 1302 |
+
output_hidden_states=False,
|
| 1303 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1304 |
|
| 1305 |
+
next_token_logits = outputs.logits[:, -1, :]
|
|
|
|
| 1306 |
|
| 1307 |
+
# pre-process distribution
|
| 1308 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
| 1309 |
+
next_token_scores = logits_warper(input_ids, next_token_scores)
|
| 1310 |
|
| 1311 |
+
# sample
|
| 1312 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
| 1313 |
+
if generation_config.do_sample:
|
| 1314 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
| 1315 |
+
else:
|
| 1316 |
+
next_tokens = torch.argmax(probs, dim=-1)
|
| 1317 |
+
|
| 1318 |
+
# update generated ids, model inputs, and length for next step
|
| 1319 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
| 1320 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
| 1321 |
+
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
|
| 1322 |
+
)
|
| 1323 |
+
unfinished_sequences = unfinished_sequences.mul((sum(next_tokens != i for i in eos_token_id)).long())
|
| 1324 |
+
|
| 1325 |
+
# stop when each sentence is finished, or if we exceed the maximum length
|
| 1326 |
+
if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
|
| 1327 |
+
break
|
| 1328 |
+
yield input_ids
|
| 1329 |
|
| 1330 |
def quantize(self, bits: int):
|
| 1331 |
from .quantization import quantize
|
tokenization_chatglm.py
CHANGED
|
@@ -299,7 +299,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
| 299 |
"""
|
| 300 |
if os.path.isdir(save_directory):
|
| 301 |
vocab_file = os.path.join(
|
| 302 |
-
save_directory,
|
| 303 |
)
|
| 304 |
else:
|
| 305 |
vocab_file = save_directory
|
|
|
|
| 299 |
"""
|
| 300 |
if os.path.isdir(save_directory):
|
| 301 |
vocab_file = os.path.join(
|
| 302 |
+
save_directory, self.vocab_files_names["vocab_file"]
|
| 303 |
)
|
| 304 |
else:
|
| 305 |
vocab_file = save_directory
|