Dhruv-Ty commited on
Commit
8af4e75
·
verified ·
1 Parent(s): 070c106

restored to original

Browse files

added the support for the gpt-4o-mini

Files changed (1) hide show
  1. app.py +50 -29
app.py CHANGED
@@ -1,10 +1,13 @@
1
  import os
2
  import warnings
3
  from typing import *
4
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
- from langchain.llms import HuggingFacePipeline
6
 
7
  from langgraph.checkpoint.memory import MemorySaver
 
 
 
8
 
9
  from interface import create_demo
10
  from medrax.agent import *
@@ -12,6 +15,8 @@ from medrax.tools import *
12
  from medrax.utils import *
13
 
14
  warnings.filterwarnings("ignore")
 
 
15
 
16
  def initialize_agent(
17
  prompt_file,
@@ -19,11 +24,27 @@ def initialize_agent(
19
  model_dir="./model-weights",
20
  temp_dir="temp",
21
  device="cuda",
 
22
  temperature=0.7,
23
- top_p=0.95
 
24
  ):
25
- """Initialize the MedRAX agent with specified tools and configuration."""
26
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  prompts = load_prompts_from_file(prompt_file)
28
  prompt = prompts["MEDICAL_ASSISTANT"]
29
 
@@ -45,6 +66,7 @@ def initialize_agent(
45
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
46
  }
47
 
 
48
  tools_dict = {}
49
  tools_to_use = tools_to_use or all_tools.keys()
50
  for tool_name in tools_to_use:
@@ -52,24 +74,7 @@ def initialize_agent(
52
  tools_dict[tool_name] = all_tools[tool_name]()
53
 
54
  checkpointer = MemorySaver()
55
-
56
- # Load local Hugging Face model
57
- hf_model_id = model_id = "HuggingFaceH4/zephyr-7b-alpha"
58
- tokenizer = AutoTokenizer.from_pretrained(hf_model_id)
59
- raw_model = AutoModelForCausalLM.from_pretrained(hf_model_id, device_map="auto")
60
-
61
- pipe = pipeline(
62
- "text-generation",
63
- model=raw_model,
64
- tokenizer=tokenizer,
65
- max_new_tokens=512,
66
- temperature=temperature,
67
- top_p=top_p,
68
- return_full_text=False,
69
- )
70
-
71
- model = HuggingFacePipeline(pipeline=pipe)
72
-
73
  agent = Agent(
74
  model,
75
  tools=list(tools_dict.values()),
@@ -82,10 +87,15 @@ def initialize_agent(
82
  print("Agent initialized")
83
  return agent, tools_dict
84
 
85
-
86
  if __name__ == "__main__":
 
 
 
 
87
  print("Starting server...")
88
 
 
 
89
  selected_tools = [
90
  "ImageVisualizerTool",
91
  "DicomProcessorTool",
@@ -98,15 +108,26 @@ if __name__ == "__main__":
98
  # "ChestXRayGeneratorTool",
99
  ]
100
 
 
 
 
 
 
 
 
 
101
  agent, tools_dict = initialize_agent(
102
  "medrax/docs/system_prompts.txt",
103
  tools_to_use=selected_tools,
104
- model_dir="./model-weights",
105
- temp_dir="temp",
106
- device="cuda",
 
107
  temperature=0.7,
108
- top_p=0.95
 
109
  )
110
-
111
  demo = create_demo(agent, tools_dict)
 
 
112
  demo.launch(debug=True, ssr_mode=False)
 
1
  import os
2
  import warnings
3
  from typing import *
4
+ from dotenv import load_dotenv
5
+ from transformers import logging
6
 
7
  from langgraph.checkpoint.memory import MemorySaver
8
+ from langchain_openai import ChatOpenAI
9
+ from langgraph.checkpoint.memory import MemorySaver
10
+ from langchain_openai import ChatOpenAI
11
 
12
  from interface import create_demo
13
  from medrax.agent import *
 
15
  from medrax.utils import *
16
 
17
  warnings.filterwarnings("ignore")
18
+ logging.set_verbosity_error()
19
+ _ = load_dotenv()
20
 
21
  def initialize_agent(
22
  prompt_file,
 
24
  model_dir="./model-weights",
25
  temp_dir="temp",
26
  device="cuda",
27
+ model="gpt-4o-mini",
28
  temperature=0.7,
29
+ top_p=0.95,
30
+ openai_kwargs={}
31
  ):
32
+ """Initialize the MedRAX agent with specified tools and configuration.
33
+
34
+ Args:
35
+ prompt_file (str): Path to file containing system prompts
36
+ tools_to_use (List[str], optional): List of tool names to initialize. If None, all tools are initialized.
37
+ model_dir (str, optional): Directory containing model weights. Defaults to "/model-weights".
38
+ temp_dir (str, optional): Directory for temporary files. Defaults to "temp".
39
+ device (str, optional): Device to run models on. Defaults to "cuda".
40
+ model (str, optional): Model to use. Defaults to "chatgpt-4o-latest".
41
+ temperature (float, optional): Temperature for the model. Defaults to 0.7.
42
+ top_p (float, optional): Top P for the model. Defaults to 0.95.
43
+ openai_kwargs (dict, optional): Additional keyword arguments for OpenAI API, such as API key and base URL.
44
+
45
+ Returns:
46
+ Tuple[Agent, Dict[str, BaseTool]]: Initialized agent and dictionary of tool instances
47
+ """
48
  prompts = load_prompts_from_file(prompt_file)
49
  prompt = prompts["MEDICAL_ASSISTANT"]
50
 
 
66
  "DicomProcessorTool": lambda: DicomProcessorTool(temp_dir=temp_dir),
67
  }
68
 
69
+ # Initialize only selected tools or all if none specified
70
  tools_dict = {}
71
  tools_to_use = tools_to_use or all_tools.keys()
72
  for tool_name in tools_to_use:
 
74
  tools_dict[tool_name] = all_tools[tool_name]()
75
 
76
  checkpointer = MemorySaver()
77
+ model = ChatOpenAI(model=model, temperature=temperature, top_p=top_p, **openai_kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  agent = Agent(
79
  model,
80
  tools=list(tools_dict.values()),
 
87
  print("Agent initialized")
88
  return agent, tools_dict
89
 
 
90
  if __name__ == "__main__":
91
+ """
92
+ This is the main entry point for the MedRAX application.
93
+ It initializes the agent with the selected tools and creates the demo.
94
+ """
95
  print("Starting server...")
96
 
97
+ # Example: initialize with only specific tools
98
+ # Here three tools are commented out, you can uncomment them to use them
99
  selected_tools = [
100
  "ImageVisualizerTool",
101
  "DicomProcessorTool",
 
108
  # "ChestXRayGeneratorTool",
109
  ]
110
 
111
+ # Collect the ENV variables
112
+ openai_kwargs = {}
113
+ if api_key := os.getenv("OPENAI_API_KEY"):
114
+ openai_kwargs["api_key"] = api_key
115
+
116
+ if base_url := os.getenv("OPENAI_BASE_URL"):
117
+ openai_kwargs["base_url"] = base_url
118
+
119
  agent, tools_dict = initialize_agent(
120
  "medrax/docs/system_prompts.txt",
121
  tools_to_use=selected_tools,
122
+ model_dir="./model-weights", # Change this to the path of the model weights
123
+ temp_dir="temp", # Change this to the path of the temporary directory
124
+ device="cuda", # Change this to the device you want to use
125
+ model="gpt-4o-mini", # Change this to the model you want to use, e.g. gpt-4o-mini
126
  temperature=0.7,
127
+ top_p=0.95,
128
+ openai_kwargs=openai_kwargs
129
  )
 
130
  demo = create_demo(agent, tools_dict)
131
+
132
+ # demo.launch(server_name="0.0.0.0", server_port=8585, share=True)
133
  demo.launch(debug=True, ssr_mode=False)