Spaces:
Sleeping
Sleeping
import os | |
import time | |
import asyncio | |
import hashlib | |
from typing import Optional | |
from datetime import datetime | |
import re | |
from concurrent.futures import ThreadPoolExecutor | |
from typing import List, Dict, Any | |
from groq import Groq | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline, AutoModelForCausalLM | |
import torch | |
import numpy as np | |
from core.models.knowledge_base import OptimizedGazaKnowledgeBase | |
from core.fact_checker import MedicalFactChecker, clean_ocr_artifacts | |
from core.utils.config import ( | |
MEDICAL_SYSTEM_PROMPT, | |
GROQ_API_KEY, | |
FLAN_MODEL_NAME, | |
FALLBACK_MODEL_NAME, | |
MAX_CACHE_SIZE, | |
MAX_CONTEXT_CHARS | |
) | |
from transformers import pipeline | |
print("🧪 Logger test: ", 'logger' in globals()) | |
from core.utils.logger import logger | |
from transformers import pipeline | |
class ArabicTranslator: | |
def __init__(self): | |
try: | |
self.en_to_ar = pipeline("translation", model="facebook/m2m100_418M", src_lang="en", tgt_lang="ar") | |
self.ar_to_en = pipeline("translation", model="facebook/m2m100_418M", src_lang="ar", tgt_lang="en") | |
print("✅ Translation models loaded") | |
except Exception as e: | |
print(f"❌ Failed to load translation models: {e}") | |
self.en_to_ar = None | |
self.ar_to_en = None | |
def translate_to_english(self, text): | |
if not self.ar_to_en: | |
print("⚠️ Arabic-to-English translation model not available.") | |
return text | |
return self.ar_to_en(text[:1000], max_length=1024)[0]["translation_text"] | |
def translate_to_arabic(self, text): | |
if not self.en_to_ar: | |
print("⚠️ English-to-Arabic translation model not available.") | |
return text | |
# Strip markdown artifacts (prevent "* * *") | |
clean_text = re.sub(r'[*_`~#>]', '', text) | |
MAX_INPUT = 900 # Stay below token limits | |
if len(clean_text) > MAX_INPUT: | |
print(f"⚠️ Input too long ({len(clean_text)}), truncating for translation.") | |
clean_text = clean_text[:MAX_INPUT] | |
return self.en_to_ar(clean_text, max_length=1024)[0]["translation_text"] | |
def translate(self, text: str, direction: str = "to_en") -> str: | |
if direction == "to_en": | |
return self.translate_to_english(text) | |
elif direction == "to_ar": | |
return self.translate_to_arabic(text) | |
else: | |
raise ValueError("Invalid translation direction: choose 'to_en' or 'to_ar'") | |
class OptimizedGazaRAGSystem: | |
"""Optimized RAG system using pre-made assets""" | |
def __init__(self, vector_store_dir: str = "./vector_store"): | |
self.knowledge_base = OptimizedGazaKnowledgeBase(vector_store_dir) | |
self.fact_checker = MedicalFactChecker() | |
self.groq_client = None | |
self.llm = None | |
self.tokenizer = None | |
self.use_native_generation = True # or False by config/env | |
self.system_prompt = self._create_system_prompt() | |
self.arabic_translator = ArabicTranslator() | |
self.generation_pipeline = None | |
self.response_cache = {} | |
self.executor = ThreadPoolExecutor(max_workers=2) | |
self.definitive_patterns = [ | |
re.compile(r, re.IGNORECASE) for r in [ | |
r'will\s+(?:cure|heal|fix)\b', # Only block definitive claims | |
r'guaranteed\s+to', | |
r'completely\s+(?:safe|effective)\b', | |
r'\b(?:inject|syringe)\b' # Added dangerous procedures | |
] | |
] | |
translated_test = self.arabic_translator.translate("How do I treat a wound?") | |
print("🔥 Arabic test translation:", translated_test) | |
def initialize(self): | |
"""Initialize the optimized RAG system""" | |
logger.info("🚀 Initializing Optimized Gaza RAG System...") | |
self.knowledge_base.initialize() | |
logger.info("✅ Optimized Gaza RAG System ready!") | |
def _initialize_groq(self): | |
"""Initialize Groq client with proper error handling""" | |
try: | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
logger.warning("⚠️ GROQ_API_KEY environment variable not set") | |
return None | |
client = Groq(api_key=api_key) | |
# Test the connection with a simple API call | |
try: | |
client.models.list() # Simple API call to verify key | |
logger.info("✅ Groq client initialized successfully") | |
return client | |
except Exception as auth_error: | |
logger.error(f"❌ Groq API key invalid: {auth_error}") | |
return None | |
except Exception as e: | |
logger.error(f"❌ Groq initialization failed: {e}") | |
return None | |
def generate_raw_text(self, prompt: str) -> str: | |
"""Direct text generation without RAG, safety checks, or translation.""" | |
if not self.generation_pipeline: | |
self._initialize_llm() | |
output = self.generation_pipeline(prompt) | |
return output[0]["generated_text"].strip() if output else "" | |
def _format_kb_response(self, text: str) -> str: | |
"""Conditionally expands short KB entries using FLAN→Groq pipeline""" | |
clean_text = clean_ocr_artifacts(text).strip() | |
def _get_groq_client(): | |
from groq import Groq | |
return Groq(api_key=os.getenv("GROQ_API_KEY")) | |
# FIXED HEURISTIC: Increased word count threshold and narrowed keyword list | |
is_detailed = len(clean_text.split()) > 200 and any( | |
kw in clean_text.lower() | |
for kw in ["fracture", "bleeding", "wound", "infection"] | |
) | |
if is_detailed: | |
return f"📚 **Comprehensive Medical Guidance:**\n\n{clean_text}" | |
# Otherwise, enrich it dynamically using FLAN + Groq | |
try: | |
refined_prompt = self._create_prompt_from_rag(query=clean_text, rag_results=[]) | |
enriched = self._generate_with_groq(query=clean_text, refined_prompt=refined_prompt) | |
return f"📚 **Comprehensive Medical Guidance:**\n\n{enriched}" | |
except Exception as e: | |
logger.warning(f"[FormatKB] Enrichment failed: {e}") | |
return f"📚 **Basic Medical Information:**\n\n{clean_text}\n\n⚠️ Could not expand this content automatically." | |
def _create_system_prompt(self) -> str: | |
"""Enhanced system prompt for Gaza context""" | |
MEDICAL_SYSTEM_PROMPT = """ | |
[STRICT GAZA MEDICAL PROTOCOL] | |
You are a WHO-certified medical assistant for Gaza. You MUST: | |
1. Follow WHO war-zone protocols | |
2. Reject unsafe treatments (ESPECIALLY syringe use for burns) | |
3. Prioritize resource-scarce solutions | |
4. Add Islamic medical considerations | |
5. Format responses clearly with: | |
- 🩹 Immediate Actions | |
- ⚠️ Contraindications | |
- 💡 Resource Alternatives | |
6. Include source references [Source X] | |
7. Always add: "📞 Verify with Gaza Red Crescent (101)" for serious cases | |
OUTPUT EXAMPLE: | |
### Burn Treatment ### | |
🩹 Cool with clean water for 10-20 mins [Source 1] | |
⚠️ Never apply ice directly [Source 2] | |
💡 Use clean damp cloth if water scarce [Source 3] | |
📍 Gaza Context: Adapt based on available supplies | |
📞 Verify with Gaza Red Crescent (101) if severe | |
""" | |
return MEDICAL_SYSTEM_PROMPT | |
def _initialize_llm(self): | |
"""Load medical FLAN-T5 model with proper error handling and optimizations""" | |
model_name = "rivapereira123/medical-flan-t5" | |
try: | |
logger.info(f"🔄 Loading medical FLAN-T5 model: {model_name}") | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
model_name, | |
cache_dir="./model_cache" # Optional local caching | |
) | |
self.llm = AutoModelForSeq2SeqLM.from_pretrained( | |
model_name, | |
low_cpu_mem_usage=True, # Critical for CPU | |
) | |
self.generation_pipeline = pipeline( | |
"text2text-generation", | |
model=self.llm, | |
tokenizer=self.tokenizer, | |
max_length=512, # Prevent OOM errors | |
truncation=True | |
) | |
logger.info("✅ Medical FLAN-T5 loaded successfully (CPU mode)") | |
except Exception as e: | |
logger.error(f"❌ Critical error loading model: {str(e)}") | |
logger.warning("⚠️ Medical QA features will be disabled") | |
self.llm = None | |
self.tokenizer = None | |
self.generation_pipeline = None | |
def _initialize_fallback_llm(self): | |
"""Enhanced fallback model with better error handling""" | |
try: | |
logger.info("🔄 Loading fallback model...") | |
fallback_model = "microsoft/DialoGPT-small" | |
self.tokenizer = AutoTokenizer.from_pretrained(fallback_model) | |
self.llm = AutoModelForCausalLM.from_pretrained( | |
fallback_model, | |
torch_dtype=torch.float32, | |
low_cpu_mem_usage=True | |
) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
self.generation_pipeline = pipeline( | |
"text-generation", | |
model=self.llm, | |
tokenizer=self.tokenizer, | |
return_full_text=False | |
) | |
logger.info("✅ Fallback model loaded successfully") | |
except Exception as e: | |
logger.error(f"❌ Fallback model failed: {e}") | |
self.llm = None | |
self.generation_pipeline = None | |
async def generate_response_async(self, query: str, progress_callback=None, language="English") -> Dict[str, Any]: | |
"""Async response generation with progress tracking""" | |
start_time = time.time() | |
# Step 0: Translate Arabic → English if needed | |
original_query = query | |
original_language = language | |
is_arabic_request = language.lower() == "arabic" or self._is_arabic(query) | |
if is_arabic_request and self.arabic_translator: | |
logger.info("🈸 Arabic input detected - translating to English for processing") | |
query = self.arabic_translator.translate_to_english(query) | |
if progress_callback: | |
progress_callback(0.1, "🔍 Checking cache...") | |
# Check cache first (using English query for cache key) | |
# Check cache first (using English query for cache key) | |
query_hash = hashlib.md5(query.encode()).hexdigest() | |
if query_hash in self.response_cache: | |
cached_response = self.response_cache[query_hash] | |
# Translate cached response if needed | |
if is_arabic_request and self.arabic_translator: | |
cached_response["response"] = self.arabic_translator.translate_to_arabic(cached_response["response"]) | |
cached_response["cached"] = True | |
cached_response["response_time"] = 0.1 | |
if progress_callback: | |
progress_callback(1.0, "💾 Retrieved from cache!") | |
return cached_response | |
try: | |
if progress_callback: | |
progress_callback(0.2, "🤖 Initializing LLM...") | |
# Initialize LLM only when needed | |
if self.llm is None: | |
await asyncio.get_event_loop().run_in_executor( | |
self.executor, self._initialize_llm | |
) | |
if progress_callback: | |
progress_callback(0.4, "🔍 Searching knowledge base...") | |
# Enhanced knowledge retrieval using pre-made index | |
search_results = await asyncio.get_event_loop().run_in_executor( | |
self.executor, self.knowledge_base.search, query, 5 | |
) | |
if progress_callback: | |
progress_callback(0.6, "📝 Preparing context...") | |
context = self._prepare_context(search_results) | |
if progress_callback: | |
progress_callback(0.8, "🧠 Generating response...") | |
# Generate response | |
english_response = await asyncio.get_event_loop().run_in_executor( | |
self.executor, self._generate_response, query, context | |
) | |
if progress_callback: | |
progress_callback(0.9, "🛡️ Validating safety...") | |
# Enhanced safety check | |
safety_check = self.fact_checker.check_medical_accuracy(english_response, context) | |
# Step 2: Translate response → Arabic if needed | |
# Prepare final response structure | |
final_response = self._prepare_final_response( | |
english_response, | |
search_results, | |
safety_check, | |
time.time() - start_time | |
) | |
# Step 3: Translate final response to Arabic if requested | |
if is_arabic_request and self.arabic_translator: | |
logger.info("🌐 Translating final response to Arabic") | |
final_response["response"] = self.arabic_translator.translate_to_arabic(final_response["response"]) | |
final_response["translated"] = True | |
final_response["original_language"] = "Arabic" | |
else: | |
final_response["translated"] = False | |
final_response["original_language"] = "English" | |
# Cache the English version (for consistency) | |
if len(self.response_cache) < 100: | |
english_cache_response = final_response.copy() | |
english_cache_response["response"] = english_response # Store English version | |
self.response_cache[query_hash] = english_cache_response | |
if progress_callback: | |
progress_callback(1.0, "✅ Complete!") | |
return final_response | |
except Exception as e: | |
logger.error(f"❌ Error generating response: {e}") | |
if progress_callback: | |
progress_callback(1.0, f"❌ Error: {str(e)}") | |
return self._create_error_response(str(e)) | |
# def _generate_with_flan(self, query: str, context: Optional[str] = None) -> str: | |
# """Generate a response using the FLAN model directly (no Groq).""" | |
# if not self.generation_pipeline: | |
# raise RuntimeError("FLAN generation pipeline not initialized") | |
# # Build simple instructional prompt | |
# prompt = f""" | |
# You are a medical assistant working in Gaza. | |
# Query: | |
# {query} | |
# Context: | |
# {context if context else "No additional context"} | |
# Respond with: | |
# - 🩹 Immediate Actions | |
# - ⚠️ Contraindications | |
# - 💡 Alternatives | |
# - 🚨 When to seek emergency help | |
# """.strip() | |
# result = self.generation_pipeline(prompt) | |
# return result[0]["generated_text"].strip() if result else "⚠️ No response generated." | |
def _create_prompt_from_rag(self, query: str, rag_results: List[Dict[str, Any]]) -> str: | |
"""Use FLAN-T5 to condense RAG results into a clean prompt""" | |
if not self.llm or not rag_results: | |
return query # Fallback to original query | |
# Create context string from RAG results | |
context = "\n".join([f"[Source {i+1}]: {res['text']}" | |
for i, res in enumerate(rag_results[:6])]) | |
# Create prompt for FLAN | |
prompt = f"""You are a medical researcher, Expand this medical context into a concise prompt for detailed response: | |
Original Query: {query} | |
Context: | |
{context} | |
Create a comprehensive response that includes: | |
1. Step-by-step treatment instructions | |
2. Gaza-specific adaptations | |
3. Alternative methods for resource-limited situations | |
4. Clear danger signs requiring professional care | |
5. Proper wound care timeline | |
Create a detailed question incorporating key points from the context, Structure your response with: | |
- Immediate Actions - Contraindications - Follow-up Care - Emergency Indicators:""" | |
# Generate with FLAN-T5 | |
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512) | |
outputs = self.llm.generate( | |
input_ids=inputs.input_ids, | |
max_new_tokens=200, | |
num_beams=3, | |
early_stopping=True | |
) | |
refined_prompt = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return f"{refined_prompt}\n\nContext References:\n{context}" | |
def _is_arabic(self, text): | |
return any('\u0600' <= c <= '\u06FF' for c in text) | |
def _prepare_context(self, search_results: List[Dict[str, Any]]) -> str: | |
MAX_CHARS = 1500 | |
if not search_results: | |
return "First aid protocol: " | |
context_parts = [] | |
for result in search_results[:5]: # Top 3 results only | |
text = str(result.get('text', '')).strip() | |
context_parts.append({ | |
'text': clean_ocr_artifacts(text), | |
'source': result.get('source', 'unknown'), | |
'score': result.get('score', 0.0) | |
}) | |
return "\n\n".join( | |
f"[Reference {i+1}]: {ctx['text']}" | |
for i, ctx in enumerate(context_parts) | |
)[:MAX_CHARS] | |
def _format_context_with_groq(self, context_parts: List[Dict[str, Any]]) -> str: | |
"""Format context using Groq with comprehensive error handling""" | |
def _get_groq_client(self): | |
from groq import Groq | |
return Groq(api_key=os.getenv("GROQ_API_KEY")) | |
logger.info(f"GROQ_API_KEY prefix-format context with groq: {os.getenv('GROQ_API_KEY')[:5]}****") | |
if not hasattr(self, 'groq_client') or not self.groq_client: | |
raise ValueError("Groq client not initialized") | |
if not context_parts: | |
return "No context available" | |
try: | |
# Prepare context string | |
context_str = "\n\n".join( | |
f"Source {i+1} ({ctx['source']}, relevance {ctx['score']:.2f}):\n{ctx['text']}" | |
for i, ctx in enumerate(context_parts) | |
) | |
response = self.groq_client.chat.completions.create( | |
model="deepseek-r1-distill-llama-70b", | |
messages=[ | |
{ | |
"role": "system", | |
"content": self.system_prompt | |
}, | |
{ | |
"role": "user", | |
"content": f"Organize this medical context:\n\n{context_str}" | |
} | |
], | |
temperature=0.3, | |
max_tokens=2000, | |
top_p=0.9 | |
) | |
# Validate response structure | |
if not response or not response.choices: | |
raise ValueError("Empty Groq response") | |
if not hasattr(response.choices[0], 'message') or not hasattr(response.choices[0].message, 'content'): | |
raise ValueError("Invalid response format") | |
formatted = response.choices[0].message.content | |
if formatted is None: | |
raise ValueError("Empty response content") | |
# Post-processing safety checks | |
if not isinstance(formatted, str): | |
raise ValueError("Formatted context is not a string") | |
if "syringe" in formatted.lower(): | |
formatted = "⚠️ SAFETY ALERT: Rejected dangerous suggestion\n\n" + formatted | |
return formatted | |
except Exception as e: | |
logger.error(f"Groq formatting failed: {str(e)}") | |
raise ValueError(f"Groq processing failed: {str(e)}") | |
def _generate_with_groq(self, query: str, context: str = None, refined_prompt: str = None) -> str: | |
""" | |
Generate medical response using Groq with two modes: | |
1. Direct mode (query + context) | |
2. Refined prompt mode (FLAN-processed prompt) | |
Args: | |
query: Original user query | |
context: Optional RAG context | |
refined_prompt: Optional FLAN-processed prompt | |
""" | |
def _get_groq_client(): | |
from groq import Groq | |
return Groq(api_key=os.getenv("GROQ_API_KEY")) | |
# Verify Groq client | |
if not self.groq_client: | |
self.groq_client = _get_groq_client() | |
if not self.groq_client: | |
raise ValueError("Groq client not available") | |
try: | |
# Determine which mode to use | |
if refined_prompt: | |
# Refined prompt mode (RAG→FLAN→Groq pipeline) | |
messages = [ | |
{ | |
"role": "system", | |
"content": ( | |
f"""{self.system_prompt}\n | |
Your task is to expand the medical information into comprehensive, | |
descriptive guidance while preserving all safety considerations. | |
You are a WHO medical advisor for Gaza. Provide 1) Extremely detailed step-by-step guidance | |
2) Multiple treatment options for different resource scenarios | |
3)Clear timeframes for each action | |
4)Islamic medical considerations | |
5)Format with emoji section headers | |
""" | |
) | |
}, | |
{ | |
"role": "user", | |
"content": f"""Expand this into comprehensive medical guidance: | |
{refined_prompt} | |
Include: | |
1. Detailed procedural steps | |
2. Pain management techniques | |
3. Infection prevention measures | |
4. When to seek emergency care""" | |
} | |
] | |
max_tokens = 2000 # Allow longer responses for descriptive answers | |
temperature = 0.5 # Slightly higher for creativity | |
else: | |
# Direct mode (query + context) | |
messages = [ | |
{ | |
"role": "system", | |
"content": ( | |
f"{self.system_prompt}\n" | |
"Provide detailed 5-7 step guidance when applicable." | |
) | |
}, | |
{ | |
"role": "user", | |
"content": ( | |
f"Query: {query}\n" | |
f"Context: {context if context else 'No additional context'}\n\n" | |
"Provide comprehensive guidance with:\n" | |
"1. Detailed steps\n2. Alternative methods\n3. Gaza-specific adaptations" | |
) | |
} | |
] | |
max_tokens = 1500 # Slightly shorter for direct responses | |
temperature = 0.4 # Balanced between creativity and accuracy | |
# Make the API call | |
response = self.groq_client.chat.completions.create( | |
model="llama3-70b-8192", # Using latest model | |
messages=messages, | |
temperature=temperature, | |
max_tokens=max_tokens, | |
top_p=0.9 | |
) | |
result = response.choices[0].message.content | |
if not result: | |
raise ValueError("Empty response from Groq") | |
# Post-processing steps | |
if refined_prompt and "Context References:" in refined_prompt: | |
# Preserve RAG references in refined prompt mode | |
refs = refined_prompt.split("Context References:")[1] | |
result = f"{result}\n\n📚 Source References:{refs}" | |
elif context: | |
# Add basic context reference in direct mode | |
result += "\n\nℹ️ Context: Based on verified medical sources" | |
# Safety checks (applied to both modes) | |
if "syringe" in result.lower(): | |
result = "⚠️ SAFETY ALERT: Rejected dangerous suggestion\n\n" + result | |
if "Gaza Red Crescent" not in result: | |
result += "\n\n📞 Verify with Gaza Red Crescent (101) if condition worsens" | |
return result | |
except Exception as e: | |
logger.error(f"Groq generation failed: {str(e)}") | |
raise ValueError(f"Groq processing error: {str(e)}") | |
def _generate_response(self, query: str, context: str) -> str: | |
"""Enhanced RAG → FLAN → Groq pipeline with fallbacks""" | |
# FIXED: Removed early return that was bypassing FLAN→Groq pipeline | |
def _get_groq_client(): | |
from groq import Groq | |
return Groq(api_key=os.getenv("GROQ_API_KEY")) | |
# 1. Get broader RAG context (3 results instead of 1) | |
rag_results = self.knowledge_base.search(query, k=6) | |
context = self._prepare_context(rag_results) # ✅ FIXED: ensure context exists | |
top_score = rag_results[0]['score'] if rag_results else 0 | |
logger.info(f"Found {len(rag_results)} RAG results (top score: {rag_results[0]['score'] if rag_results else 0:.2f})") | |
# 2. Try RAG→FLAN→Groq pipeline | |
try: | |
# Create refined prompt using FLAN | |
refined_prompt = self._create_prompt_from_rag(query, rag_results) | |
logger.info(f"Refined prompt: {refined_prompt[:100]}...") | |
# Generate with Groq if available | |
if hasattr(self, 'groq_client') and self.groq_client: | |
try: | |
groq_response = self._generate_with_groq(query=query, refined_prompt=refined_prompt) | |
if groq_response: | |
return groq_response | |
except Exception as groq_error: | |
logger.warning(f"Groq generation failed: {str(groq_error)}") | |
except Exception as pipe_error: | |
logger.warning(f"RAG→FLAN→Groq pipeline failed: {str(pipe_error)}") | |
# 3. Fallback to direct FLAN generation | |
if self.llm and self.tokenizer: | |
try: | |
# Use the original context (not refined prompt) for fallback | |
flan_response = self._generate_with_flan(query, context) | |
if flan_response: | |
return flan_response | |
except Exception as flan_error: | |
logger.error(f"FLAN generation failed: {str(flan_error)}") | |
# 4. Ultimate fallback to cached knowledge | |
if rag_results: | |
return self._format_kb_response(rag_results[0]['text']) | |
# 5. Final emergency fallback | |
return self._generate_fallback_response(query, context) | |
def _format_final_response(self, response: str) -> str: | |
"""Ensure response meets all Gaza-specific requirements""" | |
clean_response = response.split("CONTEXT:")[0].strip() | |
for icon in ["🩹", "💡", "⚠️"]: | |
if icon not in clean_response: | |
clean_response = clean_response.replace("Immediate Actions", f"Immediate Actions {icon}", 1) | |
break | |
if "📍 Gaza Context:" not in clean_response: | |
clean_response += "\n\n📍 Gaza Context: This guidance considers resource limitations. Adapt based on available supplies and seek professional medical care when accessible." | |
return clean_response | |
def _get_error_response(self, query: str, error: Exception) -> str: | |
"""User-friendly error message with Gaza contacts""" | |
return f"""⚠️ We're unable to process your query about "{query}" | |
For immediate medical assistance: | |
📞 Palestinian Red Crescent: 101 | |
📞 Civil Defense: 102 | |
(Technical issue: {str(error)}...)""" | |
def _generate_fallback_response(self, query: str, context: str) -> str: | |
"""Enhanced fallback response with Gaza-specific guidance""" | |
gaza_guidance = { | |
"burn": "For burns: Use clean, cool water if available. If water is scarce, use clean cloth. Avoid ice. Seek medical help urgently.", | |
"bleeding": "For bleeding: Apply direct pressure with clean cloth. Elevate if possible. If severe, seek immediate medical attention.", | |
"wound": "For wounds: Clean hands if possible. Apply pressure to stop bleeding. Cover with clean material. Watch for infection signs.", | |
"infection": "Signs of infection: Redness, warmth, swelling, pus, fever. Seek medical care immediately if available.", | |
"pain": "For pain management: Rest, elevation, cold/warm compress as appropriate. Avoid aspirin in children." | |
} | |
query_lower = query.lower() | |
for condition, guidance in gaza_guidance.items(): | |
if condition in query_lower: | |
return f"{guidance}\n\nContext from medical sources:\n{context}..." | |
return f"Medical guidance for: {query}\n\nGeneral advice: Prioritize safety, seek professional help when available, consider resource limitations in Gaza.\n\nRelevant information:\n{context[:600]}..." | |
def _prepare_final_response( | |
self, | |
response: str, | |
search_results: List[Dict[str, Any]], | |
safety_check: Dict[str, Any], | |
response_time: float | |
) -> Dict[str, Any]: | |
def _get_groq_client(): | |
from groq import Groq | |
return Groq(api_key=os.getenv("GROQ_API_KEY")) | |
# Ensure response is a string | |
if response is None: | |
response = "Unable to generate response. Please try again." | |
elif not isinstance(response, str): | |
response = str(response) | |
"""Enhanced final response preparation with more metadata""" | |
# Ensure response is a string | |
if not isinstance(response, str): | |
response = "Unable to generate response. Please try again." | |
# Ensure safety_check has required fields | |
if not isinstance(safety_check, dict): | |
safety_check = { | |
"confidence_score": 0.5, | |
"issues": [], | |
"warnings": ["Response validation failed"], | |
"is_safe": False | |
} | |
# Add safety warnings if needed | |
if not safety_check["is_safe"]: | |
response = f"⚠️ MEDICAL CAUTION: {response}\n\n🚨 Please verify this guidance with a medical professional when possible." | |
if safety_check["is_safe"] and hasattr(self, 'groq_client') and self.groq_client: | |
try: | |
enhanced = self._enhance_response_with_groq(response, search_results) | |
if enhanced: | |
response = enhanced | |
except Exception as e: | |
logger.warning(f"Response enhancement failed: {e}") | |
# Add Gaza-specific disclaimer | |
# Extract unique sources | |
sources = list(set(res.get("source", "unknown") for res in search_results)) if search_results else [] | |
# Calculate confidence based on multiple factors | |
base_confidence = safety_check.get("confidence_score", 0.5) | |
context_bonus = 0.1 if search_results else 0.0 | |
safety_penalty = 0.2 if not safety_check.get("is_safe", True) else 0.0 | |
final_confidence = max(0.0, min(1.0, base_confidence + context_bonus - safety_penalty)) | |
return { | |
"response": response, | |
"confidence": final_confidence, | |
"sources": sources, | |
"search_results_count": len(search_results), | |
"safety_issues": safety_check.get("issues", []), | |
"safety_warnings": safety_check.get("warnings", []), | |
"response_time": round(response_time, 2), | |
"timestamp": datetime.now().isoformat()[:19], | |
"cached": False | |
} | |
if not hasattr(self, 'groq_client') or not self.groq_client: | |
return response | |
try: | |
groq_client = self._get_groq_client() | |
models = groq_client.models.list() | |
logger.info("🧪 Available Groq models:") | |
for m in models.data: | |
logger.info(f" - {m.id}") | |
messages = [ | |
{"role": "system", "content": self.system_prompt}, | |
{"role": "user", "content": f"Enhance this medical response:\n\n{response}"} | |
] | |
enhanced = self.groq_client.chat.completions.create( | |
model="deepseek-r1-distill-llama-70b", | |
messages=messages, | |
temperature=0.3, | |
max_tokens=1000 | |
) | |
if enhanced and enhanced.choices: | |
return enhanced.choices[0].message.content | |
return response | |
except Exception as e: | |
logger.warning(f"Response enhancement failed: {e}") | |
return response | |
def _enhance_response_with_groq(self, response: str, search_results: List[Dict[str, Any]]) -> str: | |
"""Enhance response using Groq's capabilities""" | |
def _get_groq_client(): | |
from groq import Groq | |
return Groq(api_key=os.getenv("GROQ_API_KEY")) | |
if not hasattr(self, 'groq_client') or not self.groq_client: | |
return response | |
try: | |
messages = [ | |
{"role": "system", "content": self.system_prompt}, | |
{"role": "user", "content": f"Enhance this medical response:\n\n{response}"} | |
] | |
enhanced = self.groq_client.chat.completions.create( | |
model="llama3-70b-8192", # Updated model name | |
messages=messages, | |
temperature=0.3, | |
max_tokens=2000 | |
) | |
if enhanced and enhanced.choices: | |
return enhanced.choices[0].message.content | |
return response | |
except Exception as e: | |
logger.warning(f"Response enhancement failed: {e}") | |
return response | |
def _create_error_response(self, error_msg: str) -> Dict[str, Any]: | |
"""Enhanced error response with helpful information""" | |
return { | |
"response": f"⚠️ System Error: Unable to process your medical query at this time.\n\nError: {error_msg}\n\n🚨 For immediate medical emergencies, seek professional help directly.\n\n📞 Gaza Emergency Numbers:\n- Palestinian Red Crescent: 101\n- Civil Defense: 102", | |
"confidence": 0.0, | |
"sources": [], | |
"search_results_count": 0, | |
"safety_issues": ["System error occurred"], | |
"safety_warnings": ["Unable to validate medical accuracy "], | |
"response_time": 0.0, | |
"timestamp": datetime.now().isoformat()[:19], | |
"cached": False, | |
"error": True | |
} | |