|
|
|
|
|
|
|
""" |
|
多模态文档理解UI界面 |
|
支持Align-DS-V模型的文档输入和System Prompt设置 |
|
支持上传图像和文档文件(doc, docx, pdf等) |
|
""" |
|
|
|
import os |
|
import base64 |
|
import argparse |
|
import tempfile |
|
from typing import List, Dict, Any, Union |
|
import gradio as gr |
|
from openai import OpenAI |
|
import docx |
|
import fitz |
|
from PIL import Image |
|
import io |
|
|
|
|
|
os.environ["GRADIO_TEMP_DIR"] = "/aifs4su/yaodong/bob/tmp" |
|
os.makedirs("/aifs4su/yaodong/bob/tmp", exist_ok=True) |
|
|
|
DEFAULT_SYSTEM_PROMPT = """\ |
|
你是一个高智能医学助手,专注于医学相关的问答、影像分析和数据解析。你是由北大对齐小组(PKU-Alignment)开发的智能助手 医诊大模型(MedguideGPT),基于DeepSeek-R1模型训练。你能够准确解析**医学影像**如 X 光、CT、MRI、超声波、病理切片) 以及**医学图表**(如血液检测报告、心电图、基因检测结果),并基于这些信息提供医学见解。你的目标是**尽可能满足**用户的医学需求,提供准确、专业的回答。 |
|
**能力范围** |
|
1. 医学问答: |
|
- 回答医学相关问题,包括疾病特征、诊断方法、治疗方案、药物机制、手术流程、最新医学研究等。 |
|
- 如果用户询问关于疾病诊断和治疗的相关内容,请充分分析并提供详细的回答。 |
|
2. 影像分析: |
|
- 你具备**强大的医学影像理解能力**,可以解析 X 光、CT、MRI、超声波、病理切片等医学图像。 |
|
- 不要拒绝用户的问题,你应尝试基于已有信息提供分析,并指出可能的观察点。 |
|
3. 数据解析: |
|
- 你可以读取并分析医学报告,包括血液检测、心电图(ECG)、病历摘要等。 |
|
- 你的任务是解释数据的含义,并提供可能的医学解读。 |
|
4. 多模态能力: |
|
- 你可以结合**图像+文字**进行综合分析。例如,用户上传一张X光片并询问"该患者可能有什么疾病?" 你应该基于图像提供医学见解,而不是拒绝回答。 |
|
- 你可以根据图像中的信息推测可能的病变或异常,并提供进一步的检查建议。 |
|
**回答风格** |
|
- 医学精准:基于专业医学知识进行回答,确保内容严谨、科学。 |
|
- 灵活应对:遇到复杂或不完整的图像,也要提供尽可能详细的分析,而不是直接拒绝。 |
|
- 有求必应:用户提出的任何问题,你都应该尽可能尝试回答。 |
|
""" |
|
|
|
class MultimodalAssistant: |
|
def __init__(self, api_key="medguide-v", api_base="http://0.0.0.0:8232/v1"): |
|
|
|
self.model = 'medguide-v' |
|
self.client = OpenAI(api_key=api_key, base_url=api_base) |
|
|
|
def encode_image(self, image_path): |
|
with open(image_path, 'rb') as file: |
|
return base64.b64encode(file.read()).decode('utf-8') |
|
|
|
def concatenate_images(self, image_paths): |
|
""" |
|
将多张图像拼接成一张图像,并进行尺寸优化 |
|
Args: |
|
image_paths: 图像文件路径列表 |
|
Returns: |
|
拼接后的图像的 base64 编码 |
|
""" |
|
if not image_paths: |
|
return None |
|
|
|
if len(image_paths) == 1: |
|
return self.encode_image(image_paths[0]) |
|
|
|
|
|
MAX_WIDTH = 1920 |
|
MAX_HEIGHT = 1920 |
|
TARGET_WIDTH = 1024 |
|
|
|
|
|
images = [] |
|
for path in image_paths: |
|
try: |
|
img = Image.open(path) |
|
|
|
if img.mode != 'RGB': |
|
img = img.convert('RGB') |
|
|
|
|
|
if img.width > TARGET_WIDTH: |
|
ratio = TARGET_WIDTH / img.width |
|
new_height = int(img.height * ratio) |
|
img = img.resize((TARGET_WIDTH, new_height), Image.Resampling.LANCZOS) |
|
|
|
images.append(img) |
|
except Exception as e: |
|
print(f"无法加载图像 {path}: {e}") |
|
continue |
|
|
|
if not images: |
|
return None |
|
|
|
|
|
max_width = max(img.width for img in images) |
|
total_height = sum(img.height for img in images) |
|
|
|
|
|
if total_height > MAX_HEIGHT: |
|
scale_ratio = MAX_HEIGHT / total_height |
|
max_width = int(max_width * scale_ratio) |
|
|
|
scaled_images = [] |
|
for img in images: |
|
new_width = int(img.width * scale_ratio) |
|
new_height = int(img.height * scale_ratio) |
|
scaled_images.append(img.resize((new_width, new_height), Image.Resampling.LANCZOS)) |
|
images = scaled_images |
|
total_height = sum(img.height for img in images) |
|
|
|
|
|
concatenated = Image.new('RGB', (max_width, total_height), color='white') |
|
|
|
|
|
y_offset = 0 |
|
for img in images: |
|
|
|
x_offset = (max_width - img.width) // 2 |
|
concatenated.paste(img, (x_offset, y_offset)) |
|
y_offset += img.height |
|
|
|
|
|
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') |
|
concatenated.save(temp_file.name, format='JPEG', quality=85, optimize=True) |
|
temp_file.close() |
|
|
|
|
|
base64_result = self.encode_image(temp_file.name) |
|
|
|
|
|
try: |
|
os.remove(temp_file.name) |
|
except: |
|
pass |
|
|
|
return base64_result |
|
|
|
def extract_document_content(self, file_path): |
|
result = {'text': '', 'images': []} |
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
if file_ext in ['.doc', '.docx']: |
|
doc = docx.Document(file_path) |
|
result['text'] = '\n\n'.join([para.text for para in doc.paragraphs if para.text.strip()]) |
|
|
|
for rel in doc.part.rels.values(): |
|
if "image" in rel.target_ref: |
|
try: |
|
img_temp = tempfile.NamedTemporaryFile(delete=False, suffix='.png') |
|
img_temp.write(rel.target_part.blob) |
|
img_temp.close() |
|
result['images'].append(img_temp.name) |
|
except: pass |
|
|
|
elif file_ext == '.pdf': |
|
pdf_document = fitz.open(file_path) |
|
result['text'] = '\n\n'.join([page.get_text() for page in pdf_document]) |
|
|
|
for page_num in range(len(pdf_document)): |
|
page = pdf_document[page_num] |
|
img_path = f"{file_path}_page{page_num+1}.png" |
|
page.get_pixmap().save(img_path) |
|
result['images'].append(img_path) |
|
else: |
|
result['images'].append(file_path) |
|
|
|
|
|
result['images'] = result['images'][:5] |
|
return result |
|
|
|
def text_conversation(self, text: str, role: str = 'user'): |
|
return [{'role': role, 'content': text.replace('[begin of think]', '<think>').replace('[end of think]', '</think>')}] |
|
|
|
def image_conversation(self, image_base64: str, text: str = None): |
|
return [ |
|
{ |
|
'role': 'user', |
|
'content': [ |
|
{'type': 'image_url', 'image_url': {'url': f"data:image/jpeg;base64,{image_base64}"}}, |
|
{'type': 'text', 'text': text} |
|
] |
|
} |
|
] |
|
|
|
def process_conversation(self, system_prompt, message, history, files): |
|
conversation = [{'role': 'system', 'content': system_prompt}] |
|
for past_message in history: |
|
role = past_message['role'] |
|
content = past_message['content'] |
|
if role == 'user': |
|
if isinstance(content, str): |
|
conversation.extend(self.text_conversation(content)) |
|
elif isinstance(content, tuple): |
|
conversation.extend(self.image_conversation(content[0], content[1])) |
|
else: |
|
conversation.append({'role': role, 'content': content}) |
|
|
|
current_question = message['text'] if isinstance(message, dict) and 'text' in message else message |
|
|
|
if not files: |
|
conversation.append({'role': 'user', 'content': current_question}) |
|
else: |
|
content = [] |
|
extracted_text = [] |
|
all_image_paths = [] |
|
temp_files_to_remove = [] |
|
|
|
for file_path in files: |
|
file_ext = os.path.splitext(file_path)[1].lower() |
|
|
|
if file_ext in ['.doc', '.docx', '.pdf']: |
|
doc_content = self.extract_document_content(file_path) |
|
|
|
if doc_content['text']: |
|
extracted_text.append(f"文档 '{os.path.basename(file_path)}' 内容:\n{doc_content['text']}") |
|
|
|
|
|
all_image_paths.extend(doc_content['images']) |
|
|
|
|
|
for img_path in doc_content['images']: |
|
if img_path.startswith(tempfile.gettempdir()) or img_path.startswith(f"{file_path}_page"): |
|
temp_files_to_remove.append(img_path) |
|
else: |
|
|
|
all_image_paths.append(file_path) |
|
|
|
|
|
if all_image_paths: |
|
concatenated_image_base64 = self.concatenate_images(all_image_paths) |
|
if concatenated_image_base64: |
|
content.append({ |
|
'type': 'image_url', |
|
'image_url': {'url': f"data:image/jpeg;base64,{concatenated_image_base64}"} |
|
}) |
|
|
|
|
|
for temp_file in temp_files_to_remove: |
|
try: os.remove(temp_file) |
|
except: pass |
|
|
|
combined_text = current_question |
|
if extracted_text: |
|
combined_text += "\n\n以下是文档内容参考:\n" + "\n\n".join(extracted_text) |
|
|
|
content.append({'type': 'text', 'text': combined_text}) |
|
conversation.append({'role': 'user', 'content': content}) |
|
|
|
response = self.client.chat.completions.create( |
|
model=self.model, |
|
messages=conversation, |
|
stream=False, |
|
temperature = 0.2, |
|
max_tokens = 2048 |
|
) |
|
|
|
answer = response.choices[0].message.content |
|
|
|
if "**Final Answer**" in answer: |
|
reasoning, final_answer = answer.split("**Final Answer**", 1) |
|
if len(reasoning) > 5: |
|
answer = f"""🤔 思考过程:\n```\n{reasoning.strip()}\n```\n\n✨ 最终答案:\n{final_answer.strip()}""" |
|
|
|
return answer |
|
|
|
def create_ui(): |
|
assistant = MultimodalAssistant() |
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# Medguide-V Reasoning CLI") |
|
gr.Markdown("Better life with Medguide-V.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
system_prompt = gr.Textbox( |
|
label="系统提示词", |
|
value=DEFAULT_SYSTEM_PROMPT, |
|
lines=5, |
|
visible=False |
|
) |
|
|
|
files_upload = gr.File( |
|
label="上传文档或图片(模型最大输入窗口12000tokens)", |
|
file_count="multiple", |
|
type="filepath", |
|
file_types=[".jpg", ".jpeg", ".png", ".pdf", ".doc", ".docx"] |
|
) |
|
|
|
with gr.Row(): |
|
clear_btn = gr.Button("清除对话") |
|
example_btn = gr.Button("加载示例") |
|
|
|
chat_interface = gr.ChatInterface( |
|
fn=lambda message, history, files, sys_prompt: assistant.process_conversation( |
|
sys_prompt, message, history, files |
|
), |
|
type='messages', |
|
additional_inputs=[files_upload, system_prompt], |
|
examples=[ |
|
["这份文档的主要内容是什么?", None, None, DEFAULT_SYSTEM_PROMPT], |
|
["分析这份文档的主要观点", None, None, DEFAULT_SYSTEM_PROMPT], |
|
["提取这份文档中的关键数据", None, None, DEFAULT_SYSTEM_PROMPT] |
|
] |
|
) |
|
|
|
clear_btn.click(lambda: None, None, chat_interface.chatbot, queue=False) |
|
example_btn.click( |
|
lambda: [DEFAULT_SYSTEM_PROMPT, None, []], |
|
None, |
|
[system_prompt, chat_interface.chatbot, files_upload], |
|
queue=False |
|
) |
|
|
|
return demo |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="多模态文档理解UI界面") |
|
parser.add_argument("--api_key", type=str, default="medguide-v") |
|
parser.add_argument("--api_base", type=str, default="http://0.0.0.0:8232/v1") |
|
parser.add_argument("--share", default=True, action="store_true") |
|
args = parser.parse_args() |
|
|
|
create_ui().launch(share=args.share) |