Spaces:
Sleeping
Sleeping
import os | |
import time | |
from typing import List, Tuple | |
import pandas as pd | |
import gradio as gr | |
# For LLM and embeddings: | |
from langchain_groq import ChatGroq | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_chroma import Chroma | |
# For RAG prompting | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.runnables import RunnablePassthrough | |
########################################## | |
# 1. Load & Prepare Mental Health Data # | |
########################################## | |
def load_mental_health_data(file_paths, max_rows_per_file=100, max_columns=3): | |
""" | |
Loads multiple mental health datasets, | |
limiting the number of rows and columns to reduce resource usage. | |
""" | |
dataframes = {} | |
context_data = [] | |
for key, file in file_paths.items(): | |
try: | |
df = pd.read_csv(file) | |
dataframes[key] = df | |
# Limit rows to reduce size | |
limited_df = df.head(max_rows_per_file) | |
# Generate minimal context strings for each row | |
for i in range(len(limited_df)): | |
context = f"Dataset: {key} | " | |
for j in range(min(max_columns, len(limited_df.columns))): | |
column_name = limited_df.columns[j] | |
cell_value = limited_df.iloc[i][j] | |
context += f"{column_name}: {cell_value} " | |
context_data.append(context.strip()) | |
except Exception as e: | |
print(f"Error loading {file}: {e}") | |
return context_data | |
########################################## | |
# 2. Specify File Paths # | |
########################################## | |
file_paths = { | |
'df1': '1- mental-illnesses-prevalence.csv', | |
'df2': '2- burden-disease-from-each-mental-illness(1).csv', | |
'df3': '3- adult-population-covered-in-primary-data-on-the-prevalence-of-major-depression.csv', | |
'df4': '4- adult-population-covered-in-primary-data-on-the-prevalence-of-mental-illnesses.csv', | |
'df5': '5- anxiety-disorders-treatment-gap.csv', | |
'df6': '6- depressive-symptoms-across-us-population.csv', | |
'df7': '7- number-of-countries-with-primary-data-on-prevalence-of-mental-illnesses-in-the-global-burden-of-disease-study.csv', | |
'df8': 'train.csv' | |
} | |
########################################## | |
# 3. Generate Context Data # | |
########################################## | |
context_data = load_mental_health_data( | |
file_paths=file_paths, | |
max_rows_per_file=100, | |
max_columns=3 | |
) | |
########################################## | |
# 4. Initialize LLM & Embeddings # | |
########################################## | |
groq_api = "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" | |
llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api) | |
embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1") | |
########################################## | |
# 5. Create Vector Store & Retriever # | |
########################################## | |
vectorstore = Chroma( | |
collection_name="mental_health_store", | |
embedding_function=embed_model, | |
persist_directory="./" # Where the index is stored | |
) | |
vectorstore.add_texts(context_data) | |
retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) | |
########################################## | |
# 6. Define Prompt Template for RAG # | |
########################################## | |
template = """You are a professional, empathetic mental health support AI assistant. | |
Your role is to offer supportive, informative responses while maintaining appropriate | |
boundaries and encouraging professional help when needed. | |
Context from mental health research: {context} | |
User Message: {question} | |
Please provide a response that: | |
1. Shows empathy and understanding | |
2. Provides relevant information based on research data when applicable | |
3. Encourages professional help when appropriate | |
4. Offers practical coping strategies when suitable | |
5. Maintains appropriate boundaries and disclaimers | |
Note: Always clarify that you're an AI assistant and not a replacement for professional mental health care. | |
Supportive Response: | |
""" | |
rag_prompt = PromptTemplate.from_template(template) | |
rag_chain = ( | |
{"context": retriever, "question": RunnablePassthrough()} | |
| rag_prompt | |
| llm | |
| StrOutputParser() | |
) | |
########################################## | |
# 7. Processing Incoming Messages # | |
########################################## | |
def process_message(message: str, history: List[Tuple[str, str]]) -> List[Tuple[str, str]]: | |
""" | |
Process each message in a single shot (no streaming) | |
to avoid the 'Data incompatible with tuples format' error. | |
""" | |
if history is None: | |
history = [] | |
# Flatten the conversation into one string for context | |
history_str = "" | |
for user_msg, assistant_msg in history: | |
history_str += f"User: {user_msg}\nAssistant: {assistant_msg}\n" | |
# Prepare RAG input | |
chain_input = { | |
"context": history_str, | |
"question": message | |
} | |
# Call the RAG chain to get the response | |
try: | |
answer = rag_chain.invoke(chain_input) | |
except Exception as e: | |
print(f"Error generating response: {e}") | |
answer = ( | |
"I’m sorry, but something went wrong while I was generating a response. " | |
"If you need immediate support, please reach out to a mental health professional." | |
) | |
# Update the conversation history with the user’s question and the assistant’s reply | |
history.append((message, answer)) | |
return history | |
########################################## | |
# 8. Build the Gradio Interface # | |
########################################## | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo: | |
gr.Markdown(""" | |
# Professional Mental Health Support Assistant | |
**Providing empathetic guidance and insights.** | |
**Important Note:** This AI assistant offers supportive information based on mental health research data. | |
It is NOT a replacement for professional mental health care or immediate crisis resources. | |
If you are in an emergency or crisis situation, please contact emergency services (114 In Rwanda) | |
or reach out to a trusted healthcare professional immediately. | |
""") | |
chatbot = gr.Chatbot( | |
height=500, | |
show_copy_button=True, | |
bubble_full_width=False, | |
avatar_images=["👤", "🤖"] | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="How are you feeling today? Type your question here...", | |
container=False, | |
scale=9 | |
) | |
submit = gr.Button("Send", scale=1, variant="primary") | |
gr.Examples( | |
examples=[ | |
"What are common symptoms of depression?", | |
"How prevalent is anxiety in the global population?", | |
"What are some coping strategies for stress?", | |
"Can you tell me about treatment options for mental health issues?", | |
"How do I know if I should seek professional help?" | |
], | |
inputs=msg | |
) | |
gr.Markdown(""" | |
### Helpful Resources | |
- **Emergency Services**: Call your local emergency number (144 in the Rwanda). | |
- **Helplines**: Contact your country’s mental health helpline if you or someone you know needs urgent assistance. | |
""") | |
# Link the input and output | |
msg.submit( | |
process_message, | |
[msg, chatbot], | |
chatbot, | |
queue=True | |
) | |
submit.click( | |
process_message, | |
[msg, chatbot], | |
chatbot, | |
queue=True | |
) | |
########################################## | |
# 9. Launch the Gradio App # | |
########################################## | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch( | |
share=True, | |
show_error=True, | |
server_name="0.0.0.0", | |
server_port=7860 | |
) |