Santhosh1705kumar commited on
Commit
decc13a
·
verified ·
1 Parent(s): fe48a9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -76
app.py CHANGED
@@ -1,88 +1,124 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
  import pandas as pd
 
 
5
 
6
- # Load the symptom dataset once
7
- df = pd.read_csv("enhanced_symptom_tree_with_measures.csv")
8
 
9
- # Load model and tokenizer once to optimize performance
10
- def load_model():
11
- model_name = "microsoft/phi-2" # Lightweight model
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
14
- model.eval() # Set model to evaluation mode
15
- return tokenizer, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- tokenizer, model = load_model()
 
18
 
19
- # Global variable to track last detected symptom
20
- last_detected_symptom = None
 
 
21
 
22
- # Function to find symptom match
23
- def find_symptom_match(user_input):
24
- match = df[df["Primary Symptom"].str.lower().str.contains(user_input.lower(), na=False)]
25
- if not match.empty:
26
- symptom = match.iloc[0]
27
- response = f"It seems like you're experiencing {symptom['Primary Symptom']}. "
28
- if pd.notna(symptom["Follow-up Question"]):
29
- response += f"{symptom['Follow-up Question']} "
30
- response += f"\nPossible conditions: {symptom['Possible Diseases']} \n"
31
- response += f"Recommended measures: {symptom['Recommended Measures']}"
32
- return symptom['Primary Symptom'], response # Return symptom name and response
33
- return None, None
34
-
35
- # Main chatbot response function
36
- def chatbot_response(user_input, history):
37
- global last_detected_symptom # Maintain previous symptom context
38
 
39
- if not user_input.strip():
40
- return history, ""
 
41
 
42
- # Step 1: If it's a follow-up response, continue from the last known symptom
43
- if last_detected_symptom:
44
- match = df[df["Primary Symptom"].str.lower() == last_detected_symptom.lower()]
45
- if not match.empty:
46
- follow_up_options = match.iloc[0]["Follow-up Question"]
47
- if pd.notna(follow_up_options) and user_input.lower() in follow_up_options.lower():
48
- response = f"Got it. Based on that, possible conditions: {match.iloc[0]['Possible Diseases']} \nRecommended: {match.iloc[0]['Recommended Measures']}"
49
- last_detected_symptom = None # Reset symptom tracking
50
- history.append((user_input, response))
51
- return history, ""
52
-
53
- # Step 2: Otherwise, check for a new symptom
54
- detected_symptom, symptom_info = find_symptom_match(user_input)
55
- if symptom_info:
56
- last_detected_symptom = detected_symptom # Store new symptom for next response
57
- response = symptom_info
58
- else:
59
- # Step 3: Use LLM as a fallback if no match is found
60
- prompt = f"User: {user_input}\nChatbot:"
61
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
62
- with torch.no_grad():
63
- outputs = model.generate(**inputs, max_length=150, pad_token_id=tokenizer.eos_token_id)
64
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("Chatbot:")[-1].strip()
65
-
66
- history.append((user_input, response))
67
- return history, ""
68
-
69
- # Function to clear chat
70
- def clear_chat():
71
- global last_detected_symptom
72
- last_detected_symptom = None # Reset symptom tracking
73
- return [], ""
 
 
 
 
 
 
 
 
 
74
 
75
  # Gradio UI
76
- with gr.Blocks(theme="compact") as demo:
77
- gr.Markdown("### Symptom Chatbot 🏥")
78
- chatbot = gr.Chatbot()
79
- user_input = gr.Textbox(placeholder="Type your symptom and get advice...", interactive=True)
80
- submit = gr.Button("Send")
81
- clear = gr.Button("Clear Chat")
82
-
83
- submit.click(chatbot_response, [user_input, chatbot], [chatbot, user_input])
84
- clear.click(clear_chat, [], [chatbot, user_input])
85
 
86
- # Launch app in Hugging Face Spaces environment
87
- if __name__ == "__main__":
88
- demo.launch()
 
1
  import gradio as gr
 
 
2
  import pandas as pd
3
+ from collections import defaultdict
4
+ from transformers import pipeline
5
 
6
+ # Load the Phi-2 model for symptom normalization
7
+ symptom_normalizer = pipeline("text-classification", model="microsoft/phi-2")
8
 
