Spaces:
Running
Running
import os | |
import sys | |
import asyncio | |
from pathlib import Path | |
from core.utils.translations import translations | |
from datetime import datetime | |
from typing import Tuple | |
from scenario_builder import ScenarioBuilder | |
import gradio as gr | |
# Add parent directory to path for imports | |
sys.path.append(str(Path(__file__).parent.parent)) | |
from core.ai_engine import OptimizedGazaRAGSystem | |
from ui.components import ( | |
get_custom_css, | |
create_header_section, | |
create_query_input_section, | |
create_response_output_section, | |
create_quick_access_section, | |
create_example_scenarios, | |
gradio_user_selector, | |
gradio_sidebar_controls, | |
gradio_show_response | |
) | |
# import logging | |
# logger = logging.getLogger(__name__) | |
# logging.basicConfig(level=logging.INFO) | |
from core.utils.logger import logger | |
# Global system instance | |
optimized_rag_system = None | |
scenario_builder = None | |
import gradio as gr | |
import gradio as gr | |
def build_dynamic_mcq_ui(scenario_data: list[dict[str, any]]): | |
with gr.Blocks() as mcq_interface: | |
gr.Markdown("## π§ͺ Interactive Scenario Quiz") | |
radio_inputs = [] | |
for idx, q in enumerate(scenario_data): | |
gr.Markdown(f"**Q{idx+1}: {q['question']}**") | |
choices = [f"{opt}: {txt}" for opt, txt in q["options"].items()] | |
radio = gr.Radio(choices, label=None, type="index") | |
radio_inputs.append(radio) | |
submit = gr.Button("β Submit") | |
output = gr.Markdown() | |
def evaluate(*user_indices): | |
results = [] | |
for i, selected_index in enumerate(user_indices): | |
correct_key = scenario_data[i]["correct_answer"] | |
correct_label = f"{correct_key}: {scenario_data[i]['options'][correct_key]}" | |
user_key = list(scenario_data[i]["options"].keys())[selected_index] if selected_index is not None else "N/A" | |
user_label = f"{user_key}: {scenario_data[i]['options'].get(user_key, 'No answer')}" | |
correct = user_key == correct_key | |
results.append( | |
f""" | |
### Question {i+1} | |
**Your answer:** {user_label} | |
**Correct answer:** {correct_label} | |
{"β Correct!" if correct else "β Incorrect."} | |
π§ {scenario_data[i]['feedback']} | |
""") | |
return "\n".join(results) | |
submit.click(fn=evaluate, inputs=radio_inputs, outputs=output) | |
return mcq_interface | |
def generate_scenario_data(query: str, num_questions: int = 5): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
system = initialize_optimized_system() | |
global scenario_builder | |
if not scenario_builder: | |
scenario_builder = ScenarioBuilder(system) | |
result = loop.run_until_complete(scenario_builder.create_scenario_from_query(query, num_questions)) | |
loop.close() | |
if "error" in result: | |
return {"error": result["error"], "details": result["details"]} | |
return result["questions"] | |
def render_mcq_quiz(container: gr.Group, questions: list[dict]): | |
container.clear() | |
container.append(gr.Markdown("## π§ͺ Interactive Quiz")) | |
radios = [] | |
for idx, q in enumerate(questions): | |
container.append(gr.Markdown(f"**Q{idx+1}: {q['question']}**")) | |
options = [f"{k}: {v}" for k, v in q["options"].items()] | |
radio = gr.Radio(choices=options, label=None, type="index") | |
container.append(radio) | |
radios.append(radio) | |
result_box = gr.Markdown() | |
container.append(gr.Button("β Submit").click( | |
fn=lambda *selected: format_results(questions, selected), | |
inputs=radios, | |
outputs=result_box | |
)) | |
container.append(result_box) | |
def format_results(questions, selections): | |
results = [] | |
for i, selected_index in enumerate(selections): | |
q = questions[i] | |
correct_key = q["correct_answer"] | |
correct_label = f"{correct_key}: {q['options'][correct_key]}" | |
user_key = list(q["options"].keys())[selected_index] if selected_index is not None else "N/A" | |
user_label = f"{user_key}: {q['options'].get(user_key, 'No answer')}" | |
is_correct = user_key == correct_key | |
results.append(f""" | |
### Q{i+1} | |
**Your answer:** {user_label} | |
**Correct answer:** {correct_label} | |
{"β Correct!" if is_correct else "β Incorrect."} | |
π§ {q['feedback']} | |
""") | |
return "\n---\n".join(results) | |
def get_mcq_data(query: str, num_questions: int = 5): | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
system = initialize_optimized_system() | |
global scenario_builder | |
if not scenario_builder: | |
scenario_builder = ScenarioBuilder(system) | |
result = loop.run_until_complete(scenario_builder.create_scenario_from_query(query, num_questions)) | |
loop.close() | |
if "error" in result: | |
return {"error": result["error"], "details": result["details"]} | |
return result["questions"] | |
def initialize_optimized_system(vector_store_dir: str = "./vector_store"): | |
global optimized_rag_system | |
if optimized_rag_system is None: | |
try: | |
optimized_rag_system = OptimizedGazaRAGSystem(vector_store_dir) | |
scenario_builder = ScenarioBuilder(optimized_rag_system) | |
optimized_rag_system.initialize() | |
logger.info("β Optimized Gaza RAG System initialized successfully") | |
except Exception as e: | |
logger.error(f"β Failed to initialize optimized system: {e}") | |
raise | |
return optimized_rag_system | |
def process_medical_query_with_progress(query: str,language, progress=gr.Progress()) -> Tuple[str, str, str, str, str]: | |
from core.utils.translations import translations | |
t = translations.get(language, translations["English"]) | |
if not query.strip(): | |
return ( | |
"Please enter a medical question.", | |
"", | |
"β οΈ No query provided", | |
gr.update(value=""), | |
gr.update(value=""), | |
gr.update(value="") | |
) | |
try: | |
progress(0.05, desc="π§ Initializing optimized system...") | |
system = initialize_optimized_system() | |
loop = asyncio.new_event_loop() | |
asyncio.set_event_loop(loop) | |
def progress_callback(value, desc): | |
progress(value, desc=desc) | |
try: | |
result = loop.run_until_complete( | |
system.generate_response_async(query, progress_callback, language=language) | |
) | |
finally: | |
loop.close() | |
response = result["response"] | |
metadata_parts = [ | |
f"π― Confidence: {result.get('confidence', 0):.1%}", | |
f"β±οΈ Response: {result.get('response_time', 0)}s", | |
f"π Sources: {result.get('search_results_count', 0)} found" | |
] | |
if result.get("cached"): | |
metadata_parts.append("πΎ Cached") | |
if result.get("sources"): | |
metadata_parts.append(f"π Refs: {', '.join(result['sources'][:2])}") | |
metadata = " | ".join(metadata_parts) | |
status_parts = [] | |
if result.get("safety_warnings"): | |
status_parts.append(f"β οΈ {len(result['safety_warnings'])} warnings") | |
if result.get("safety_issues"): | |
status_parts.append(f"π¨ {len(result['safety_issues'])} issues") | |
if not status_parts: | |
status_parts.append("β Safe response") | |
status = " | ".join(status_parts) | |
return (response, metadata, status, gr.Markdown(f"### AI Response\n{response}"), gr.Markdown(f"### Metadata\n{metadata}"), gr.Markdown(f"### Safety Check\n{status}")) | |
except Exception as e: | |
logger.error(f"β Error processing query: {e}") | |
error_response = f"β οΈ Error processing your query: {str(e)}\n\nπ¨ For medical emergencies, seek immediate professional help." | |
error_metadata = f"β Error at {datetime.now().strftime('%H:%M:%S')}" | |
error_status = "π¨ System error occurred" | |
return (error_response, error_metadata, error_status, gr.update(value=""), gr.update(value=""), gr.update(value="")) | |
def get_system_stats() -> str: | |
try: | |
system = initialize_optimized_system() | |
stats = system.knowledge_base.get_stats() | |
if stats["status"] == "initialized": | |
return f""" | |
π **System Statistics:** | |
- Status: β Initialized | |
- Total Chunks: {stats['total_chunks']:,} | |
- Vector Dimension: {stats['embedding_dimension']} | |
- Index Type: {stats['index_type']} | |
- Sources: {len(stats['sources'])} documents | |
- Available Sources: {', '.join(stats['sources'][:5])}{'...' if len(stats['sources']) > 5 else ''} | |
""" | |
else: | |
return "π System Status: β Not Initialized" | |
except Exception as e: | |
return f"π System Status: β Error - {str(e)}" | |
def create_optimized_gradio_interface(): | |
with gr.Blocks( | |
css=get_custom_css() + """ | |
#role-selection-box { | |
display: flex; | |
flex-direction: column; | |
gap: 1rem; | |
} | |
.gr-button { | |
width: 350px; | |
font-size: 1.1rem; | |
} | |
..highlight-flash { | |
animation: flash-highlight 1.6s ease-in-out; | |
} | |
""", | |
title="π₯ Optimized Gaza First Aid Assistant", | |
theme=gr.themes.Soft(primary_hue="blue", secondary_hue="green", neutral_hue="slate") | |
) as interface: | |
user_role = gr.State() | |
default_language = "English" | |
# Role Selection UI | |
with gr.Column(elem_id="role-selection-box", visible=True) as role_selection_group: | |
role_title = gr.Markdown("### π§ββοΈ Select Your Role to Begin") | |
volunteer_btn = gr.Button("π― I'm a Volunteer") | |
organizer_btn = gr.Button("π I'm an Event Organizer") | |
divider = gr.Markdown("---") | |
# Main UI (Hidden at first) | |
with gr.Column(visible=False) as main_ui: | |
with gr.Row(elem_classes=["main-container"]): | |
create_header_section() | |
with gr.Row(elem_classes=["main-container"]): | |
with gr.Group(elem_classes=["stats-container"]): | |
stats_display = gr.Markdown( | |
value=get_system_stats(), | |
label="π System Status" | |
) | |
with gr.Row(elem_classes=["main-container"]): | |
with gr.Column(scale=2): | |
query_container, query_input, submit_btn, clear_btn = create_query_input_section(default_language) | |
create_example_scenarios(query_input, default_language) | |
with gr.Column(scale=1): | |
create_quick_access_section(default_language) | |
user_type_dropdown = gradio_user_selector() | |
language_dropdown = gradio_sidebar_controls() | |
with gr.Row(elem_classes=["main-container"]): | |
with gr.Column(): | |
_, response_output, metadata_output, status_output = create_response_output_section() | |
show_response_output = gr.Markdown(label="AI Response", elem_classes=["highlight-flash"]) | |
show_metadata_output = gr.Markdown(label="Metadata") | |
show_safety_output = gr.Markdown(label="Safety Check") | |
with gr.Tab("π§ͺ Scenario Generator"): | |
with gr.Column(elem_classes=["scenario-generator-container"]): | |
gr.Markdown("### Generate an Interactive Medical Scenario") | |
# Input controls | |
scenario_query_input = gr.Textbox( | |
label="Enter a medical topic", | |
placeholder="e.g., 'burns', 'fractures', 'CPR'", | |
value="" | |
) | |
num_questions_slider = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
label="Number of Questions" | |
) | |
scenario_submit = gr.Button("π Generate Scenario", variant="primary") | |
# Status display | |
scenario_status = gr.Markdown("Ready to generate quiz...") | |
# Output containers | |
scenario_quiz_block = gr.Column(visible=False) | |
scenario_result_output = gr.Markdown(visible=False, elem_classes=["quiz-result-output"]) | |
questions_state = gr.State() | |
# Pre-allocate quiz components (up to 10 questions) | |
quiz_questions = [] | |
with scenario_quiz_block: | |
gr.Markdown("## π§ͺ Interactive Medical Quiz") | |
for i in range(10): | |
q_md = gr.Markdown(visible=False) | |
q_radio = gr.Radio(choices=[], type="index", visible=False) | |
quiz_questions.append((q_md, q_radio)) | |
submit_quiz_btn = gr.Button("β Submit Answers", variant="primary") | |
def on_generate_click(query, num_q): | |
try: | |
if not query.strip(): | |
return generate_error_response("Please enter a medical topic") | |
if num_q < 1 or num_q > 10: | |
return generate_error_response("Number of questions must be between 1 and 10") | |
questions = get_mcq_data(query.strip(), num_q) | |
if isinstance(questions, dict) and "error" in questions: | |
error_msg = f"β Error: {questions['error']}" | |
if "details" in questions: | |
error_msg += f" β {questions['details']}" | |
return generate_error_response(error_msg) | |
if not questions or not isinstance(questions, list): | |
return generate_error_response("No valid questions were generated") | |
return generate_success_response(questions, query) | |
except Exception as e: | |
logger.error(f"Error in on_generate_click: {e}") | |
return generate_error_response(f"Unexpected error: {str(e)}") | |
def generate_error_response(error_message): | |
updates = [] | |
for _ in range(10): | |
updates.append(gr.update(visible=False)) | |
for _ in range(10): | |
updates.append(gr.update(visible=False)) | |
updates.extend([ | |
gr.update(visible=False), | |
gr.update(value=error_message, visible=True), | |
None, | |
gr.update(value=error_message) | |
]) | |
return tuple(updates) | |
def generate_success_response(questions, query): | |
updates = [] | |
for i in range(10): | |
if i < len(questions): | |
q = questions[i] | |
updates.append(gr.update(value=f"**Question {i+1}:** {q['question']}", visible=True)) | |
else: | |
updates.append(gr.update(visible=False)) | |
for i in range(10): | |
if i < len(questions): | |
q = questions[i] | |
choices = [f"{k}: {v}" for k, v in q["options"].items()] | |
updates.append(gr.update(choices=choices, visible=True, value=None, label=f"Select your answer for Question {i+1}:")) | |
else: | |
updates.append(gr.update(visible=False, choices=[])) | |
updates.extend([ | |
gr.update(visible=True), | |
gr.update(value="", visible=False), | |
questions, | |
gr.update(value=f"β Generated {len(questions)} questions about '{query}'. Answer the questions below and click Submit!") | |
]) | |
return tuple(updates) | |
def evaluate_quiz(*args): | |
try: | |
if not args: | |
return gr.update(value="β No data received for evaluation", visible=True) | |
questions = args[-1] | |
selections = args[:-1] | |
if not questions: | |
return gr.update(value="β No questions available for evaluation", visible=True) | |
results = [] | |
score = 0 | |
total = len(questions) | |
for i, selected_index in enumerate(selections): | |
if i >= len(questions): | |
break | |
question = questions[i] | |
correct_key = question["correct_answer"] | |
correct_answer = f"{correct_key}: {question['options'][correct_key]}" | |
if selected_index is not None and selected_index < len(question["options"]): | |
user_key = list(question["options"].keys())[selected_index] | |
user_answer = f"{user_key}: {question['options'][user_key]}" | |
is_correct = user_key == correct_key | |
else: | |
user_answer = "No answer selected" | |
is_correct = False | |
if is_correct: | |
score += 1 | |
status_icon = "β " if is_correct else "β" | |
results.append(f""" | |
### Question {i + 1} | |
**{question['question']}** | |
πΉ **Your Answer:** {user_answer} | |
πΉ **Correct Answer:** {correct_answer} | |
πΉ **Result:** {status_icon} {'Correct!' if is_correct else 'Incorrect'} | |
π‘ **Explanation:** {question.get('feedback', 'No explanation provided.')} | |
--- | |
""") | |
percentage = (score / total) * 100 if total > 0 else 0 | |
if percentage >= 90: | |
grade_emoji, grade_text = "π", "Outstanding!" | |
elif percentage >= 80: | |
grade_emoji, grade_text = "π", "Excellent work!" | |
elif percentage >= 70: | |
grade_emoji, grade_text = "π", "Good job!" | |
elif percentage >= 60: | |
grade_emoji, grade_text = "π", "Keep studying!" | |
else: | |
grade_emoji, grade_text = "πͺ", "More practice needed!" | |
summary = f""" | |
# π Quiz Results | |
## {grade_emoji} Final Score: {score}/{total} ({percentage:.0f}%) | |
**{grade_text}** | |
--- | |
## π Detailed Feedback: | |
""" | |
return gr.update(value=summary + "\n".join(results), visible=True) | |
except Exception as e: | |
logger.error(f"Error in evaluate_quiz: {e}") | |
return gr.update(value=f"β Error evaluating quiz: {str(e)}", visible=True) | |
scenario_submit.click( | |
fn=on_generate_click, | |
inputs=[scenario_query_input, num_questions_slider], | |
outputs=[ | |
quiz_questions[0][0], quiz_questions[1][0], quiz_questions[2][0], quiz_questions[3][0], quiz_questions[4][0], | |
quiz_questions[5][0], quiz_questions[6][0], quiz_questions[7][0], quiz_questions[8][0], quiz_questions[9][0], | |
quiz_questions[0][1], quiz_questions[1][1], quiz_questions[2][1], quiz_questions[3][1], quiz_questions[4][1], | |
quiz_questions[5][1], quiz_questions[6][1], quiz_questions[7][1], quiz_questions[8][1], quiz_questions[9][1], | |
scenario_quiz_block, scenario_result_output, questions_state, scenario_status | |
], | |
show_progress=True | |
) | |
submit_quiz_btn.click( | |
fn=evaluate_quiz, | |
inputs=[q[1] for q in quiz_questions] + [questions_state], | |
outputs=[scenario_result_output] | |
) | |
gr.HTML(""" | |
<script> | |
window.scrollToQuizResult = function() { | |
const el = document.querySelector('.quiz-result-output'); | |
if (el) el.scrollIntoView({ behavior: 'smooth' }); | |
} | |
</script> | |
""") | |
# Event handlers for main UI | |
submit_btn.click( | |
process_medical_query_with_progress, | |
inputs=[query_input, language_dropdown], | |
outputs=[response_output, metadata_output, status_output, show_response_output, show_metadata_output, show_safety_output], | |
show_progress=True | |
) | |
query_input.submit( | |
process_medical_query_with_progress, | |
inputs=[query_input, language_dropdown], | |
outputs=[response_output, metadata_output, status_output, show_response_output, show_metadata_output, show_safety_output], | |
show_progress=True | |
) | |
clear_btn.click( | |
lambda: ("", "", "", gr.update(value=""), gr.update(value=""), gr.update(value="")), | |
outputs=[query_input, response_output, metadata_output, status_output, show_response_output, show_metadata_output, show_safety_output] | |
) | |
# Role selection handlers | |
def show_main_ui(role): | |
return ( | |
role, # update user_role state | |
gr.update(visible=True), # show main UI | |
gr.update(visible=False), # hide role title | |
gr.update(visible=False), # hide volunteer_btn | |
gr.update(visible=False), # hide organizer_btn | |
gr.update(visible=False) # hide divider | |
) | |
volunteer_btn.click( | |
lambda: show_main_ui("volunteer"), | |
outputs=[user_role, main_ui, role_title, volunteer_btn, organizer_btn, divider] | |
) | |
organizer_btn.click( | |
lambda: show_main_ui("organizer"), | |
outputs=[user_role, main_ui, role_title, volunteer_btn, organizer_btn, divider] | |
) | |
return interface | |
if __name__ == "__main__": | |
interface = create_optimized_gradio_interface() | |
interface.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False, | |
debug=True | |
) |