Spaces:
Sleeping
Sleeping
File size: 3,433 Bytes
cb3a670 8b507f5 cb3a670 eb57a64 cb3a670 3f298d8 cb3a670 3f298d8 cb3a670 8b507f5 1ee0803 8b507f5 1ee0803 3f298d8 8b507f5 3f298d8 c044359 8b507f5 cb3a670 c044359 cb3a670 3f298d8 cb3a670 c044359 8b507f5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import warnings
from typing import *
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain.llms import HuggingFacePipeline
from langgraph.checkpoint.memory import MemorySaver
from interface import create_demo
from medrax.agent import *
from medrax.tools import *
from medrax.utils import *
warnings.filterwarnings("ignore")
def initialize_agent(
prompt_file,
tools_to_use=None,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
temperature=0.7,
top_p=0.95
):
"""Initialize the MedRAX agent with specified tools and configuration."""
prompts = load_prompts_from_file(prompt_file)
prompt = prompts["MEDICAL_ASSISTANT"]
all_tools = {
"ChestXRayClassifierTool": lambda: ChestXRayClassifierTool(device=device),
"ChestXRaySegmentationTool": lambda: ChestXRaySegmentationTool(device=device),
"LlavaMedTool": lambda: LlavaMedTool(cache_dir=model_dir, device=device, load_in_8bit=True),
"XRayVQATool": lambda: XRayVQATool(cache_dir=model_dir, device=device),
"ChestXRayReportGeneratorTool": lambda: ChestXRayReportGeneratorTool(
cache_dir=model_dir, device=device
),
"XRayPhraseGroundingTool": lambda: XRayPhraseGroundingTool(
cache_dir=model_dir, temp_dir=temp_dir, load_in_8bit=True, device=device
),
"ChestXRayGeneratorTool": lambda: ChestXRayGeneratorTool(
model_path=f"{model_dir}/roentgen", temp_dir=temp_dir, device=device
),
"ImageVisualizerTool": lambda: ImageVisualizerTool(),
"DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
}
tools_dict = {}
tools_to_use = tools_to_use or all_tools.keys()
for tool_name in tools_to_use:
if tool_name in all_tools:
tools_dict[tool_name] = all_tools[tool_name]()
checkpointer = MemorySaver()
# Load local Hugging Face model
hf_model_id = model_id = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_id)
raw_model = AutoModelForCausalLM.from_pretrained(hf_model_id, device_map="auto")
pipe = pipeline(
"text-generation",
model=raw_model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=temperature,`
top_p=top_p,
return_full_text=False,
)
model = HuggingFacePipeline(pipeline=pipe)
agent = Agent(
model,
tools=list(tools_dict.values()),
log_tools=True,
log_dir="logs",
system_prompt=prompt,
checkpointer=checkpointer,
)
print("Agent initialized")
return agent, tools_dict
if __name__ == "__main__":
print("Starting server...")
selected_tools = [
"ImageVisualizerTool",
"DicomProcessorTool",
"ChestXRayClassifierTool",
"ChestXRaySegmentationTool",
"ChestXRayReportGeneratorTool",
"XRayVQATool",
# "LlavaMedTool",
# "XRayPhraseGroundingTool",
# "ChestXRayGeneratorTool",
]
agent, tools_dict = initialize_agent(
"medrax/docs/system_prompts.txt",
tools_to_use=selected_tools,
model_dir="./model-weights",
temp_dir="temp",
device="cuda",
temperature=0.7,
top_p=0.95
)
demo = create_demo(agent, tools_dict)
demo.launch(debug=True, ssr_mode=False)
|