Dhruv-Ty's picture
model updated model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
1ee0803 verified
raw
history blame
3.43 kB
import os
import warnings
from typing import *
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from langgraph.checkpoint.memory import MemorySaver
from interface import create_demo
from medrax.agent import *
from medrax.tools import *
from medrax.utils import *
warnings.filterwarnings("ignore")
def initialize_agent(
prompt_file,
tools_to_use=None,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
temperature=0.7,
top_p=0.95
):
"""Initialize the MedRAX agent with specified tools and configuration."""
prompts = load_prompts_from_file(prompt_file)
prompt = prompts["MEDICAL_ASSISTANT"]
all_tools = {
"ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
"XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
cache_dir=model_dir, device=device
),
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
),
"ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
),
"ImageVisualizerTool": lambda: ImageVisualizerTool(),
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
}
tools_dict = {}
tools_to_use = tools_to_use or all_tools.keys()
for tool_name in tools_to_use:
if tool_name in all_tools:
tools_dict[tool_name] = all_tools[tool_name]()
checkpointer = MemorySaver()
# Load local Hugging Face model
hf_model_id = model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_id)
raw_model = AutoModelForCausalLM.from_pretrained(hf_model_id, device_map="auto")
pipe = pipeline(
"text-generation",
model=raw_model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=temperature,`
top_p=top_p,
return_full_text=False,
)
model = HuggingFacePipeline(pipeline=pipe)
agent = Agent(
model,
tools=list(tools_dict.values()),
log_tools=True,
log_dir="logs",
system_prompt=prompt,
checkpointer=checkpointer,
)
print("Agent initialized")
return agent, tools_dict
if __name__ == "__main__":
print("Starting server...")
selected_tools = [
"ImageVisualizerTool",
"DicomProcessorTool",
"ChestXRayClassifierTool",
"ChestXRaySegmentationTool",
"ChestXRayReportGeneratorTool",
"XRayVQATool",
# "LlavaMedTool",
# "XRayPhraseGroundingTool",
# "ChestXRayGeneratorTool",
]
agent, tools_dict = initialize_agent(
"medrax/docs/system_prompts.txt",
tools_to_use=selected_tools,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
temperature=0.7,
top_p=0.95
)
demo = create_demo(agent, tools_dict)
demo.launch(debug=True, ssr_mode=False)