Spaces:
Runtime error
Runtime error
import re, uuid | |
import base64 | |
import os | |
import bcrypt | |
import gradio as gr | |
import boto3 | |
from gradio_pdf import PDF | |
from pathlib import Path | |
import time | |
import shutil | |
from typing import AsyncGenerator, List, Optional, Tuple | |
from gradio import ChatMessage | |
from fpdf import FPDF | |
REPORT_DIR = Path("reports") | |
REPORT_DIR.mkdir(exist_ok=True) | |
SALT = b'$2b$12$MC7djiqmIR7154Syul5Wme' | |
s3_client = boto3.client("s3", region_name="ca-central-1", aws_access_key_id=os.getenv("AWS_ACCESS_KEY"), aws_secret_access_key=os.getenv("AWS_SECRET_KEY")) | |
BUCKET_NAME = "molx-data-storage" | |
USERS = { | |
'test_user': b'$2b$12$MC7djiqmIR7154Syul5WmeQwebwsNOK5svMX08zMYhvpF9P9IVXe6', | |
'pna': b'$2b$12$MC7djiqmIR7154Syul5WmeWTzYft1UnOV4uGVn54FGfmbH3dRNq1C', | |
'dr_rajat': b'$2b$12$MC7djiqmIR7154Syul5WmeKZX8DXEs48GWbFpO3nRtFLbB5W/2suW' | |
} | |
class ChatInterface: | |
""" | |
A chat interface for interacting with a medical AI agent through Gradio. | |
Handles file uploads, message processing, and chat history management. | |
Supports both regular image files and DICOM medical imaging files. | |
""" | |
def __init__(self, agent, tools_dict): | |
""" | |
Initialize the chat interface. | |
Args: | |
agent: The medical AI agent to handle requests | |
tools_dict (dict): Dictionary of available tools for image processing | |
""" | |
self.agent = agent | |
self.tools_dict = tools_dict | |
self.upload_dir = Path("temp") | |
self.upload_dir.mkdir(exist_ok=True) | |
self.current_thread_id = None | |
# Separate storage for original and display paths | |
self.original_file_path = None # For LLM (.dcm or other) | |
self.display_file_path = None # For UI (always viewable format) | |
def handle_upload(self, file_path: str) -> str: | |
""" | |
Handle new file upload and set appropriate paths. | |
Args: | |
file_path (str): Path to the uploaded file | |
Returns: | |
str: Display path for UI, or None if no file uploaded | |
""" | |
if not file_path: | |
return None | |
source = Path(file_path) | |
timestamp = int(time.time()) | |
# Save original file with proper suffix | |
suffix = source.suffix.lower() | |
saved_path = self.upload_dir / f"upload_{timestamp}{suffix}" | |
shutil.copy2(file_path, saved_path) # Use file_path directly instead of source | |
self.original_file_path = str(saved_path) | |
# Handle DICOM conversion for display only | |
if suffix == ".dcm": | |
output, _ = self.tools_dict["DicomProcessorTool"]._run(str(saved_path)) | |
self.display_file_path = output["image_path"] | |
else: | |
self.display_file_path = str(saved_path) | |
return self.display_file_path, gr.update(interactive=True), gr.update(interactive=True) | |
def add_message( | |
self, message: str, display_image: str, history: List[dict] | |
) -> Tuple[List[dict], gr.Textbox]: | |
""" | |
Add a new message to the chat history. | |
Args: | |
message (str): Text message to add | |
display_image (str): Path to image being displayed | |
history (List[dict]): Current chat history | |
Returns: | |
Tuple[List[dict], gr.Textbox]: Updated history and textbox component | |
""" | |
image_path = self.original_file_path or display_image | |
if image_path is not None: | |
history.append({"role": "user", "content": {"path": image_path}}) | |
if message is not None: | |
history.append({"role": "user", "content": message}) | |
return history, gr.Textbox(value=message, interactive=False) | |
async def process_message( | |
self, message: str, display_image: Optional[str], session_details: dict, chat_history: List[ChatMessage] | |
) -> AsyncGenerator[Tuple[List[ChatMessage], Optional[str], str], None]: | |
""" | |
Process a message and generate responses. | |
Args: | |
message (str): User message to process | |
display_image (Optional[str]): Path to currently displayed image | |
chat_history (List[ChatMessage]): Current chat history | |
Yields: | |
Tuple[List[ChatMessage], Optional[str], str]: Updated chat history, display path, and empty string | |
""" | |
chat_history = chat_history or [] | |
# Initialize thread if needed | |
if not self.current_thread_id: | |
self.current_thread_id = str(time.time()) | |
messages = [] | |
image_path = self.original_file_path or display_image | |
if image_path is not None: | |
# Send path for tools | |
messages.append({"role": "user", "content": f"image_path: {image_path}"}) | |
# Load and encode image for multimodal | |
with open(image_path, "rb") as img_file: | |
img_base64 = base64.b64encode(img_file.read()).decode("utf-8") | |
messages.append( | |
{ | |
"role": "user", | |
"content": [ | |
{ | |
"type": "image_url", | |
"image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, | |
} | |
], | |
} | |
) | |
if message is not None: | |
messages.append({"role": "user", "content": [{"type": "text", "text": message}]}) | |
try: | |
for event in self.agent.workflow.stream( | |
{"messages": messages}, {"configurable": {"thread_id": self.current_thread_id}} | |
): | |
if isinstance(event, dict): | |
if "process" in event: | |
content = event["process"]["messages"][-1].content | |
if content: | |
content = re.sub(r"temp/[^\s]*", "", content) | |
chat_history.append(ChatMessage(role="assistant", content=content)) | |
yield chat_history, self.display_file_path, "" | |
elif "execute" in event: | |
for message in event["execute"]["messages"]: | |
tool_name = message.name | |
tool_result = eval(message.content)[0] | |
if tool_result: | |
metadata = {"title": f"πΌοΈ Image from tool: {tool_name}"} | |
formatted_result = " ".join( | |
line.strip() for line in str(tool_result).splitlines() | |
).strip() | |
metadata["description"] = formatted_result | |
chat_history.append( | |
ChatMessage( | |
role="assistant", | |
content=formatted_result, | |
metadata=metadata, | |
) | |
) | |
# For image_visualizer, use display path | |
if tool_name == "image_visualizer": | |
self.display_file_path = tool_result["image_path"] | |
chat_history.append( | |
ChatMessage( | |
role="assistant", | |
# content=gr.Image(value=self.display_file_path), | |
content={"path": self.display_file_path}, | |
) | |
) | |
yield chat_history, self.display_file_path, "" | |
except Exception as e: | |
chat_history.append( | |
ChatMessage( | |
role="assistant", content=f"β Error: {str(e)}", metadata={"title": "Error"} | |
) | |
) | |
yield chat_history, self.display_file_path | |
finally: | |
store_chat_history(session_details['username'], session_details['session_id'], chat_history) | |
def store_chat_history(username, session_id, chat_history): | |
""" | |
Store the chat history (Agent responses) in S3 as a text file with a unique name. | |
Args: | |
username (str): The username of the user. | |
session_id (str): A unique session identifier. | |
chat_history (list): A list of agent responses to be saved in the text file. | |
Returns: | |
str: The URL of the uploaded chat history in S3. | |
""" | |
chat_history_text = "\n".join(chat_history) | |
timestamp = str(int(time.time())) # Get current timestamp for unique name | |
chat_history_path = f"/tmp/{session_id}_chat_history_{timestamp}.txt" | |
with open(chat_history_path, "w") as f: | |
f.write(chat_history_text) | |
# Upload chat history to S3 with the timestamp | |
return upload_to_s3(chat_history_path, username, session_id) | |
def upload_to_s3(file_path, username, session_id): | |
""" | |
Upload a file to S3 under the user's folder and session ID. | |
Args: | |
file_path (str): The path to the file to upload. | |
username (str): The username of the user. | |
session_id (str): A unique session identifier. | |
file_type (str): The type of file being uploaded (image, report, etc.). | |
Returns: | |
str: The URL of the uploaded file in S3. | |
""" | |
# Define the S3 object key | |
file_name = Path(file_path).name | |
s3_key = f"{username}/{session_id}/{file_name}" | |
# Upload the file to S3 | |
s3_client.upload_file(file_path, BUCKET_NAME, s3_key) | |
# Generate the file URL | |
file_url = f"https://{BUCKET_NAME}.s3.amazonaws.com/{s3_key}" | |
return file_url | |
def create_demo(agent, tools_dict): | |
""" | |
Create a Gradio demo interface for the medical AI agent. | |
Args: | |
agent: The medical AI agent to handle requests | |
tools_dict (dict): Dictionary of available tools for image processing | |
Returns: | |
gr.Blocks: Gradio Blocks interface | |
""" | |
interface = ChatInterface(agent, tools_dict) | |
session_details = {} | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
auth_state = gr.State(False) | |
with gr.Column(visible=True) as login_page: | |
gr.Markdown("## π Login") | |
username = gr.Textbox(label="Username") | |
password = gr.Textbox(label="Password", type="password") | |
login_button = gr.Button("Login") | |
login_error = gr.Markdown(visible=False) | |
with gr.Column(visible=False) as main_page: | |
gr.Markdown( | |
""" | |
# π₯ MOLx - Powered by MedRAX | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chatbot = gr.Chatbot( | |
[], | |
height=800, | |
container=True, | |
show_label=True, | |
elem_classes="chat-box", | |
type="messages", | |
label="Agent", | |
avatar_images=( | |
None, | |
"assets/medrax_logo.jpg", | |
), | |
) | |
with gr.Row(): | |
with gr.Column(scale=3): | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="Ask about the X-ray...", | |
container=False, | |
) | |
with gr.Column(scale=3): | |
with gr.Tabs(): | |
with gr.Tab(label="Image section"): | |
image_display = gr.Image( | |
label="Image", type="filepath", height=685, container=True | |
) | |
with gr.Row(): | |
analyze_btn = gr.Button("Analyze 1") | |
analyze2_btn = gr.Button("Analyze 2") | |
segment_btn = gr.Button("Segment") | |
with gr.Row(): | |
clear_btn = gr.Button("Clear Chat") | |
new_thread_btn = gr.Button("New Patient") | |
with gr.Tab(label="Report section"): | |
generate_report_btn = gr.Button("Generate Report") | |
conclusion_tb = gr.Textbox(label="Conclusion", interactive=False) | |
with gr.Row(): | |
approve_btn = gr.Button("Approve", visible=False) | |
download_pdf_btn = gr.DownloadButton(label="π₯ Download PDF", visible=False) | |
pdf_preview = PDF(visible=False) | |
# Event handlers | |
def authenticate(username, password): | |
hashed = USERS.get(username) | |
if hashed and bcrypt.checkpw(password.encode(), hashed): | |
session_details["username"] = username | |
session_details["session_id"] = str(uuid.uuid4()) + str(time.time()) | |
return ( | |
gr.update(visible=False), # hide login | |
gr.update(visible=True), # show main | |
gr.update(visible=False), # hide error | |
True # set state | |
) | |
return None, None, gr.update(value="β Incorrect username or password", visible=True), False | |
def clear_chat(): | |
interface.original_file_path = None | |
interface.display_file_path = None | |
return [], None | |
def new_thread(): | |
interface.current_thread_id = str(time.time()) | |
return ( | |
[], | |
interface.display_file_path, | |
gr.update(value=None, interactive=False), | |
gr.update(visible=False), | |
# gr.update(visible=False), | |
gr.update(value=None, visible=False), | |
gr.update(value=None, visible=False) | |
) | |
def handle_file_upload(file): | |
return interface.handle_upload(file.name) | |
def generate_report(): | |
result = interface.agent.summarize_message(interface.current_thread_id) | |
return ( | |
gr.update(value=result["Conclusion"], lines=4, interactive=True), | |
gr.update(visible=True), | |
) | |
def records_to_pdf(conclusion) -> Path: | |
""" | |
Writes a PDF report under ./reports/ and returns the Path. | |
""" | |
pdf = FPDF() | |
pdf.set_auto_page_break(auto=True, margin=15) | |
pdf.add_page() | |
pdf.set_font(family="Helvetica", size=12) | |
pdf.cell(0, 10, "Chest-X-ray Report", ln=1, align="C") | |
pdf.ln(4) | |
pdf.set_font(family="Helvetica", style="") | |
pdf.multi_cell(0, 8, conclusion) | |
pdf_path = REPORT_DIR / f"report_{uuid.uuid4().hex}.pdf" | |
pdf.output(str(pdf_path)) | |
return pdf_path | |
def build_pdf_and_preview(conclusion): | |
pdf_path = records_to_pdf(conclusion) | |
iframe_html = ( | |
f'<iframe src="file={pdf_path}" ' | |
'style="width:100%;height:650px;border:none;"></iframe>' | |
) | |
return ( | |
gr.update(value=pdf_path, visible=True), # for DownloadButton | |
gr.update(value=str(pdf_path), visible=True) # for HTML preview | |
) | |
def show_reject_ui(): | |
return gr.update(visible=True, value=""), gr.update(visible=True), gr.update(visible=True) | |
def hide_reject_ui(): | |
return gr.update(visible=False, value=""), gr.update(visible=False), gr.update(visible=False) | |
login_button.click(authenticate, [username, password], [login_page, main_page, login_error, auth_state]) | |
chat_msg = txt.submit( | |
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] | |
) | |
bot_msg = chat_msg.then( | |
interface.process_message, | |
inputs=[txt, image_display, chatbot], | |
outputs=[chatbot, image_display, session_details, txt], | |
) | |
bot_msg.then(lambda: gr.Textbox(interactive=True), None, [txt]) | |
analyze_btn.click( | |
lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the medgemma_xray_expert tool"), None, txt | |
).then( | |
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] | |
).then( | |
interface.process_message, | |
inputs=[txt, image_display, chatbot], | |
outputs=[chatbot, image_display, session_details, txt], | |
).then(lambda: gr.Textbox(interactive=True), None, [txt]) | |
analyze2_btn.click( | |
lambda: gr.update(value="Analyze this xray and give me a detailed response. Use the chest_xray_expert tool"), None, txt | |
).then( | |
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] | |
).then( | |
interface.process_message, | |
inputs=[txt, image_display, chatbot], | |
outputs=[chatbot, image_display, session_details, txt], | |
).then(lambda: gr.Textbox(interactive=True), None, [txt]) | |
segment_btn.click( | |
lambda: gr.update(value="Segment the major affected lung"), None, txt | |
).then( | |
interface.add_message, inputs=[txt, image_display, chatbot], outputs=[chatbot, txt] | |
).then( | |
interface.process_message, | |
inputs=[txt, image_display, session_details, chatbot], | |
outputs=[chatbot, image_display, txt], | |
).then(lambda: gr.Textbox(interactive=True), None, [txt]) | |
clear_btn.click(clear_chat, outputs=[chatbot, image_display]) | |
new_thread_btn.click(new_thread, outputs=[chatbot, image_display, conclusion_tb, approve_btn, download_pdf_btn, pdf_preview]) | |
generate_report_btn.click(generate_report, outputs=[conclusion_tb, approve_btn]) | |
approve_btn.click( | |
build_pdf_and_preview, | |
# inputs=[diseases_df, conclusion_tb], | |
inputs=[conclusion_tb], | |
outputs=[download_pdf_btn, pdf_preview], | |
) | |
return demo | |