Mentalchatbot / app.py
Muhirwa12a's picture
Update app.py
4f25019 verified
import os
import time
from typing import List, Tuple
import pandas as pd
import gradio as gr
# For LLM and embeddings:
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
# For RAG prompting
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
##########################################
# 1. Load & Prepare Mental Health Data #
##########################################
def load_mental_health_data(file_paths, max_rows_per_file=100, max_columns=3):
"""
Loads multiple mental health datasets,
limiting the number of rows and columns to reduce resource usage.
"""
dataframes = {}
context_data = []
for key, file in file_paths.items():
try:
df = pd.read_csv(file)
dataframes[key] = df
# Limit rows to reduce size
limited_df = df.head(max_rows_per_file)
# Generate minimal context strings for each row
for i in range(len(limited_df)):
context = f"Dataset: {key} | "
for j in range(min(max_columns, len(limited_df.columns))):
column_name = limited_df.columns[j]
cell_value = limited_df.iloc[i][j]
context += f"{column_name}: {cell_value} "
context_data.append(context.strip())
except Exception as e:
print(f"Error loading {file}: {e}")
return context_data
##########################################
# 2. Specify File Paths #
##########################################
file_paths = {
'df1': '1- mental-illnesses-prevalence.csv',
'df2': '2- burden-disease-from-each-mental-illness(1).csv',
'df3': '3- adult-population-covered-in-primary-data-on-the-prevalence-of-major-depression.csv',
'df4': '4- adult-population-covered-in-primary-data-on-the-prevalence-of-mental-illnesses.csv',
'df5': '5- anxiety-disorders-treatment-gap.csv',
'df6': '6- depressive-symptoms-across-us-population.csv',
'df7': '7- number-of-countries-with-primary-data-on-prevalence-of-mental-illnesses-in-the-global-burden-of-disease-study.csv',
'df8': 'train.csv'
}
##########################################
# 3. Generate Context Data #
##########################################
context_data = load_mental_health_data(
file_paths=file_paths,
max_rows_per_file=100,
max_columns=3
)
##########################################
# 4. Initialize LLM & Embeddings #
##########################################
groq_api = "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"
llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api)
embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")
##########################################
# 5. Create Vector Store & Retriever #
##########################################
vectorstore = Chroma(
collection_name="mental_health_store",
embedding_function=embed_model,
persist_directory="./" # Where the index is stored
)
vectorstore.add_texts(context_data)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
##########################################
# 6. Define Prompt Template for RAG #
##########################################
template = """You are a professional, empathetic mental health support AI assistant.
Your role is to offer supportive, informative responses while maintaining appropriate
boundaries and encouraging professional help when needed.
Context from mental health research: {context}
User Message: {question}
Please provide a response that:
1. Shows empathy and understanding
2. Provides relevant information based on research data when applicable
3. Encourages professional help when appropriate
4. Offers practical coping strategies when suitable
5. Maintains appropriate boundaries and disclaimers
Note: Always clarify that you're an AI assistant and not a replacement for professional mental health care.
Supportive Response:
"""
rag_prompt = PromptTemplate.from_template(template)
rag_chain = (
{"context": retriever, "question": RunnablePassthrough()}
| rag_prompt
| llm
| StrOutputParser()
)
##########################################
# 7. Processing Incoming Messages #
##########################################
def process_message(message: str, history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""
Process each message in a single shot (no streaming)
to avoid the 'Data incompatible with tuples format' error.
"""
if history is None:
history = []
# Flatten the conversation into one string for context
history_str = ""
for user_msg, assistant_msg in history:
history_str += f"User: {user_msg}\nAssistant: {assistant_msg}\n"
# Prepare RAG input
chain_input = {
"context": history_str,
"question": message
}
# Call the RAG chain to get the response
try:
answer = rag_chain.invoke(chain_input)
except Exception as e:
print(f"Error generating response: {e}")
answer = (
"I’m sorry, but something went wrong while I was generating a response. "
"If you need immediate support, please reach out to a mental health professional."
)
# Update the conversation history with the user’s question and the assistant’s reply
history.append((message, answer))
return history
##########################################
# 8. Build the Gradio Interface #
##########################################
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo:
gr.Markdown("""
# Professional Mental Health Support Assistant
**Providing empathetic guidance and insights.**
**Important Note:** This AI assistant offers supportive information based on mental health research data.
It is NOT a replacement for professional mental health care or immediate crisis resources.
If you are in an emergency or crisis situation, please contact emergency services (114 In Rwanda)
or reach out to a trusted healthcare professional immediately.
""")
chatbot = gr.Chatbot(
height=500,
show_copy_button=True,
bubble_full_width=False,
avatar_images=["👤", "🤖"]
)
with gr.Row():
msg = gr.Textbox(
placeholder="How are you feeling today? Type your question here...",
container=False,
scale=9
)
submit = gr.Button("Send", scale=1, variant="primary")
gr.Examples(
examples=[
"What are common symptoms of depression?",
"How prevalent is anxiety in the global population?",
"What are some coping strategies for stress?",
"Can you tell me about treatment options for mental health issues?",
"How do I know if I should seek professional help?"
],
inputs=msg
)
gr.Markdown("""
### Helpful Resources
- **Emergency Services**: Call your local emergency number (144 in the Rwanda).
- **Helplines**: Contact your country’s mental health helpline if you or someone you know needs urgent assistance.
""")
# Link the input and output
msg.submit(
process_message,
[msg, chatbot],
chatbot,
queue=True
)
submit.click(
process_message,
[msg, chatbot],
chatbot,
queue=True
)
##########################################
# 9. Launch the Gradio App #
##########################################
if __name__ == "__main__":
demo.queue(max_size=20).launch(
share=True,
show_error=True,
server_name="0.0.0.0",
server_port=7860
)