支持batch形式的对话
Browse files— 添加chat_batch方法,可以同时进行多次多伦对话

- modeling_chatglm.py +30 -13
modeling_chatglm.py
CHANGED
|
@@ -721,7 +721,6 @@ CHATGLM_6B_START_DOCSTRING = r"""
|
|
| 721 |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
| 722 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
| 723 |
usage and behavior.
|
| 724 |
-
|
| 725 |
Parameters:
|
| 726 |
config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
|
| 727 |
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
@@ -732,37 +731,28 @@ CHATGLM_6B_INPUTS_DOCSTRING = r"""
|
|
| 732 |
Args:
|
| 733 |
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 734 |
Indices of input sequence tokens in the vocabulary.
|
| 735 |
-
|
| 736 |
Indices can be obtained using [`ChatGLM6BTokenizer`].
|
| 737 |
See [`PreTrainedTokenizer.encode`] and
|
| 738 |
[`PreTrainedTokenizer.__call__`] for details.
|
| 739 |
-
|
| 740 |
[What are input IDs?](../glossary#input-ids)
|
| 741 |
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 742 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 743 |
-
|
| 744 |
- 1 for tokens that are **not masked**,
|
| 745 |
- 0 for tokens that are **masked**.
|
| 746 |
-
|
| 747 |
[What are attention masks?](../glossary#attention-mask)
|
| 748 |
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 749 |
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
|
| 750 |
-
|
| 751 |
- 0 corresponds to a *sentence A* token,
|
| 752 |
- 1 corresponds to a *sentence B* token.
|
| 753 |
-
|
| 754 |
[What are token type IDs?](../glossary#token-type-ids)
|
| 755 |
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 756 |
Indices of positions of each input sequence tokens in the position embeddings.
|
| 757 |
Selected in the range `[0, config.max_position_embeddings - 1]`.
|
| 758 |
-
|
| 759 |
[What are position IDs?](../glossary#position-ids)
|
| 760 |
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 761 |
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
| 762 |
-
|
| 763 |
- 1 indicates the head is **not masked**,
|
| 764 |
- 0 indicates the head is **masked**.
|
| 765 |
-
|
| 766 |
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 767 |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 768 |
This is useful if you want more control over how to convert *input_ids* indices into associated vectors
|
|
@@ -784,13 +774,11 @@ CHATGLM_6B_INPUTS_DOCSTRING = r"""
|
|
| 784 |
)
|
| 785 |
class ChatGLMModel(ChatGLMPreTrainedModel):
|
| 786 |
"""
|
| 787 |
-
|
| 788 |
The model can behave as an encoder (with only self-attention) as well
|
| 789 |
as a decoder, in which case a layer of cross-attention is added between
|
| 790 |
the self-attention layers, following the architecture described in [Attention is
|
| 791 |
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
|
| 792 |
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
| 793 |
-
|
| 794 |
To behave as an decoder the model needs to be initialized with the
|
| 795 |
`is_decoder` argument of the configuration set to `True`.
|
| 796 |
To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
|
|
@@ -1237,7 +1225,6 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1237 |
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
| 1238 |
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 1239 |
beam_idx at every generation step.
|
| 1240 |
-
|
| 1241 |
Output shares the same memory storage as `past`.
|
| 1242 |
"""
|
| 1243 |
return tuple(
|
|
@@ -1288,6 +1275,36 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1288 |
response = self.process_response(response)
|
| 1289 |
history = history + [(query, response)]
|
| 1290 |
return response, history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1291 |
|
| 1292 |
@torch.no_grad()
|
| 1293 |
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
|
|
|
|
| 721 |
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) sub-class.
|
| 722 |
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
| 723 |
usage and behavior.
|
|
|
|
| 724 |
Parameters:
|
| 725 |
config ([`~ChatGLM6BConfig`]): Model configuration class with all the parameters of the model.
|
| 726 |
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
|
|
|
| 731 |
Args:
|
| 732 |
input_ids (`torch.LongTensor` of shape `({0})`):
|
| 733 |
Indices of input sequence tokens in the vocabulary.
|
|
|
|
| 734 |
Indices can be obtained using [`ChatGLM6BTokenizer`].
|
| 735 |
See [`PreTrainedTokenizer.encode`] and
|
| 736 |
[`PreTrainedTokenizer.__call__`] for details.
|
|
|
|
| 737 |
[What are input IDs?](../glossary#input-ids)
|
| 738 |
attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
|
| 739 |
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
|
|
|
| 740 |
- 1 for tokens that are **not masked**,
|
| 741 |
- 0 for tokens that are **masked**.
|
|
|
|
| 742 |
[What are attention masks?](../glossary#attention-mask)
|
| 743 |
token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 744 |
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, 1]`:
|
|
|
|
| 745 |
- 0 corresponds to a *sentence A* token,
|
| 746 |
- 1 corresponds to a *sentence B* token.
|
|
|
|
| 747 |
[What are token type IDs?](../glossary#token-type-ids)
|
| 748 |
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
|
| 749 |
Indices of positions of each input sequence tokens in the position embeddings.
|
| 750 |
Selected in the range `[0, config.max_position_embeddings - 1]`.
|
|
|
|
| 751 |
[What are position IDs?](../glossary#position-ids)
|
| 752 |
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
|
| 753 |
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
|
|
|
|
| 754 |
- 1 indicates the head is **not masked**,
|
| 755 |
- 0 indicates the head is **masked**.
|
|
|
|
| 756 |
inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
|
| 757 |
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
|
| 758 |
This is useful if you want more control over how to convert *input_ids* indices into associated vectors
|
|
|
|
| 774 |
)
|
| 775 |
class ChatGLMModel(ChatGLMPreTrainedModel):
|
| 776 |
"""
|
|
|
|
| 777 |
The model can behave as an encoder (with only self-attention) as well
|
| 778 |
as a decoder, in which case a layer of cross-attention is added between
|
| 779 |
the self-attention layers, following the architecture described in [Attention is
|
| 780 |
all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani,
|
| 781 |
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
|
|
|
| 782 |
To behave as an decoder the model needs to be initialized with the
|
| 783 |
`is_decoder` argument of the configuration set to `True`.
|
| 784 |
To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder`
|
|
|
|
| 1225 |
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
| 1226 |
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
| 1227 |
beam_idx at every generation step.
|
|
|
|
| 1228 |
Output shares the same memory storage as `past`.
|
| 1229 |
"""
|
| 1230 |
return tuple(
|
|
|
|
| 1275 |
response = self.process_response(response)
|
| 1276 |
history = history + [(query, response)]
|
| 1277 |
return response, history
|
| 1278 |
+
|
| 1279 |
+
@torch.no_grad()
|
| 1280 |
+
def chat_batch(self, tokenizer, querys=List[str], historys=List[List[Tuple[str, str]]], max_length: int = 2048, num_beams=1,
|
| 1281 |
+
do_sample=True, top_p=0.7, temperature=0.95, logits_processor=None, **kwargs):
|
| 1282 |
+
responses = []
|
| 1283 |
+
prompts = []
|
| 1284 |
+
for query, history in zip(querys, historys):
|
| 1285 |
+
if history is None:
|
| 1286 |
+
history = []
|
| 1287 |
+
if logits_processor is None:
|
| 1288 |
+
logits_processor = LogitsProcessorList()
|
| 1289 |
+
logits_processor.append(InvalidScoreLogitsProcessor())
|
| 1290 |
+
gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
|
| 1291 |
+
"temperature": temperature, "logits_processor": logits_processor, **kwargs}
|
| 1292 |
+
if not history:
|
| 1293 |
+
prompt = query
|
| 1294 |
+
else:
|
| 1295 |
+
prompt = ""
|
| 1296 |
+
for i, (old_query, response) in enumerate(history):
|
| 1297 |
+
prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response)
|
| 1298 |
+
prompt += "[Round {}]\n问:{}\n答:".format(len(history), query)
|
| 1299 |
+
prompts.append(prompt)
|
| 1300 |
+
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
|
| 1301 |
+
inputs = inputs.to(self.device)
|
| 1302 |
+
outputs = self.generate(**inputs, **gen_kwargs)
|
| 1303 |
+
outputs = outputs.tolist()
|
| 1304 |
+
outputs = [x[len(inputs["input_ids"][0]):] for x in outputs]
|
| 1305 |
+
responses = [tokenizer.decode(output) for output in outputs]
|
| 1306 |
+
responses = [self.process_response(response) for response in responses]
|
| 1307 |
+
return responses
|
| 1308 |
|
| 1309 |
@torch.no_grad()
|
| 1310 |
def stream_chat(self, tokenizer, query: str, history: List[Tuple[str, str]] = None, max_length: int = 2048,
|