Dhruv-Ty's picture
updates
c044359
raw
history blame
3.55 kB
import os
import warnings
from typing import *
from dotenv import load_dotenv
from transformers import logging
from langgraph.checkpoint.memory import MemorySaver
from langchain_openai import ChatOpenAI
from interface import create_demo
from medrax.agent import *
from medrax.tools import *
from medrax.utils import *
warnings.filterwarnings("ignore")
logging.set_verbosity_error()
_ = load_dotenv() # Loads your .env file (OPENAI_API_KEY and OPENAI_BASE_URL)
def initialize_agent(
prompt_file,
tools_to_use=None,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
model="google/gemini-1.5-flash-latest", # ✅ updated model name for OpenRouter
temperature=0.7,
top_p=0.95,
openai_kwargs=None
):
prompts = load_prompts_from_file(prompt_file)
prompt = prompts["MEDICAL_ASSISTANT"]
all_tools = {
"ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
"XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
cache_dir=model_dir, device=device
),
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
),
"ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
),
"ImageVisualizerTool": lambda: ImageVisualizerTool(),
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
}
tools_dict = {}
tools_to_use = tools_to_use or all_tools.keys()
for tool_name in tools_to_use:
if tool_name in all_tools:
tools_dict[tool_name] = all_tools[tool_name]()
checkpointer = MemorySaver()
model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
agent = Agent(
model,
tools=list(tools_dict.values()),
log_tools=True,
log_dir="logs",
system_prompt=prompt,
checkpointer=checkpointer,
)
print("Agent initialized")
return agent, tools_dict
if __name__ == "__main__":
print("Starting server...")
selected_tools = [
"ImageVisualizerTool",
"DicomProcessorTool",
"ChestXRayClassifierTool",
"ChestXRaySegmentationTool",
"ChestXRayReportGeneratorTool",
"XRayVQATool",
# "LlavaMedTool",
# "XRayPhraseGroundingTool",
# "ChestXRayGeneratorTool",
]
# ✅ Collect environment variables and pass to model
openai_kwargs = {}
if api_key := os.getenv("OPENAI_API_KEY"):
openai_kwargs["openai_api_key"] = api_key
if base_url := os.getenv("OPENAI_BASE_URL"):
openai_kwargs["openai_api_base"] = base_url
agent, tools_dict = initialize_agent(
"medrax/docs/system_prompts.txt",
tools_to_use=selected_tools,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
model="google/gemini-1.5-flash-latest", # ✅ Updated OpenRouter model
temperature=0.7,
top_p=0.95,
openai_kwargs=openai_kwargs
)
demo = create_demo(agent, tools_dict)
demo.launch(server_name="0.0.0.0", server_port=8585, share=True)