9
+ # Predefined symptoms
10
+ symptom_data = {
11
+ "Shortness of breath": {
12
+ "questions": [
13
+ "Do you also have chest pain?",
14
+ "Do you feel fatigued often?",
15
+ "Have you noticed swelling in your legs?"
16
+ ],
17
+ "diseases": ["Atelectasis", "Emphysema", "Edema"],
18
+ "weights_yes": [30, 30, 40],
19
+ "weights_no": [10, 20, 30]
20
+ },
21
+ "Persistent cough": {
22
+ "questions": [
23
+ "Is your cough dry or with mucus?",
24
+ "Do you experience fever?",
25
+ "Do you have difficulty breathing?"
26
+ ],
27
+ "diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
28
+ "weights_yes": [35, 30, 35],
29
+ "weights_no": [10, 15, 20]
30
+ },
31
+ "Sharp chest pain": {
32
+ "questions": [
33
+ "Does it worsen with deep breaths?",
34
+ "Do you feel lightheaded?",
35
+ "Have you had recent trauma or surgery?"
36
+ ],
37
+ "diseases": ["Pneumothorax", "Effusion", "Cardiomegaly"],
38
+ "weights_yes": [40, 30, 30],
39
+ "weights_no": [15, 20, 25]
40
+ },
41
+ "Fatigue & swelling": {
42
+ "questions": [
43
+ "Do you feel breathless when lying down?",
44
+ "Have you gained weight suddenly?",
45
+ "Do you experience irregular heartbeat?"
46
+ ],
47
+ "diseases": ["Edema", "Cardiomegaly"],
48
+ "weights_yes": [50, 30, 20],
49
+ "weights_no": [20, 15, 15]
50
+ },
51
+ "Chronic wheezing": {
52
+ "questions": [
53
+ "Do you have a history of smoking?",
54
+ "Do you feel tightness in your chest?",
55
+ "Do you have frequent lung infections?"
56
+ ],
57
+ "diseases": ["Emphysema", "Fibrosis"],
58
+ "weights_yes": [40, 30, 30],
59
+ "weights_no": [15, 25, 20]
60
+ }
61
+ }
62
 
63
+ # Global variables to track user state
64
+ user_state = {}
65
 
66
+ def normalize_symptom(user_input):
67
+ labels = list(symptom_data.keys())
68
+ predictions = symptom_normalizer(user_input, labels=labels)
69
+ return predictions[0]['label']
70
 
71
+ def chatbot(user_input):
72
+ if "state" not in user_state:
73
+ user_state["state"] = "greet"
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ if user_state["state"] == "greet":
76
+ user_state["state"] = "ask_symptom"
77
+ return "Hello! I'm a medical AI assistant. Please describe your primary symptom."
78
 
79
+ elif user_state["state"] == "ask_symptom":
80
+ normalized_symptom = normalize_symptom(user_input)
81
+ if normalized_symptom not in symptom_data:
82
+ return "I don't recognize that symptom. Please enter one of these: " + ", ".join(symptom_data.keys())
83
+ user_state["symptom"] = normalized_symptom
84
+ user_state["state"] = "ask_duration"
85
+ return "How long have you been experiencing this symptom? (Less than a week / More than a week)"
86
+
87
+ elif user_state["state"] == "ask_duration":
88
+ if user_input.lower() == "less than a week":
89
+ user_state.clear()
90
+ return "It might be a temporary issue. Please monitor your symptoms and consult a doctor if they persist."
91
+ elif user_input.lower() == "more than a week":
92
+ user_state["state"] = "follow_up"
93
+ user_state["current_question"] = 0
94
+ user_state["disease_scores"] = defaultdict(int)
95
+ return symptom_data[user_state["symptom"]]["questions"][0]
96
+ else:
97
+ return "Please respond with 'Less than a week' or 'More than a week'."
98
+
99
+ elif user_state["state"] == "follow_up":
100
+ symptom = user_state["symptom"]
101
+ question_index = user_state["current_question"]
102
+
103
+ # Update probabilities
104
+ if user_input.lower() == "yes":
105
+ for i, disease in enumerate(symptom_data[symptom]["diseases"]):
106
+ user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
107
+ else:
108
+ for i, disease in enumerate(symptom_data[symptom]["diseases"]):
109
+ user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
110
+
111
+ # Move to the next question or finish
112
+ user_state["current_question"] += 1
113
+ if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
114
+ return symptom_data[symptom]["questions"][user_state["current_question"]]
115
+
116
+ # Final diagnosis
117
+ probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
118
+ user_state.clear()
119
+ return f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor for confirmation."
120
 
121
  # Gradio UI
122
+ demo = gr.Interface(fn=chatbot, inputs=gr.Textbox(placeholder="Enter your response..."), outputs="text", title="Symptom Chatbot")
 
 
 
 
 
 
 
 
123
 
124
+ demo.launch()