teowu commited on
Commit
376de5d
·
verified ·
1 Parent(s): f289fe9

Update kimi_vl/serve/inference.py

Browse files
Files changed (1) hide show
  1. kimi_vl/serve/inference.py +32 -19
kimi_vl/serve/inference.py CHANGED
@@ -4,7 +4,7 @@ from threading import Thread
4
  from typing import List, Optional
5
 
6
  import torch
7
- import spaces
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoProcessor,
@@ -73,6 +73,7 @@ def preprocess(
73
  messages: list[dict],
74
  processor,
75
  sft_format: Optional[str] = "kimi-vl",
 
76
  ):
77
  """
78
  Build messages from the conversations and images.
@@ -83,28 +84,38 @@ def preprocess(
83
 
84
  # get texts from conversations
85
  converstion = get_conv_template(sft_format)
86
- # only use the last 3 round of messages
87
- latest_messages = messages[-3:]
 
 
 
 
 
 
 
 
 
 
 
 
88
  for mid, message in enumerate(latest_messages):
89
  if message["role"] == converstion.roles[0] or message["role"] == "user":
90
  record = {
91
  "role": message["role"],
92
  "content": [],
93
  }
94
- if "images" in message:
 
 
 
 
 
 
95
  per_round_images = message["images"]
96
- if len(per_round_images) > 2:
97
- per_round_images = per_round_images[-2:]
98
- print(f"Only use the last 2 images in the {mid}-th round")
99
-
100
- images.extend(per_round_images)
101
  for image in per_round_images:
102
- record["content"].append(
103
- {
104
- "type": "image",
105
- "image": image,
106
- }
107
- )
108
  if 'content' in message:
109
  record["content"].append(
110
  {
@@ -113,6 +124,7 @@ def preprocess(
113
  }
114
  )
115
  results.append(record)
 
116
  elif message["role"] == converstion.roles[1] or message["role"] == "assistant":
117
  formatted_answer = message["content"].strip()
118
  # ◁think▷用户说了“你好”,这是一个非常简单的问候,通常用于开启对话。我需要判断用户的意图。可能性一:用户只是礼貌性地打招呼,想要开启一段对话;可能性二:用户可能有更具体的需求,比如询问我的功能、功能或者需要帮助。由于用户没有提供更多信息,我需要保持开放,同时引导用户进一步说明他们的需求。
@@ -137,7 +149,7 @@ def preprocess(
137
  formatted_answer.count(processor.image_token) == 0
138
  ), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
139
  converstion.append_message(converstion.roles[1], formatted_answer)
140
-
141
  text = processor.apply_chat_template(results, add_generation_prompt=True)
142
  print(f"raw text = {text}")
143
  if len(images) == 0:
@@ -153,11 +165,13 @@ def preprocess(
153
  return inputs
154
 
155
 
156
-
 
157
  def kimi_vl_generate(
158
  model: torch.nn.Module,
159
  processor: AutoProcessor,
160
  conversations: list[Conversation],
 
161
  stop_words: list,
162
  max_length: int = 256,
163
  temperature: float = 1.0,
@@ -166,7 +180,7 @@ def kimi_vl_generate(
166
  ):
167
  # convert conversation to inputs
168
  print(f"conversations = {conversations}")
169
- inputs = preprocess(conversations, processor=processor)
170
  inputs = inputs.to(model.device)
171
 
172
  return generate(
@@ -180,7 +194,6 @@ def kimi_vl_generate(
180
  chunk_size=chunk_size,
181
  )
182
 
183
-
184
  def generate(
185
  model,
186
  processor,
 
4
  from typing import List, Optional
5
 
6
  import torch
7
+
8
  from transformers import (
9
  AutoModelForCausalLM,
10
  AutoProcessor,
 
73
  messages: list[dict],
74
  processor,
75
  sft_format: Optional[str] = "kimi-vl",
76
+ override_system_prompt = "",
77
  ):
78
  """
79
  Build messages from the conversations and images.
 
84
 
85
  # get texts from conversations
86
  converstion = get_conv_template(sft_format)
87
+ # only use the last 10 round of messages
88
+ latest_messages = messages[-10:]
89
+ results.append(
90
+ {
91
+ "role": "system",
92
+ "content": [
93
+ {
94
+ "type": "text",
95
+ "text": override_system_prompt if override_system_prompt else converstion.system_message,
96
+ }
97
+ ],
98
+ }
99
+ )
100
+ print("The actual system prompt for generation:", override_system_prompt if override_system_prompt else converstion.system_message)
101
  for mid, message in enumerate(latest_messages):
102
  if message["role"] == converstion.roles[0] or message["role"] == "user":
103
  record = {
104
  "role": message["role"],
105
  "content": [],
106
  }
107
+ if "timestamps" in message and "images" in message and message["timestamps"] is not None:
108
+ per_round_images, per_round_timestamps = message["images"], message["timestamps"]
109
+ for image, timestamp in zip(per_round_images, per_round_timestamps):
110
+ images.append(image)
111
+ record["content"].append({"type": "text", "text": f"{int(timestamp)//3600:02d}:{(int(timestamp)//60-60*(int(timestamp)//3600)):02d}:{int(timestamp)%60:02d}"})
112
+ record["content"].append({"type": "image", "image": image})
113
+ elif "images" in message:
114
  per_round_images = message["images"]
 
 
 
 
 
115
  for image in per_round_images:
116
+ images.append(image)
117
+ record["content"].append({"type": "image", "image": image})
118
+
 
 
 
119
  if 'content' in message:
120
  record["content"].append(
121
  {
 
124
  }
125
  )
126
  results.append(record)
127
+
128
  elif message["role"] == converstion.roles[1] or message["role"] == "assistant":
129
  formatted_answer = message["content"].strip()
130
  # ◁think▷用户说了“你好”,这是一个非常简单的问候,通常用于开启对话。我需要判断用户的意图。可能性一:用户只是礼貌性地打招呼,想要开启一段对话;可能性二:用户可能有更具体的需求,比如询问我的功能、功能或者需要帮助。由于用户没有提供更多信息,我需要保持开放,同时引导用户进一步说明他们的需求。
 
149
  formatted_answer.count(processor.image_token) == 0
150
  ), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
151
  converstion.append_message(converstion.roles[1], formatted_answer)
152
+
153
  text = processor.apply_chat_template(results, add_generation_prompt=True)
154
  print(f"raw text = {text}")
155
  if len(images) == 0:
 
165
  return inputs
166
 
167
 
168
+ @torch.no_grad()
169
+ @torch.inference_mode()
170
  def kimi_vl_generate(
171
  model: torch.nn.Module,
172
  processor: AutoProcessor,
173
  conversations: list[Conversation],
174
+ override_system_prompt,
175
  stop_words: list,
176
  max_length: int = 256,
177
  temperature: float = 1.0,
 
180
  ):
181
  # convert conversation to inputs
182
  print(f"conversations = {conversations}")
183
+ inputs = preprocess(conversations, processor=processor, override_system_prompt=override_system_prompt)
184
  inputs = inputs.to(model.device)
185
 
186
  return generate(
 
194
  chunk_size=chunk_size,
195
  )
196
 
 
197
  def generate(
198
  model,
199
  processor,