Spaces:
Runtime error
Runtime error
Update kimi_vl/serve/inference.py
Browse files- kimi_vl/serve/inference.py +32 -19
kimi_vl/serve/inference.py
CHANGED
|
@@ -4,7 +4,7 @@ from threading import Thread
|
|
| 4 |
from typing import List, Optional
|
| 5 |
|
| 6 |
import torch
|
| 7 |
-
|
| 8 |
from transformers import (
|
| 9 |
AutoModelForCausalLM,
|
| 10 |
AutoProcessor,
|
|
@@ -73,6 +73,7 @@ def preprocess(
|
|
| 73 |
messages: list[dict],
|
| 74 |
processor,
|
| 75 |
sft_format: Optional[str] = "kimi-vl",
|
|
|
|
| 76 |
):
|
| 77 |
"""
|
| 78 |
Build messages from the conversations and images.
|
|
@@ -83,28 +84,38 @@ def preprocess(
|
|
| 83 |
|
| 84 |
# get texts from conversations
|
| 85 |
converstion = get_conv_template(sft_format)
|
| 86 |
-
# only use the last
|
| 87 |
-
latest_messages = messages[-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
for mid, message in enumerate(latest_messages):
|
| 89 |
if message["role"] == converstion.roles[0] or message["role"] == "user":
|
| 90 |
record = {
|
| 91 |
"role": message["role"],
|
| 92 |
"content": [],
|
| 93 |
}
|
| 94 |
-
if "images" in message:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
per_round_images = message["images"]
|
| 96 |
-
if len(per_round_images) > 2:
|
| 97 |
-
per_round_images = per_round_images[-2:]
|
| 98 |
-
print(f"Only use the last 2 images in the {mid}-th round")
|
| 99 |
-
|
| 100 |
-
images.extend(per_round_images)
|
| 101 |
for image in per_round_images:
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
"image": image,
|
| 106 |
-
}
|
| 107 |
-
)
|
| 108 |
if 'content' in message:
|
| 109 |
record["content"].append(
|
| 110 |
{
|
|
@@ -113,6 +124,7 @@ def preprocess(
|
|
| 113 |
}
|
| 114 |
)
|
| 115 |
results.append(record)
|
|
|
|
| 116 |
elif message["role"] == converstion.roles[1] or message["role"] == "assistant":
|
| 117 |
formatted_answer = message["content"].strip()
|
| 118 |
# ◁think▷用户说了“你好”,这是一个非常简单的问候,通常用于开启对话。我需要判断用户的意图。可能性一:用户只是礼貌性地打招呼,想要开启一段对话;可能性二:用户可能有更具体的需求,比如询问我的功能、功能或者需要帮助。由于用户没有提供更多信息,我需要保持开放,同时引导用户进一步说明他们的需求。
|
|
@@ -137,7 +149,7 @@ def preprocess(
|
|
| 137 |
formatted_answer.count(processor.image_token) == 0
|
| 138 |
), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
|
| 139 |
converstion.append_message(converstion.roles[1], formatted_answer)
|
| 140 |
-
|
| 141 |
text = processor.apply_chat_template(results, add_generation_prompt=True)
|
| 142 |
print(f"raw text = {text}")
|
| 143 |
if len(images) == 0:
|
|
@@ -153,11 +165,13 @@ def preprocess(
|
|
| 153 |
return inputs
|
| 154 |
|
| 155 |
|
| 156 |
-
|
|
|
|
| 157 |
def kimi_vl_generate(
|
| 158 |
model: torch.nn.Module,
|
| 159 |
processor: AutoProcessor,
|
| 160 |
conversations: list[Conversation],
|
|
|
|
| 161 |
stop_words: list,
|
| 162 |
max_length: int = 256,
|
| 163 |
temperature: float = 1.0,
|
|
@@ -166,7 +180,7 @@ def kimi_vl_generate(
|
|
| 166 |
):
|
| 167 |
# convert conversation to inputs
|
| 168 |
print(f"conversations = {conversations}")
|
| 169 |
-
inputs = preprocess(conversations, processor=processor)
|
| 170 |
inputs = inputs.to(model.device)
|
| 171 |
|
| 172 |
return generate(
|
|
@@ -180,7 +194,6 @@ def kimi_vl_generate(
|
|
| 180 |
chunk_size=chunk_size,
|
| 181 |
)
|
| 182 |
|
| 183 |
-
|
| 184 |
def generate(
|
| 185 |
model,
|
| 186 |
processor,
|
|
|
|
| 4 |
from typing import List, Optional
|
| 5 |
|
| 6 |
import torch
|
| 7 |
+
|
| 8 |
from transformers import (
|
| 9 |
AutoModelForCausalLM,
|
| 10 |
AutoProcessor,
|
|
|
|
| 73 |
messages: list[dict],
|
| 74 |
processor,
|
| 75 |
sft_format: Optional[str] = "kimi-vl",
|
| 76 |
+
override_system_prompt = "",
|
| 77 |
):
|
| 78 |
"""
|
| 79 |
Build messages from the conversations and images.
|
|
|
|
| 84 |
|
| 85 |
# get texts from conversations
|
| 86 |
converstion = get_conv_template(sft_format)
|
| 87 |
+
# only use the last 10 round of messages
|
| 88 |
+
latest_messages = messages[-10:]
|
| 89 |
+
results.append(
|
| 90 |
+
{
|
| 91 |
+
"role": "system",
|
| 92 |
+
"content": [
|
| 93 |
+
{
|
| 94 |
+
"type": "text",
|
| 95 |
+
"text": override_system_prompt if override_system_prompt else converstion.system_message,
|
| 96 |
+
}
|
| 97 |
+
],
|
| 98 |
+
}
|
| 99 |
+
)
|
| 100 |
+
print("The actual system prompt for generation:", override_system_prompt if override_system_prompt else converstion.system_message)
|
| 101 |
for mid, message in enumerate(latest_messages):
|
| 102 |
if message["role"] == converstion.roles[0] or message["role"] == "user":
|
| 103 |
record = {
|
| 104 |
"role": message["role"],
|
| 105 |
"content": [],
|
| 106 |
}
|
| 107 |
+
if "timestamps" in message and "images" in message and message["timestamps"] is not None:
|
| 108 |
+
per_round_images, per_round_timestamps = message["images"], message["timestamps"]
|
| 109 |
+
for image, timestamp in zip(per_round_images, per_round_timestamps):
|
| 110 |
+
images.append(image)
|
| 111 |
+
record["content"].append({"type": "text", "text": f"{int(timestamp)//3600:02d}:{(int(timestamp)//60-60*(int(timestamp)//3600)):02d}:{int(timestamp)%60:02d}"})
|
| 112 |
+
record["content"].append({"type": "image", "image": image})
|
| 113 |
+
elif "images" in message:
|
| 114 |
per_round_images = message["images"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
for image in per_round_images:
|
| 116 |
+
images.append(image)
|
| 117 |
+
record["content"].append({"type": "image", "image": image})
|
| 118 |
+
|
|
|
|
|
|
|
|
|
|
| 119 |
if 'content' in message:
|
| 120 |
record["content"].append(
|
| 121 |
{
|
|
|
|
| 124 |
}
|
| 125 |
)
|
| 126 |
results.append(record)
|
| 127 |
+
|
| 128 |
elif message["role"] == converstion.roles[1] or message["role"] == "assistant":
|
| 129 |
formatted_answer = message["content"].strip()
|
| 130 |
# ◁think▷用户说了“你好”,这是一个非常简单的问候,通常用于开启对话。我需要判断用户的意图。可能性一:用户只是礼貌性地打招呼,想要开启一段对话;可能性二:用户可能有更具体的需求,比如询问我的功能、功能或者需要帮助。由于用户没有提供更多信息,我需要保持开放,同时引导用户进一步说明他们的需求。
|
|
|
|
| 149 |
formatted_answer.count(processor.image_token) == 0
|
| 150 |
), f"there should be no {processor.image_token} in the assistant's reply, but got {messages}"
|
| 151 |
converstion.append_message(converstion.roles[1], formatted_answer)
|
| 152 |
+
|
| 153 |
text = processor.apply_chat_template(results, add_generation_prompt=True)
|
| 154 |
print(f"raw text = {text}")
|
| 155 |
if len(images) == 0:
|
|
|
|
| 165 |
return inputs
|
| 166 |
|
| 167 |
|
| 168 |
+
@torch.no_grad()
|
| 169 |
+
@torch.inference_mode()
|
| 170 |
def kimi_vl_generate(
|
| 171 |
model: torch.nn.Module,
|
| 172 |
processor: AutoProcessor,
|
| 173 |
conversations: list[Conversation],
|
| 174 |
+
override_system_prompt,
|
| 175 |
stop_words: list,
|
| 176 |
max_length: int = 256,
|
| 177 |
temperature: float = 1.0,
|
|
|
|
| 180 |
):
|
| 181 |
# convert conversation to inputs
|
| 182 |
print(f"conversations = {conversations}")
|
| 183 |
+
inputs = preprocess(conversations, processor=processor, override_system_prompt=override_system_prompt)
|
| 184 |
inputs = inputs.to(model.device)
|
| 185 |
|
| 186 |
return generate(
|
|
|
|
| 194 |
chunk_size=chunk_size,
|
| 195 |
)
|
| 196 |
|
|
|
|
| 197 |
def generate(
|
| 198 |
model,
|
| 199 |
processor,
|