Spaces:
Runtime error
Runtime error
Update kimi_vl/serve/inference.py
Browse files- kimi_vl/serve/inference.py +32 -19
kimi_vl/serve/inference.py
CHANGED
@@ -4,7 +4,7 @@ from threading import Thread
|
|
4 |
from typing import List, Optional
|
5 |
|
6 |
import torch
|
7 |
-
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
AutoProcessor,
|
@@ -73,6 +73,7 @@ def preprocess(
|
|
73 |
messages: list[dict],
|
74 |
processor,
|
75 |
sft_format: Optional[str] = "kimi-vl",
|
|
|
76 |
):
|
77 |
"""
|
78 |
Build messages from the conversations and images.
|
@@ -83,28 +84,38 @@ def preprocess(
|
|
83 |
|
84 |
# get texts from conversations
|
85 |
converstion = get_conv_template(sft_format)
|
86 |
-
# only use the last
|
87 |
-
latest_messages = messages[-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
for mid, message in enumerate(latest_messages):
|
89 |
if message["role"] == converstion.roles[0] or message["role"] == "user":
|
90 |
record = {
|
91 |
"role": message["role"],
|
92 |
"content": [],
|
93 |
}
|
94 |
-
if "images" in message:
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
per_round_images = message["images"]
|
96 |
-
if len(per_round_images) > 2:
|
97 |
-
per_round_images = per_round_images[-2:]
|
98 |
-
print(f"Only use the last 2 images in the {mid}-th round")
|
99 |
-
|
100 |
-
images.extend(per_round_images)
|
101 |
for image in per_round_images:
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
"image": image,
|
106 |
-
}
|
107 |
-
)
|
108 |
if 'content' in message:
|
109 |
record["content"].append(
|
110 |
{
|
@@ -113,6 +124,7 @@ def preprocess(
|
|
113 |
}
|
114 |
)
|
115 |
results.append(record)
|
|
|
116 |
elif message["role"] == converstion.roles[1] or message["role"] == "assistant":
|
117 |
formatted_answer = message["content"].strip()
|
118 |
# ◁think▷用户说了“你好”,这是一个非常简单的问候,通常用于开启对话。我需要判断用户的意图。可能性一:用户只是礼貌性地打招呼,想要开启一段对话;可能性二:用户可能有更具体的需求,比如询问我的功能、功能或者需要帮助。由于用户没有提供更多信息,我需要保持开放,同时引导用户进一步说明他们的需求。
|
@@ -137,7 +149,7 @@ def preprocess(
|
|
137 |
formatted_answer.count(processor.image_token) == 0
|
138 |
), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
|
139 |
converstion.append_message(converstion.roles[1], formatted_answer)
|
140 |
-
|
141 |
text = processor.apply_chat_template(results, add_generation_prompt=True)
|
142 |
print(f"raw text = {text}")
|
143 |
if len(images) == 0:
|
@@ -153,11 +165,13 @@ def preprocess(
|
|
153 |
return inputs
|
154 |
|
155 |
|
156 |
-
|
|
|
157 |
def kimi_vl_generate(
|
158 |
model: torch.nn.Module,
|
159 |
processor: AutoProcessor,
|
160 |
conversations: list[Conversation],
|
|
|
161 |
stop_words: list,
|
162 |
max_length: int = 256,
|
163 |
temperature: float = 1.0,
|
@@ -166,7 +180,7 @@ def kimi_vl_generate(
|
|
166 |
):
|
167 |
# convert conversation to inputs
|
168 |
print(f"conversations = {conversations}")
|
169 |
-
inputs = preprocess(conversations, processor=processor)
|
170 |
inputs = inputs.to(model.device)
|
171 |
|
172 |
return generate(
|
@@ -180,7 +194,6 @@ def kimi_vl_generate(
|
|
180 |
chunk_size=chunk_size,
|
181 |
)
|
182 |
|
183 |
-
|
184 |
def generate(
|
185 |
model,
|
186 |
processor,
|
|
|
4 |
from typing import List, Optional
|
5 |
|
6 |
import torch
|
7 |
+
|
8 |
from transformers import (
|
9 |
AutoModelForCausalLM,
|
10 |
AutoProcessor,
|
|
|
73 |
messages: list[dict],
|
74 |
processor,
|
75 |
sft_format: Optional[str] = "kimi-vl",
|
76 |
+
override_system_prompt = "",
|
77 |
):
|
78 |
"""
|
79 |
Build messages from the conversations and images.
|
|
|
84 |
|
85 |
# get texts from conversations
|
86 |
converstion = get_conv_template(sft_format)
|
87 |
+
# only use the last 10 round of messages
|
88 |
+
latest_messages = messages[-10:]
|
89 |
+
results.append(
|
90 |
+
{
|
91 |
+
"role": "system",
|
92 |
+
"content": [
|
93 |
+
{
|
94 |
+
"type": "text",
|
95 |
+
"text": override_system_prompt if override_system_prompt else converstion.system_message,
|
96 |
+
}
|
97 |
+
],
|
98 |
+
}
|
99 |
+
)
|
100 |
+
print("The actual system prompt for generation:", override_system_prompt if override_system_prompt else converstion.system_message)
|
101 |
for mid, message in enumerate(latest_messages):
|
102 |
if message["role"] == converstion.roles[0] or message["role"] == "user":
|
103 |
record = {
|
104 |
"role": message["role"],
|
105 |
"content": [],
|
106 |
}
|
107 |
+
if "timestamps" in message and "images" in message and message["timestamps"] is not None:
|
108 |
+
per_round_images, per_round_timestamps = message["images"], message["timestamps"]
|
109 |
+
for image, timestamp in zip(per_round_images, per_round_timestamps):
|
110 |
+
images.append(image)
|
111 |
+
record["content"].append({"type": "text", "text": f"{int(timestamp)//3600:02d}:{(int(timestamp)//60-60*(int(timestamp)//3600)):02d}:{int(timestamp)%60:02d}"})
|
112 |
+
record["content"].append({"type": "image", "image": image})
|
113 |
+
elif "images" in message:
|
114 |
per_round_images = message["images"]
|
|
|
|
|
|
|
|
|
|
|
115 |
for image in per_round_images:
|
116 |
+
images.append(image)
|
117 |
+
record["content"].append({"type": "image", "image": image})
|
118 |
+
|
|
|
|
|
|
|
119 |
if 'content' in message:
|
120 |
record["content"].append(
|
121 |
{
|
|
|
124 |
}
|
125 |
)
|
126 |
results.append(record)
|
127 |
+
|
128 |
elif message["role"] == converstion.roles[1] or message["role"] == "assistant":
|
129 |
formatted_answer = message["content"].strip()
|
130 |
# ◁think▷用户说了“你好”,这是一个非常简单的问候,通常用于开启对话。我需要判断用户的意图。可能性一:用户只是礼貌性地打招呼,想要开启一段对话;可能性二:用户可能有更具体的需求,比如询问我的功能、功能或者需要帮助。由于用户没有提供更多信息,我需要保持开放,同时引导用户进一步说明他们的需求。
|
|
|
149 |
formatted_answer.count(processor.image_token) == 0
|
150 |
), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
|
151 |
converstion.append_message(converstion.roles[1], formatted_answer)
|
152 |
+
|
153 |
text = processor.apply_chat_template(results, add_generation_prompt=True)
|
154 |
print(f"raw text = {text}")
|
155 |
if len(images) == 0:
|
|
|
165 |
return inputs
|
166 |
|
167 |
|
168 |
+
@torch.no_grad()
|
169 |
+
@torch.inference_mode()
|
170 |
def kimi_vl_generate(
|
171 |
model: torch.nn.Module,
|
172 |
processor: AutoProcessor,
|
173 |
conversations: list[Conversation],
|
174 |
+
override_system_prompt,
|
175 |
stop_words: list,
|
176 |
max_length: int = 256,
|
177 |
temperature: float = 1.0,
|
|
|
180 |
):
|
181 |
# convert conversation to inputs
|
182 |
print(f"conversations = {conversations}")
|
183 |
+
inputs = preprocess(conversations, processor=processor, override_system_prompt=override_system_prompt)
|
184 |
inputs = inputs.to(model.device)
|
185 |
|
186 |
return generate(
|
|
|
194 |
chunk_size=chunk_size,
|
195 |
)
|
196 |
|
|
|
197 |
def generate(
|
198 |
model,
|
199 |
processor,
|