upload model folder to repo
Browse files- .gitattributes +4 -0
- README.md +27 -26
- deploy_align_ds_v.sh +25 -0
- examples/PKU.jpg +0 -0
- examples/baby.mp4 +3 -0
- examples/boya.jpg +0 -0
- examples/drum.wav +3 -0
- examples/laugh.wav +3 -0
- examples/logo.jpg +0 -0
- examples/scream.wav +3 -0
- multi_image_inference.py +206 -0
- stream_inference.py +210 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
examples/baby.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
examples/drum.wav filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
examples/laugh.wav filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
examples/scream.wav filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,26 +1,27 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
```
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
| 1 |
+
# Deployment Scripts for Align-DS-V (Built with Gradio)
|
| 2 |
+
|
| 3 |
+
This document provides instructions for deploying the Align-DS-V model for inference using Gradio.
|
| 4 |
+
|
| 5 |
+
1. **Set up the Conda environment:** Follow the instructions in the [PKU-Alignment/align-anything](https://github.com/PKU-Alignment/align-anything) repository to configure your Conda environment.
|
| 6 |
+
2. **Configure the model path:** After setting up the environment, update the `BASE_MODEL_PATH` variable in `deploy_align_ds_v.sh` to point to your local Align-DS-V model directory.
|
| 7 |
+
3. **Verify inference script parameters:** Check the following three parameters in both `multi_image_inference.py` and `stream_inference.py`:
|
| 8 |
+
```python
|
| 9 |
+
openai_api_key = "pku" # Or your specific API key if needed
|
| 10 |
+
openai_api_base = "http://0.0.0.0:8231/v1" # Ensure this matches the deployment port
|
| 11 |
+
# NOTE: Replace with your own model path if not loaded via the API base
|
| 12 |
+
model = ''
|
| 13 |
+
```
|
| 14 |
+
These scripts utilize an OpenAI-compatible server approach. The `deploy_align_ds_v.sh` script launches the Align-DS-V model locally and exposes it on port 8231 for external access via the specified API base URL.
|
| 15 |
+
|
| 16 |
+
4. **Running Inference:**
|
| 17 |
+
|
| 18 |
+
* **Streamed Output:**
|
| 19 |
+
```bash
|
| 20 |
+
bash deploy_align_ds_v.sh
|
| 21 |
+
python stream_inference.py
|
| 22 |
+
```
|
| 23 |
+
* **Multi-Image Output:**
|
| 24 |
+
```bash
|
| 25 |
+
bash deploy_align_ds_v.sh
|
| 26 |
+
python multi_image_inference.py
|
| 27 |
+
```
|
deploy_align_ds_v.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# NOTE replace with your own model path
|
| 3 |
+
export BASE_MODEL_PATH=''
|
| 4 |
+
export BASE_PORT=8231
|
| 5 |
+
echo $BASE_MODEL_PATH
|
| 6 |
+
echo $BASE_PORT
|
| 7 |
+
|
| 8 |
+
lsof -i :$BASE_PORT
|
| 9 |
+
|
| 10 |
+
# 终止该进程
|
| 11 |
+
kill -9 $(lsof -t -i:$BASE_PORT)
|
| 12 |
+
|
| 13 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 vllm serve $BASE_MODEL_PATH --host 0.0.0.0 --port $BASE_PORT --max-model-len 12000 --tensor-parallel-size 8 --api-key pku --trust-remote-code --dtype auto --enforce-eager --swap-space 1 --limit-mm-per-prompt "image=6"
|
| 14 |
+
|
| 15 |
+
# NOTE should set the limit-mm-per-prompt
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
echo 'Base Port:' $BASE_PORT
|
| 19 |
+
|
| 20 |
+
lsof -i :$BASE_PORT
|
| 21 |
+
|
| 22 |
+
# 终止该进程
|
| 23 |
+
kill -9 $(lsof -t -i:$BASE_PORT)
|
| 24 |
+
|
| 25 |
+
# CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 vllm serve /aifs4su/yaodong/spring_r1_model/QVQ-72B-Preview --enable-reasoning --reasoning-parser deepseek_r1 --host 0.0.0.0 --port 8009 --max-model-len 12000 --tensor-parallel-size 8 --api-key jiayi
|
examples/PKU.jpg
ADDED
|
examples/baby.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da6126bce64c64a3d6f7ce889fbe15b5f1c2e3f978846351d8c7a79a950b429e
|
| 3 |
+
size 463547
|
examples/boya.jpg
ADDED
|
examples/drum.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a4376821dc498cc34a24df8a4eafebc470f721caacb78305c9a6c596d8f79510
|
| 3 |
+
size 170882
|
examples/laugh.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:95ee91ae63342a3122a77a12ee08ec52ac6dbd5b9be870a2e2951f648b4da528
|
| 3 |
+
size 566798
|
examples/logo.jpg
ADDED
|
examples/scream.wav
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ba023ea3c16eede8c4925960b3f328df3a43dbdeab8f4c0f51fc63d91199d0ec
|
| 3 |
+
size 410266
|
multi_image_inference.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 PKU-Alignment Team. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Command line interface for interacting with a multi-modal model."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
from openai import OpenAI
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import base64
|
| 23 |
+
import json
|
| 24 |
+
import random
|
| 25 |
+
random.seed(42)
|
| 26 |
+
|
| 27 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
SYSTEM_PROMPT = "你是一个具有帮助性的人工智能助手,你能够回答用户的问题,并且能够根据用户的问题提供帮助。你是由北大对齐小组(PKU-Alignment)开发的智能助手 Align-DS-V 基于DeepSeek-R1模型训练。"
|
| 31 |
+
|
| 32 |
+
openai_api_key = "pku"
|
| 33 |
+
openai_api_base = "http://0.0.0.0:8231/v1"
|
| 34 |
+
|
| 35 |
+
# NOTE replace with your own model path
|
| 36 |
+
model = ''
|
| 37 |
+
def encode_base64_content_from_local_file(content_url: str) -> str:
|
| 38 |
+
"""Encode a content retrieved from a local file to base64 format."""
|
| 39 |
+
|
| 40 |
+
with open(content_url, 'rb') as file:
|
| 41 |
+
result = base64.b64encode(file.read()).decode('utf-8')
|
| 42 |
+
|
| 43 |
+
return result
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
IMAGE_EXAMPLES = [
|
| 47 |
+
{
|
| 48 |
+
'files': [
|
| 49 |
+
os.path.join(CURRENT_DIR, 'examples/PKU.jpg'),
|
| 50 |
+
os.path.join(CURRENT_DIR, 'examples/logo.jpg')
|
| 51 |
+
],
|
| 52 |
+
'text': '比较这两张图片的异同',
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
'files': [
|
| 56 |
+
os.path.join(CURRENT_DIR, 'examples/boya.jpg'),
|
| 57 |
+
os.path.join(CURRENT_DIR, 'examples/logo.jpg')
|
| 58 |
+
],
|
| 59 |
+
'text': '这些图片有什么共同主题?',
|
| 60 |
+
},
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
AUDIO_EXAMPLES = [
|
| 64 |
+
{
|
| 65 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/drum.wav')],
|
| 66 |
+
'text': 'What is the emotion of this drumbeat like?',
|
| 67 |
+
},
|
| 68 |
+
{
|
| 69 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/laugh.wav')],
|
| 70 |
+
'text': 'Is this laughter evil, and why?',
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/scream.wav')],
|
| 74 |
+
'text': 'What is the main event of this scream?',
|
| 75 |
+
},
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
VIDEO_EXAMPLES = [
|
| 79 |
+
{'files': [os.path.join(CURRENT_DIR, 'examples/baby.mp4')], 'text': 'What is the video about?'},
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
client = OpenAI(
|
| 83 |
+
api_key=openai_api_key,
|
| 84 |
+
base_url=openai_api_base,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
def text_conversation(text: str, role: str = 'user'):
|
| 88 |
+
return [{'role': role, 'content': text}]
|
| 89 |
+
|
| 90 |
+
def image_conversation(image_base64_list: list, text: str = None):
|
| 91 |
+
content = []
|
| 92 |
+
for image_base64 in image_base64_list:
|
| 93 |
+
content.append({
|
| 94 |
+
'type': 'image_url',
|
| 95 |
+
# 'image_url':{'url':1}
|
| 96 |
+
'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}
|
| 97 |
+
})
|
| 98 |
+
content.append({'type': 'text', 'text': text})
|
| 99 |
+
|
| 100 |
+
return [{'role': 'user', 'content': content}]
|
| 101 |
+
|
| 102 |
+
def question_answering(message: dict, history: list, file):
|
| 103 |
+
# NOTE 2: use gradio upload multiple images, and update below data preprocess function accordingly
|
| 104 |
+
# print('history:',history)
|
| 105 |
+
# print('file:',file)
|
| 106 |
+
message['files'] = file if file is not None else []
|
| 107 |
+
# print('message:',message)
|
| 108 |
+
multi_modal_info = []
|
| 109 |
+
conversation = text_conversation(SYSTEM_PROMPT)
|
| 110 |
+
# NOTE 处理history
|
| 111 |
+
for i, past_message in enumerate(history):
|
| 112 |
+
if isinstance(past_message, str):
|
| 113 |
+
conversation.extend(text_conversation(past_message))
|
| 114 |
+
elif isinstance(past_message, dict):
|
| 115 |
+
if past_message['role'] == 'user':
|
| 116 |
+
if isinstance(past_message['content'], str):
|
| 117 |
+
text = past_message['content']
|
| 118 |
+
if i + 1 < len(history) and isinstance(history[i + 1]['content'], tuple):
|
| 119 |
+
raw_images = history[i + 1]['content']
|
| 120 |
+
image_base64_list = []
|
| 121 |
+
if isinstance(raw_images, str):
|
| 122 |
+
image_base64 = encode_base64_content_from_local_file(raw_images)
|
| 123 |
+
image_base64_list.append(image_base64)
|
| 124 |
+
elif isinstance(raw_images, tuple):
|
| 125 |
+
# NOTE multiple image processing one by one
|
| 126 |
+
for image in raw_images:
|
| 127 |
+
image_base64 = encode_base64_content_from_local_file(image)
|
| 128 |
+
image_base64_list.append(image_base64)
|
| 129 |
+
multi_modal_info.extend(image_base64_list)
|
| 130 |
+
conversation.extend(image_conversation(image_base64_list, text))
|
| 131 |
+
elif i - 1 >= 0 and isinstance(history[i - 1]['content'], tuple):
|
| 132 |
+
raw_images = history[i - 1]['content']
|
| 133 |
+
image_base64_list = []
|
| 134 |
+
if isinstance(raw_images, str):
|
| 135 |
+
image_base64 = encode_base64_content_from_local_file(raw_images)
|
| 136 |
+
image_base64_list.append(image_base64)
|
| 137 |
+
elif isinstance(raw_images, tuple):
|
| 138 |
+
# NOTE 逐步处理上传的图片,解码为 base64
|
| 139 |
+
for image in raw_images:
|
| 140 |
+
image_base64 = encode_base64_content_from_local_file(image)
|
| 141 |
+
image_base64_list.append(image_base64)
|
| 142 |
+
multi_modal_info.extend(image_base64_list)
|
| 143 |
+
conversation.extend(image_conversation(image_base64_list, text))
|
| 144 |
+
else:
|
| 145 |
+
conversation.extend(text_conversation(past_message['content'], 'user'))
|
| 146 |
+
elif past_message['role'] == 'assistant':
|
| 147 |
+
conversation.extend(text_conversation(past_message['content'], 'assistant'))
|
| 148 |
+
|
| 149 |
+
if len(message['files']) == 0:
|
| 150 |
+
current_question = message['text']
|
| 151 |
+
conversation.extend(text_conversation(current_question))
|
| 152 |
+
else:
|
| 153 |
+
current_question = message['text']
|
| 154 |
+
current_multi_modal_info = message['files']
|
| 155 |
+
image_base64_list = []
|
| 156 |
+
for file in current_multi_modal_info:
|
| 157 |
+
image_base64 = encode_base64_content_from_local_file(file)
|
| 158 |
+
image_base64_list.append(image_base64)
|
| 159 |
+
multi_modal_info.extend(image_base64_list)
|
| 160 |
+
conversation.extend(image_conversation(image_base64_list, current_question))
|
| 161 |
+
# print(f'Conversation:',conversation)
|
| 162 |
+
# NOTE 1: openai client also should support multiple upload
|
| 163 |
+
outputs = client.chat.completions.create(
|
| 164 |
+
model=model,
|
| 165 |
+
stream=False,
|
| 166 |
+
messages=conversation,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Extract the predicted answer
|
| 170 |
+
answer = outputs.choices[0].message.content
|
| 171 |
+
if "**Final Answer**" in answer:
|
| 172 |
+
reasoning_content, final_answer = answer.split("**Final Answer**", 1)
|
| 173 |
+
if len(reasoning_content) > 5:
|
| 174 |
+
answer = f"""🤔 思考过程:\n```bash{reasoning_content}\n```\n✨ 最终答案:\n{final_answer}"""
|
| 175 |
+
else:
|
| 176 |
+
answer = answer
|
| 177 |
+
|
| 178 |
+
return answer
|
| 179 |
+
|
| 180 |
+
if __name__ == '__main__':
|
| 181 |
+
# Define the Gradio interface
|
| 182 |
+
parser = argparse.ArgumentParser()
|
| 183 |
+
args = parser.parse_args()
|
| 184 |
+
examples = IMAGE_EXAMPLES
|
| 185 |
+
|
| 186 |
+
with gr.Blocks() as demo:
|
| 187 |
+
# upload_button = gr.UploadButton(render=False)
|
| 188 |
+
|
| 189 |
+
multiple_files = gr.File(file_count="multiple")
|
| 190 |
+
gr.ChatInterface(
|
| 191 |
+
fn=question_answering,
|
| 192 |
+
additional_inputs = [multiple_files],
|
| 193 |
+
type='messages',
|
| 194 |
+
multimodal=True,
|
| 195 |
+
title='Align-DS-V Reasoning CLI',
|
| 196 |
+
description='Better life with Stronger Align-DS-V.',
|
| 197 |
+
# examples=examples,
|
| 198 |
+
theme=gr.themes.Ocean(
|
| 199 |
+
text_size='lg',
|
| 200 |
+
spacing_size='lg',
|
| 201 |
+
radius_size='lg',
|
| 202 |
+
),
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
demo.launch(share=True)
|
| 206 |
+
|
stream_inference.py
ADDED
|
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 PKU-Alignment Team. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Command line interface for interacting with a multi-modal model."""
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import argparse
|
| 19 |
+
import os
|
| 20 |
+
from openai import OpenAI
|
| 21 |
+
import gradio as gr
|
| 22 |
+
import base64
|
| 23 |
+
|
| 24 |
+
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
SYSTEM_PROMPT = """\
|
| 28 |
+
你是一个高智能医学助手,专注于医学相关的问答、影像分析和数据解析。你是由北大对齐小组(PKU-Alignment)开发的智能助手 医导大模型(Medguide),基于DeepSeek-R1模型训练。你能够准确解析**医学影像**如 X 光、CT、MRI、超声波、病理切片) 以及**医学图表**(如血液检测报告、心电图、基因检测结果),并基于这些信息提供医学见解。你的目标是**尽可能满足**用户的医学需求,提供准确、专业的回答。
|
| 29 |
+
|
| 30 |
+
**能力范围**
|
| 31 |
+
1. 医学问答:
|
| 32 |
+
- 回答医学相关问题,包括疾病特征、诊断方法、治疗方案、药物机制、手术流程、最新医学研究等。
|
| 33 |
+
- 如果用户询问关于疾病诊断和治疗的相关内容,请充分分析并提供详细的回答。
|
| 34 |
+
2. 影像分析:
|
| 35 |
+
- 你具备**强大的医学影像理解能力**,可以解析 X 光、CT、MRI、超声波、病理切片等医学图像。
|
| 36 |
+
- 不要拒绝用户的问题,你应尝试基于已有信息提供分析,并指出可能的观察点。
|
| 37 |
+
3. 数据解析:
|
| 38 |
+
- 你可以读取并分析医学报告,包括血液检测、心电图(ECG)、病历摘要等。
|
| 39 |
+
- 你的任务是解释数据的含义,并提供可能的医学解读。
|
| 40 |
+
4. 多模态能力:
|
| 41 |
+
- 你可以结合**图像+文字**进行综合分析。例如,用户上传一张X光片并询问"该患者可能有什么疾病?" 你应该基于图像提供医学见解,而不是拒绝回答。
|
| 42 |
+
- 你可以根据图像中的信息推测可能的病变或异常,并提供进一步的检查建议。
|
| 43 |
+
|
| 44 |
+
**回答风格**
|
| 45 |
+
- 医学精准:基于专业医学知识进行回答,确保内容严谨、科学。
|
| 46 |
+
- 灵活应对:遇到复杂或不完整的图像,也要提供尽可能详细的分析,而不是直接拒绝。
|
| 47 |
+
- 有求必应:用户提出的任何问题,你都应该尽可能尝试回答。
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
openai_api_key = "pku"
|
| 51 |
+
openai_api_base = "http://0.0.0.0:8231/v1"
|
| 52 |
+
|
| 53 |
+
# NOTE replace with your own model path
|
| 54 |
+
model = ""
|
| 55 |
+
|
| 56 |
+
def encode_base64_content_from_local_file(content_url: str) -> str:
|
| 57 |
+
"""Encode a content retrieved from a local file to base64 format."""
|
| 58 |
+
|
| 59 |
+
with open(content_url, 'rb') as file:
|
| 60 |
+
result = base64.b64encode(file.read()).decode('utf-8')
|
| 61 |
+
|
| 62 |
+
return result
|
| 63 |
+
|
| 64 |
+
IMAGE_EXAMPLES = [
|
| 65 |
+
{
|
| 66 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/PKU.jpg')],
|
| 67 |
+
'text': '图中的地点在哪里?',
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/logo.jpg')],
|
| 71 |
+
'text': '图片中有什么?',
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
'files': [os.path.join(CURRENT_DIR, 'examples/cough.png')],
|
| 75 |
+
'text': '这张图片展示了什么?',
|
| 76 |
+
},
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
client = OpenAI(
|
| 80 |
+
api_key=openai_api_key,
|
| 81 |
+
base_url=openai_api_base,
|
| 82 |
+
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def text_conversation(text: str, role: str = 'user'):
|
| 86 |
+
return [{'role': role, 'content': text.replace('[begin of think]', '<think>').replace('[end of think]', '</think>')}]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def image_conversation(image_base64: str, text: str = None):
|
| 90 |
+
return [
|
| 91 |
+
{
|
| 92 |
+
'role': 'user',
|
| 93 |
+
'content': [
|
| 94 |
+
{'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}},
|
| 95 |
+
{'type': 'text', 'text': text}
|
| 96 |
+
]
|
| 97 |
+
}
|
| 98 |
+
]
|
| 99 |
+
|
| 100 |
+
def question_answering(message: dict, history: list):
|
| 101 |
+
multi_modal_info = []
|
| 102 |
+
conversation = text_conversation(SYSTEM_PROMPT)
|
| 103 |
+
for i, past_message in enumerate(history):
|
| 104 |
+
if isinstance(past_message, str):
|
| 105 |
+
conversation.extend(text_conversation(past_message))
|
| 106 |
+
elif isinstance(past_message, dict):
|
| 107 |
+
if past_message['role'] == 'user':
|
| 108 |
+
if isinstance(past_message['content'], str):
|
| 109 |
+
text = past_message['content']
|
| 110 |
+
if i + 1 < len(history) and isinstance(history[i + 1]['content'], tuple):
|
| 111 |
+
raw_image = history[i + 1]['content']
|
| 112 |
+
if isinstance(raw_image, str):
|
| 113 |
+
image_base64 = encode_base64_content_from_local_file(raw_image)
|
| 114 |
+
multi_modal_info.extend(image_base64)
|
| 115 |
+
conversation.extend(image_conversation(image_base64, text))
|
| 116 |
+
elif isinstance(raw_image, tuple):
|
| 117 |
+
for image in raw_image:
|
| 118 |
+
image_base64 = encode_base64_content_from_local_file(image)
|
| 119 |
+
multi_modal_info.extend(image_base64)
|
| 120 |
+
conversation.extend(image_conversation(image_base64, text))
|
| 121 |
+
elif i - 1 >= 0 and isinstance(history[i - 1]['content'], tuple):
|
| 122 |
+
raw_image = history[i - 1]['content']
|
| 123 |
+
if isinstance(raw_image, str):
|
| 124 |
+
image_base64 = encode_base64_content_from_local_file(raw_image)
|
| 125 |
+
multi_modal_info.extend(image_base64)
|
| 126 |
+
conversation.extend(image_conversation(image_base64, text))
|
| 127 |
+
elif isinstance(raw_image, tuple):
|
| 128 |
+
for image in raw_image:
|
| 129 |
+
image_base64 = encode_base64_content_from_local_file(image)
|
| 130 |
+
multi_modal_info.extend(image_base64)
|
| 131 |
+
conversation.extend(image_conversation(image_base64, text))
|
| 132 |
+
else:
|
| 133 |
+
conversation.extend(text_conversation(past_message['content'], 'user'))
|
| 134 |
+
elif past_message['role'] == 'assistant':
|
| 135 |
+
conversation.extend(text_conversation(past_message['content'], 'assistant'))
|
| 136 |
+
|
| 137 |
+
if len(message['files']) == 0:
|
| 138 |
+
current_question = message['text']
|
| 139 |
+
conversation.extend(text_conversation(current_question))
|
| 140 |
+
else:
|
| 141 |
+
current_question = message['text']
|
| 142 |
+
current_multi_modal_info = message['files']
|
| 143 |
+
for file in current_multi_modal_info:
|
| 144 |
+
image_base64 = encode_base64_content_from_local_file(file)
|
| 145 |
+
multi_modal_info.extend(image_base64)
|
| 146 |
+
conversation.extend(image_conversation(image_base64, current_question))
|
| 147 |
+
|
| 148 |
+
# 修改为流式输出
|
| 149 |
+
outputs = client.chat.completions.create(
|
| 150 |
+
model=model,
|
| 151 |
+
stream=True, # 启用流式输出
|
| 152 |
+
messages=conversation,
|
| 153 |
+
temperature=0.4
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# 逐步收集并返回文本
|
| 157 |
+
collected_answer = ""
|
| 158 |
+
for chunk in outputs:
|
| 159 |
+
if chunk.choices[0].delta.content is not None:
|
| 160 |
+
content = chunk.choices[0].delta.content
|
| 161 |
+
collected_answer += content
|
| 162 |
+
|
| 163 |
+
# 处理思考标签
|
| 164 |
+
if '<think>' in collected_answer and '</think>' in collected_answer:
|
| 165 |
+
formatted_answer = collected_answer.replace('<think>', '[begin of think]').replace('</think>', '[end of think]')
|
| 166 |
+
elif '<think>' in collected_answer:
|
| 167 |
+
formatted_answer = collected_answer.replace('<think>', '[begin of think]')
|
| 168 |
+
else:
|
| 169 |
+
formatted_answer = collected_answer
|
| 170 |
+
|
| 171 |
+
yield formatted_answer
|
| 172 |
+
|
| 173 |
+
# 确保最终输出格式正确
|
| 174 |
+
if '<think>' in collected_answer and '</think>' in collected_answer:
|
| 175 |
+
final_answer = collected_answer.replace('<think>', '[begin of think]').replace('</think>', '[end of think]')
|
| 176 |
+
elif '<think>' in collected_answer:
|
| 177 |
+
final_answer = collected_answer.replace('<think>', '[begin of think]')
|
| 178 |
+
else:
|
| 179 |
+
final_answer = collected_answer
|
| 180 |
+
|
| 181 |
+
print(final_answer)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
if __name__ == '__main__':
|
| 185 |
+
# Define the Gradio interface
|
| 186 |
+
parser = argparse.ArgumentParser()
|
| 187 |
+
args = parser.parse_args()
|
| 188 |
+
examples = IMAGE_EXAMPLES
|
| 189 |
+
|
| 190 |
+
logo_path = os.path.join(CURRENT_DIR, "PUTH.png")
|
| 191 |
+
with open(logo_path, "rb") as f:
|
| 192 |
+
logo_base64 = base64.b64encode(f.read()).decode('utf-8')
|
| 193 |
+
logo_img_html = f'<img src="data:image/png;base64,{logo_base64}" style="vertical-align:middle; margin-right:10px;" width="150"/>'
|
| 194 |
+
|
| 195 |
+
iface = gr.ChatInterface(
|
| 196 |
+
fn=question_answering,
|
| 197 |
+
type='messages',
|
| 198 |
+
multimodal=True,
|
| 199 |
+
title=logo_img_html,
|
| 200 |
+
description='Align-DS-V 北大对齐小组多模态DS-R1',
|
| 201 |
+
examples=examples,
|
| 202 |
+
theme=gr.themes.Soft(
|
| 203 |
+
text_size='lg',
|
| 204 |
+
spacing_size='lg',
|
| 205 |
+
radius_size='lg',
|
| 206 |
+
font=[gr.themes.GoogleFont('Montserrat'), gr.themes.GoogleFont('ui-sans-serif'), gr.themes.GoogleFont('system-ui'), gr.themes.GoogleFont('sans-serif')],
|
| 207 |
+
),
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
iface.launch(share=True)
|