Spaces:
Sleeping
Sleeping
import os | |
import warnings | |
from typing import * | |
from dotenv import load_dotenv | |
from transformers import logging | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain_openai import ChatOpenAI | |
from interface import create_demo | |
from medrax.agent import * | |
from medrax.tools import * | |
from medrax.utils import * | |
warnings.filterwarnings("ignore") | |
logging.set_verbosity_error() | |
load_dotenv() | |
# Set environment variables explicitly to ensure they're available | |
api_key = os.getenv("OPENAI_API_KEY") | |
base_url = os.getenv("OPENAI_BASE_URL") | |
if not api_key: | |
raise ValueError("OPENAI_API_KEY not found in environment variables") | |
if not base_url: | |
raise ValueError("OPENAI_BASE_URL not found in environment variables") | |
# Set them in environment for libraries that might read directly from os.environ | |
os.environ["OPENAI_API_KEY"] = api_key | |
os.environ["OPENAI_BASE_URL"] = base_url | |
def initialize_agent( | |
prompt_file, | |
tools_to_use=None, | |
model_dir="./model-weights", | |
temp_dir="temp", | |
device="cuda", | |
model="qwen/qwen2.5-vl-3b-instruct:free", | |
temperature=0.7, | |
top_p=0.95 | |
): | |
"""Initialize the MedRAX agent with specified tools and configuration.""" | |
prompts = load_prompts_from_file(prompt_file) | |
prompt = prompts["MEDICAL_ASSISTANT"] | |
all_tools = { | |
"ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device), | |
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device), | |
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True), | |
"XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device), | |
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool( | |
cache_dir=model_dir, device=device | |
), | |
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool( | |
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device | |
), | |
"ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool( | |
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device | |
), | |
"ImageVisualizerTool": lambda: ImageVisualizerTool(), | |
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir), | |
} | |
tools_dict = {} | |
tools_to_use = tools_to_use or all_tools.keys() | |
for tool_name in tools_to_use: | |
if tool_name in all_tools: | |
tools_dict[tool_name] = all_tools[tool_name]() | |
checkpointer = MemorySaver() | |
# Explicitly pass the API key and base URL | |
model = ChatOpenAI( | |
model_name=model, | |
api_key=api_key, | |
base_url=base_url, | |
temperature=temperature, | |
top_p=top_p, | |
) | |
agent = Agent( | |
model, | |
tools=list(tools_dict.values()), | |
log_tools=True, | |
log_dir="logs", | |
system_prompt=prompt, | |
checkpointer=checkpointer, | |
) | |
print("Agent initialized") | |
return agent, tools_dict | |
if __name__ == "__main__": | |
print("Starting server...") | |
selected_tools = [ | |
"ImageVisualizerTool", | |
"DicomProcessorTool", | |
"ChestXRayClassifierTool", | |
"ChestXRaySegmentationTool", | |
"ChestXRayReportGeneratorTool", | |
"XRayVQATool", | |
# "LlavaMedTool", | |
# "XRayPhraseGroundingTool", | |
# "ChestXRayGeneratorTool", | |
] | |
agent, tools_dict = initialize_agent( | |
"medrax/docs/system_prompts.txt", | |
tools_to_use=selected_tools, | |
model_dir="./model-weights", | |
temp_dir="temp", | |
device="cuda", | |
model="google/gemini-2.5-pro-exp-03-25:free", | |
temperature=0.7, | |
top_p=0.95 | |
) | |
demo = create_demo(agent, tools_dict) | |
# demo.launch(server_name="0.0.0.0", server_port=8585, share=True) | |
# demo.launch(debug=True, queue=True, ssr_mode=False) | |
demo.launch(debug=True, ssr_mode=False) |