Spaces:
Sleeping
Sleeping
File size: 7,990 Bytes
ff5dbd8 4f25019 ff5dbd8 49e8b23 ff5dbd8 49e8b23 ff5dbd8 379afa7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
import os
import time
from typing import List, Tuple
import pandas as pd
import gradio as gr
# For LLM and embeddings:
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
# For RAG prompting
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
##########################################
# 1. Load & Prepare Mental Health Data #
##########################################
def load_mental_health_data(file_paths, max_rows_per_file=100, max_columns=3):
"""
Loads multiple mental health datasets,
limiting the number of rows and columns to reduce resource usage.
"""
dataframes = {}
context_data = []
for key, file in file_paths.items():
try:
df = pd.read_csv(file)
dataframes[key] = df
# Limit rows to reduce size
limited_df = df.head(max_rows_per_file)
# Generate minimal context strings for each row
for i in range(len(limited_df)):
context = f"Dataset: {key} | "
for j in range(min(max_columns, len(limited_df.columns))):
column_name = limited_df.columns[j]
cell_value = limited_df.iloc[i][j]
context += f"{column_name}: {cell_value} "
context_data.append(context.strip())
except Exception as e:
print(f"Error loading {file}: {e}")
return context_data
##########################################
# 2. Specify File Paths #
##########################################
file_paths = {
'df1': '1- mental-illnesses-prevalence.csv',
'df2': '2- burden-disease-from-each-mental-illness(1).csv',
'df3': '3- adult-population-covered-in-primary-data-on-the-prevalence-of-major-depression.csv',
'df4': '4- adult-population-covered-in-primary-data-on-the-prevalence-of-mental-illnesses.csv',
'df5': '5- anxiety-disorders-treatment-gap.csv',
'df6': '6- depressive-symptoms-across-us-population.csv',
'df7': '7- number-of-countries-with-primary-data-on-prevalence-of-mental-illnesses-in-the-global-burden-of-disease-study.csv',
'df8': 'train.csv'
}
##########################################
# 3. Generate Context Data #
##########################################
context_data = load_mental_health_data(
file_paths=file_paths,
max_rows_per_file=100,
max_columns=3
)
##########################################
# 4. Initialize LLM & Embeddings #
##########################################
groq_api = "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"
llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api)
embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
##########################################
# 5. Create Vector Store & Retriever #
##########################################
vectorstore = Chroma(
collection_name="mental_health_store",
embedding_function=embed_model,
persist_directory="./" # Where the index is stored
)
vectorstore.add_texts(context_data)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
##########################################
# 6. Define Prompt Template for RAG #
##########################################
template = """You are a professional, empathetic mental health support AI assistant.
Your role is to offer supportive, informative responses while maintaining appropriate
boundaries and encouraging professional help when needed.
Context from mental health research: {context}
User Message: {question}
Please provide a response that:
1. Shows empathy and understanding
2. Provides relevant information based on research data when applicable
3. Encourages professional help when appropriate
4. Offers practical coping strategies when suitable
5. Maintains appropriate boundaries and disclaimers
Note: Always clarify that you're an AI assistant and not a replacement for professional mental health care.
Supportive Response:
"""
rag_prompt = PromptTemplate.from_template(template)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
##########################################
# 7. Processing Incoming Messages #
##########################################
def process_message(message: str, history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
Process each message in a single shot (no streaming)
to avoid the 'Data incompatible with tuples format' error.
"""
if history is None:
history = []
# Flatten the conversation into one string for context
history_str = ""
for user_msg, assistant_msg in history:
history_str += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
# Prepare RAG input
chain_input = {
"context": history_str,
"question": message
}
# Call the RAG chain to get the response
try:
answer = rag_chain.invoke(chain_input)
except Exception as e:
print(f"Error generating response: {e}")
answer = (
"I’m sorry, but something went wrong while I was generating a response. "
"If you need immediate support, please reach out to a mental health professional."
)
# Update the conversation history with the user’s question and the assistant’s reply
history.append((message, answer))
return history
##########################################
# 8. Build the Gradio Interface #
##########################################
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo:
gr.Markdown("""
# Professional Mental Health Support Assistant
**Providing empathetic guidance and insights.**
**Important Note:** This AI assistant offers supportive information based on mental health research data.
It is NOT a replacement for professional mental health care or immediate crisis resources.
If you are in an emergency or crisis situation, please contact emergency services (114 In Rwanda)
or reach out to a trusted healthcare professional immediately.
""")
chatbot = gr.Chatbot(
height=500,
show_copy_button=True,
bubble_full_width=False,
avatar_images=["👤", "🤖"]
)
with gr.Row():
msg = gr.Textbox(
placeholder="How are you feeling today? Type your question here...",
container=False,
scale=9
)
submit = gr.Button("Send", scale=1, variant="primary")
gr.Examples(
examples=[
"What are common symptoms of depression?",
"How prevalent is anxiety in the global population?",
"What are some coping strategies for stress?",
"Can you tell me about treatment options for mental health issues?",
"How do I know if I should seek professional help?"
],
inputs=msg
)
gr.Markdown("""
### Helpful Resources
- **Emergency Services**: Call your local emergency number (144 in the Rwanda).
- **Helplines**: Contact your country’s mental health helpline if you or someone you know needs urgent assistance.
""")
# Link the input and output
msg.submit(
process_message,
[msg, chatbot],
chatbot,
queue=True
)
submit.click(
process_message,
[msg, chatbot],
chatbot,
queue=True
)
##########################################
# 9. Launch the Gradio App #
##########################################
if __name__ == "__main__":
demo.queue(max_size=20).launch(
share=True,
show_error=True,
server_name="0.0.0.0",
server_port=7860
) |