File size: 7,990 Bytes
ff5dbd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f25019
ff5dbd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49e8b23
 
 
 
 
 
 
 
 
 
 
ff5dbd8
 
 
 
49e8b23
ff5dbd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
379afa7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
import os
import time
from typing import List, Tuple

import pandas as pd
import gradio as gr

# For LLM and embeddings:
from langchain_groq import ChatGroq
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma

# For RAG prompting
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough

##########################################
# 1. Load & Prepare Mental Health Data   #
##########################################
def load_mental_health_data(file_paths, max_rows_per_file=100, max_columns=3):
    """
    Loads multiple mental health datasets, 
    limiting the number of rows and columns to reduce resource usage.
    """
    dataframes = {}
    context_data = []
    
    for key, file in file_paths.items():
        try:
            df = pd.read_csv(file)
            dataframes[key] = df
            
            # Limit rows to reduce size
            limited_df = df.head(max_rows_per_file)
            
            # Generate minimal context strings for each row
            for i in range(len(limited_df)):
                context = f"Dataset: {key} | "
                for j in range(min(max_columns, len(limited_df.columns))):
                    column_name = limited_df.columns[j]
                    cell_value = limited_df.iloc[i][j]
                    context += f"{column_name}: {cell_value} "
                context_data.append(context.strip())
        except Exception as e:
            print(f"Error loading {file}: {e}")
    
    return context_data

##########################################
# 2. Specify File Paths                  #
##########################################
file_paths = {
    'df1': '1- mental-illnesses-prevalence.csv',
    'df2': '2- burden-disease-from-each-mental-illness(1).csv',
    'df3': '3- adult-population-covered-in-primary-data-on-the-prevalence-of-major-depression.csv',
    'df4': '4- adult-population-covered-in-primary-data-on-the-prevalence-of-mental-illnesses.csv',
    'df5': '5- anxiety-disorders-treatment-gap.csv',
    'df6': '6- depressive-symptoms-across-us-population.csv',
    'df7': '7- number-of-countries-with-primary-data-on-prevalence-of-mental-illnesses-in-the-global-burden-of-disease-study.csv',
    'df8': 'train.csv'
}

##########################################
# 3. Generate Context Data               #
##########################################
context_data = load_mental_health_data(
    file_paths=file_paths, 
    max_rows_per_file=100, 
    max_columns=3
)

##########################################
# 4. Initialize LLM & Embeddings         #
##########################################
groq_api = "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"

llm = ChatGroq(model="llama-3.1-70b-versatile", api_key=groq_api)

embed_model = HuggingFaceEmbeddings(model_name="mixedbread-ai/mxbai-embed-large-v1")

##########################################
# 5. Create Vector Store & Retriever     #
##########################################
vectorstore = Chroma(
    collection_name="mental_health_store",
    embedding_function=embed_model,
    persist_directory="./"  # Where the index is stored
)
vectorstore.add_texts(context_data)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})

##########################################
# 6. Define Prompt Template for RAG      #
##########################################
template = """You are a professional, empathetic mental health support AI assistant. 
Your role is to offer supportive, informative responses while maintaining appropriate 
boundaries and encouraging professional help when needed.

Context from mental health research: {context}

User Message: {question}

Please provide a response that:
1. Shows empathy and understanding
2. Provides relevant information based on research data when applicable
3. Encourages professional help when appropriate
4. Offers practical coping strategies when suitable
5. Maintains appropriate boundaries and disclaimers

Note: Always clarify that you're an AI assistant and not a replacement for professional mental health care.

Supportive Response:
"""

rag_prompt = PromptTemplate.from_template(template)
rag_chain = (
    {"context": retriever, "question": RunnablePassthrough()}
    | rag_prompt
    | llm
    | StrOutputParser()
)

##########################################
# 7. Processing Incoming Messages        #
##########################################
def process_message(message: str, history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    """
    Process each message in a single shot (no streaming) 
    to avoid the 'Data incompatible with tuples format' error.
    """
    if history is None:
        history = []

    # Flatten the conversation into one string for context
    history_str = ""
    for user_msg, assistant_msg in history:
        history_str += f"User: {user_msg}\nAssistant: {assistant_msg}\n"

    # Prepare RAG input
    chain_input = {
        "context": history_str,
        "question": message
    }

    # Call the RAG chain to get the response
    try:
        answer = rag_chain.invoke(chain_input)
    except Exception as e:
        print(f"Error generating response: {e}")
        answer = (
            "I’m sorry, but something went wrong while I was generating a response. "
            "If you need immediate support, please reach out to a mental health professional."
        )

    # Update the conversation history with the user’s question and the assistant’s reply
    history.append((message, answer))
    return history


##########################################
# 8. Build the Gradio Interface          #
##########################################
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="purple")) as demo:
    gr.Markdown("""
    # Professional Mental Health Support Assistant
    **Providing empathetic guidance and insights.**

    **Important Note:** This AI assistant offers supportive information based on mental health research data.
    It is NOT a replacement for professional mental health care or immediate crisis resources. 
    If you are in an emergency or crisis situation, please contact emergency services (114 In Rwanda) 
    or reach out to a trusted healthcare professional immediately.
    """)

    chatbot = gr.Chatbot(
        height=500,
        show_copy_button=True,
        bubble_full_width=False,
        avatar_images=["👤", "🤖"]
    )
    
    with gr.Row():
        msg = gr.Textbox(
            placeholder="How are you feeling today? Type your question here...",
            container=False,
            scale=9
        )
        submit = gr.Button("Send", scale=1, variant="primary")
    
    gr.Examples(
        examples=[
            "What are common symptoms of depression?",
            "How prevalent is anxiety in the global population?",
            "What are some coping strategies for stress?",
            "Can you tell me about treatment options for mental health issues?",
            "How do I know if I should seek professional help?"
        ],
        inputs=msg
    )
    
    gr.Markdown("""
    ### Helpful Resources
    - **Emergency Services**: Call your local emergency number (144 in the Rwanda).
    - **Helplines**: Contact your country’s mental health helpline if you or someone you know needs urgent assistance.
    """)

    # Link the input and output
    msg.submit(
        process_message,
        [msg, chatbot],
        chatbot,
        queue=True
    )
    submit.click(
        process_message,
        [msg, chatbot],
        chatbot,
        queue=True
    )

##########################################
# 9. Launch the Gradio App               #
##########################################
if __name__ == "__main__":
    demo.queue(max_size=20).launch(
        share=True,
        show_error=True,
        server_name="0.0.0.0",
        server_port=7860
    )