AxL95 commited on
Commit
13164ef
·
verified ·
1 Parent(s): ab279c4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -849
app.py CHANGED
@@ -1,878 +1,67 @@
1
- from fastapi import FastAPI, Request, HTTPException,Depends,File, UploadFile, Response
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse
4
  from fastapi.staticfiles import StaticFiles
5
- from huggingface_hub import InferenceClient
6
- import secrets
7
- from typing import Optional
8
- from sentence_transformers import SentenceTransformer
9
- from bson.objectid import ObjectId
10
- from datetime import datetime, timedelta
11
- from fastapi import Request
12
- import requests
13
- import numpy as np
14
  import argparse
15
- import os
16
- from pymongo import MongoClient
17
- from datetime import datetime
18
- from passlib.hash import bcrypt
19
- import PyPDF2
20
- from io import BytesIO
21
- import uuid
22
 
23
- from langchain_community.embeddings import HuggingFaceEmbeddings
24
- from sklearn.metrics.pairwise import cosine_similarity
25
- import time
26
 
27
- from fastapi.responses import StreamingResponse
28
- import json
29
- import asyncio
30
 
31
-
32
- from langchain_community.document_loaders import PyPDFDirectoryLoader
33
- from langchain_text_splitters import RecursiveCharacterTextSplitter
34
- from langchain_community.embeddings import HuggingFaceEmbeddings
35
-
36
-
37
- SECRET_KEY = secrets.token_hex(32)
38
-
39
- HOST = os.environ.get("API_URL", "0.0.0.0")
40
- PORT = os.environ.get("PORT", 7860)
41
- parser = argparse.ArgumentParser()
42
- parser.add_argument("--host", default=HOST)
43
- parser.add_argument("--port", type=int, default=PORT)
44
- parser.add_argument("--reload", action="store_true", default=True)
45
- parser.add_argument("--ssl_certfile")
46
- parser.add_argument("--ssl_keyfile")
47
- args = parser.parse_args()
48
-
49
- # Configuration MongoDB
50
- mongo_uri = os.environ.get("MONGODB_URI", "mongodb+srv://giffardaxel95:[email protected]/")
51
- db_name = os.environ.get("DB_NAME", "chatmed_schizo")
52
- mongo_client = MongoClient(mongo_uri)
53
- db = mongo_client[db_name]
54
-
55
- SAVE_FOLDER = "files"
56
- COLLECTION_NAME="connaissances"
57
- os.makedirs(SAVE_FOLDER, exist_ok=True)
58
-
59
-
60
- app = FastAPI()
61
  app.add_middleware(
62
  CORSMiddleware,
63
-
64
- allow_origins=[
65
- "https://axl95-medically.hf.space",
66
- "https://huggingface.co",
67
- "http://localhost:3000",
68
- "http://localhost:7860",
69
- "http://0.0.0.0:7860"
70
- ],
71
  allow_credentials=True,
72
  allow_methods=["*"],
73
  allow_headers=["*"],
74
  )
75
 
76
- def download_pdf(url, save_path, retries=2, delay=3):
77
- for attempt in range(retries):
78
- try:
79
- req = Request(url, headers={'User-Agent': 'Mozilla/5.0'})
80
- with urlopen(req) as response, open(save_path, 'wb') as f:
81
- f.write(response.read())
82
- print(f"Téléchargé : {save_path}")
83
- return
84
- except (HTTPError, URLError) as e:
85
- print(f"Erreur ({e}) pour {url}, tentative {attempt+1}/{retries}")
86
- time.sleep(delay)
87
- print(f"Échec du téléchargement : {url}")
88
 
89
- '''
90
- Le chargement automatique des PDFs est désactivé. La base de données utilise les embeddings existants.
91
- for url in PDF_URLS:
92
- file_name = url.split("/")[-1]
93
- file_path = os.path.join(SAVE_FOLDER, file_name)
94
- if not os.path.exists(file_path):
95
- download_pdf(url, file_path)
96
-
97
- loader = PyPDFDirectoryLoader(SAVE_FOLDER)
98
- docs = loader.load()
99
-
100
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
101
- chunks = splitter.split_documents(docs)
102
- print(f"{len(chunks)} morceaux extraits.")
103
-
104
- embedding_model = HuggingFaceEmbeddings(model_name="shtilev/medical_embedded_v2")
105
-
106
- client = MongoClient(MONGO_URI)
107
- collection = client[DB_NAME][COLLECTION_NAME]
108
 
109
- collection.delete_many({})
110
 
111
- for chunk in chunks:
112
- text = chunk.page_content
113
- embedding = embedding_model.embed_query(text)
114
- collection.insert_one({
115
- "text": text,
116
- "embedding": embedding
117
- })
118
 
119
- print("Tous les embeddings ont été insérés dans la base MongoDB.")
120
  '''
121
-
122
-
123
-
124
-
125
- def retrieve_relevant_context(query, embedding_model, mongo_collection, k=5):
126
- query_embedding = embedding_model.embed_query(query)
127
-
128
- docs = list(mongo_collection.find({}, {"text": 1, "embedding": 1}))
129
-
130
- print(f"[DEBUG] Recherche de contexte pour: '{query}'")
131
- print(f"[DEBUG] {len(docs)} documents trouvés dans la base de données")
132
-
133
- if not docs:
134
- print("[DEBUG] Aucun document dans la collection. RAG désactivé.")
135
- return ""
136
-
137
- # Calcul des similarités
138
- similarities = []
139
- for i, doc in enumerate(docs):
140
- if "embedding" not in doc or not doc["embedding"]:
141
- print(f"[DEBUG] Document {i} sans embedding")
142
- continue
143
-
144
- sim = cosine_similarity([query_embedding], [doc["embedding"]])[0][0]
145
- similarities.append((sim, i, doc["text"]))
146
-
147
- similarities.sort(reverse=True)
148
-
149
- # Afficher les top k documents avec leurs scores
150
- print("\n=== CONTEXTE SÉLECTIONNÉ ===")
151
- top_k_docs = []
152
- for i, (score, idx, text) in enumerate(similarities[:k]):
153
- doc_preview = text[:100] + "..." if len(text) > 100 else text
154
- print(f"Document #{i+1} (score: {score:.4f}): {doc_preview}")
155
- top_k_docs.append(text)
156
- print("==========================\n")
157
-
158
- return "\n\n".join(top_k_docs)
159
-
160
-
161
-
162
- async def get_admin_user(request: Request):
163
- user = await get_current_user(request)
164
- if user["role"] != "Administrateur":
165
- raise HTTPException(status_code=403, detail="Accès interdit: Droits d'administrateur requis")
166
- return user
167
-
168
-
169
- try:
170
- embedding_model = HuggingFaceEmbeddings(model_name="shtilev/medical_embedded_v2")
171
- print("✅ Modèle d'embedding médical chargé avec succès")
172
-
173
- except Exception as e:
174
- print(f"Erreur lors du chargement du modèle d'embedding: {str(e)}")
175
- embedding_model = None
176
-
177
- doc_count = db.connaissances.count_documents({})
178
- print(f"\n[DIAGNOSTIC] Collection 'connaissances': {doc_count} documents trouvés")
179
- if doc_count == 0:
180
- print("[AVERTISSEMENT] La collection est vide. Le système RAG ne fonctionnera pas!")
181
- print("[AVERTISSEMENT] Veuillez charger des documents via l'API admin ou exécuter le script d'initialisation.")
182
- else:
183
- sample_doc = db.connaissances.find_one({})
184
- has_embeddings = "embedding" in sample_doc and sample_doc["embedding"] is not None
185
- print(f"[DIAGNOSTIC] Les documents ont des embeddings: {'✅ Oui' if has_embeddings else '❌ Non'}")
186
- if not has_embeddings:
187
- print("[AVERTISSEMENT] Les documents n'ont pas d'embeddings valides!")
188
- @app.post("/api/admin/knowledge/upload")
189
- async def upload_pdf(
190
- file: UploadFile = File(...),
191
- title: str = None,
192
- tags: str = None,
193
- current_user: dict = Depends(get_admin_user)
194
- ):
195
- try:
196
- if not file.filename.endswith('.pdf'):
197
- raise HTTPException(status_code=400, detail="Le fichier doit être un PDF")
198
-
199
- contents = await file.read()
200
- pdf_file = BytesIO(contents)
201
-
202
- pdf_reader = PyPDF2.PdfReader(pdf_file)
203
- text_content = ""
204
- for page_num in range(len(pdf_reader.pages)):
205
- text_content += pdf_reader.pages[page_num].extract_text() + "\n"
206
-
207
- embedding = None
208
- if embedding_model:
209
- try:
210
- # Limiter la taille du texte si nécessaire
211
- max_length = 5000
212
- truncated_text = text_content[:max_length]
213
- embedding = embedding_model.embed_query(truncated_text)
214
- except Exception as e:
215
- print(f"Erreur lors de la génération de l'embedding: {str(e)}")
216
-
217
- doc_id = ObjectId()
218
-
219
- pdf_path = f"files/{str(doc_id)}.pdf"
220
- os.makedirs("files", exist_ok=True)
221
- with open(pdf_path, "wb") as f:
222
- pdf_file.seek(0)
223
- f.write(contents)
224
-
225
- document = {
226
- "_id": doc_id,
227
- "text": text_content,
228
- "embedding": embedding,
229
- "title": title or file.filename,
230
- "tags": tags.split(",") if tags else [],
231
- "uploaded_by": str(current_user["_id"]),
232
- "upload_date": datetime.utcnow()
233
- }
234
-
235
- print(f"Tentative d'insertion du document avec ID: {doc_id}")
236
- result = db.connaissances.insert_one(document)
237
- print(f"Document inséré avec ID: {result.inserted_id}")
238
-
239
- # Vérification de l'insertion
240
- verification = db.connaissances.find_one({"_id": doc_id})
241
- if verification:
242
- print(f"Document vérifié et trouvé dans la base de données")
243
- return {"success": True, "document_id": str(doc_id)}
244
- else:
245
- print(f"ERREUR: Document non trouvé après insertion")
246
- return {"success": False, "error": "Document non trouvé après insertion"}
247
-
248
- except Exception as e:
249
- import traceback
250
- print(f"Erreur lors de l'upload du PDF: {traceback.format_exc()}")
251
- raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}")
252
-
253
- @app.get("/api/admin/knowledge")
254
- async def list_documents(current_user: dict = Depends(get_admin_user)):
255
- try:
256
- documents = list(db.connaissances.find().sort("upload_date", -1))
257
-
258
- result = []
259
- for doc in documents:
260
- doc_safe = {
261
- "id": str(doc["_id"]),
262
- "title": doc.get("title", "Sans titre"),
263
- "tags": doc.get("tags", []),
264
- "date": doc.get("upload_date").isoformat() if "upload_date" in doc else None,
265
- "text_preview": doc.get("text", "")[:100] + "..." if len(doc.get("text", "")) > 100 else doc.get("text", "")
266
- }
267
- result.append(doc_safe)
268
-
269
- return {"documents": result}
270
- except Exception as e:
271
- print(f"Erreur lors de la liste des documents: {str(e)}")
272
- raise HTTPException(status_code=500, detail=f"Erreur: {str(e)}")
273
-
274
-
275
-
276
- @app.delete("/api/admin/knowledge/{document_id}")
277
- async def delete_document(document_id: str, current_user: dict = Depends(get_admin_user)):
278
- try:
279
- try:
280
- doc_id = ObjectId(document_id)
281
- except Exception:
282
- raise HTTPException(status_code=400, detail="ID de document invalide")
283
-
284
- # Vérifier si le document existe
285
- document = db.connaissances.find_one({"_id": doc_id})
286
- if not document:
287
- raise HTTPException(status_code=404, detail="Document non trouvé")
288
-
289
- # Supprimer le document de la base de données
290
- result = db.connaissances.delete_one({"_id": doc_id})
291
-
292
- if result.deleted_count == 0:
293
- raise HTTPException(status_code=500, detail="Échec de la suppression du document")
294
-
295
- # Supprimer le fichier PDF associé s'il existe
296
- pdf_path = f"files/{document_id}.pdf"
297
- if os.path.exists(pdf_path):
298
- try:
299
- os.remove(pdf_path)
300
- print(f"Fichier supprimé: {pdf_path}")
301
- except Exception as e:
302
- print(f"Erreur lors de la suppression du fichier: {str(e)}")
303
-
304
- return {"success": True, "message": "Document supprimé avec succès"}
305
-
306
- except HTTPException as he:
307
- raise he
308
- except Exception as e:
309
- raise HTTPException(status_code=500, detail=f"Erreur lors de la suppression: {str(e)}")
310
-
311
-
312
- @app.post("/api/login")
313
- async def login(request: Request, response: Response):
314
- try:
315
- data = await request.json()
316
- email = data.get("email")
317
- password = data.get("password")
318
-
319
- user = db.users.find_one({"email": email})
320
- if not user or not bcrypt.verify(password, user["password"]):
321
- raise HTTPException(status_code=401, detail="Email ou mot de passe incorrect")
322
-
323
- session_id = secrets.token_hex(16)
324
- user_id = str(user["_id"])
325
- username = f"{user['prenom']} {user['nom']}"
326
-
327
- db.sessions.insert_one({
328
- "session_id": session_id,
329
- "user_id": user_id,
330
- "created_at": datetime.utcnow(),
331
- "expires_at": datetime.utcnow() + timedelta(days=7)
332
- })
333
-
334
- response.set_cookie(
335
- key="session_id",
336
- value=session_id,
337
- httponly=False,
338
- max_age=7*24*60*60,
339
- samesite="none",
340
- secure=True,
341
- path="/"
342
- )
343
-
344
- # Log pour débogage
345
- print(f"Session créée: {session_id} pour l'utilisateur {user_id}")
346
-
347
- return {
348
- "success": True,
349
- "username": username,
350
- "user_id": user_id,
351
- "session_id": session_id,
352
- "role": user.get("role", "user")
353
-
354
- }
355
-
356
- except Exception as e:
357
- print(f"Erreur login: {str(e)}")
358
- raise HTTPException(status_code=500, detail=str(e))
359
-
360
-
361
- async def get_current_user(request: Request):
362
- session_id = request.cookies.get("session_id")
363
- print(f"Cookie de session reçu: {session_id[:5] if session_id else 'None'}")
364
-
365
- if not session_id:
366
- auth_header = request.headers.get("Authorization")
367
- if auth_header and auth_header.startswith("Bearer "):
368
- session_id = auth_header.replace("Bearer ", "")
369
- print(f"Session d'autorisation reçue: {session_id[:5]}...")
370
-
371
- if not session_id:
372
- session_id = request.query_params.get("session_id")
373
- if session_id:
374
- print(f"Session des paramètres de requête: {session_id[:5]}...")
375
-
376
- if not session_id:
377
- raise HTTPException(status_code=401, detail="Non authentifié - Aucune session trouvée")
378
-
379
- session = db.sessions.find_one({
380
- "session_id": session_id,
381
- "expires_at": {"$gt": datetime.utcnow()}
382
- })
383
-
384
- if not session:
385
- raise HTTPException(status_code=401, detail="Session expirée ou invalide")
386
-
387
- user = db.users.find_one({"_id": ObjectId(session["user_id"])})
388
- if not user:
389
- raise HTTPException(status_code=401, detail="Utilisateur non trouvé")
390
-
391
- return user
392
-
393
- @app.post("/api/logout")
394
- async def logout(request: Request, response: Response):
395
- session_id = request.cookies.get("session_id")
396
- if session_id:
397
- db.sessions.delete_one({"session_id": session_id})
398
-
399
- response.delete_cookie(key="session_id")
400
- return {"success": True}
401
- @app.post("/api/register")
402
- async def register(request: Request):
403
- try:
404
- data = await request.json()
405
-
406
- required_fields = ["prenom", "nom", "email", "password"]
407
- for field in required_fields:
408
- if not data.get(field):
409
- raise HTTPException(status_code=400, detail=f"Le champ {field} est requis")
410
-
411
- existing_user = db.users.find_one({"email": data["email"]})
412
- if existing_user:
413
- raise HTTPException(status_code=409, detail="Cet email est déjà utilisé")
414
-
415
- hashed_password = bcrypt.hash(data["password"])
416
-
417
- user = {
418
- "prenom": data["prenom"],
419
- "nom": data["nom"],
420
- "email": data["email"],
421
- "password": hashed_password,
422
- "createdAt": datetime.utcnow(),
423
- "role": data.get("role", "user"),
424
-
425
- }
426
-
427
- result = db.users.insert_one(user)
428
-
429
- return {"message": "Utilisateur créé avec succès", "userId": str(result.inserted_id)}
430
-
431
- except HTTPException as he:
432
- raise he
433
-
434
- except Exception as e:
435
- import traceback
436
- print(f"Erreur lors de l'inscription: {str(e)}")
437
- print(traceback.format_exc())
438
- raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
439
- @app.post("/api/embed")
440
- async def embed(request: Request):
441
- data = await request.json()
442
- texts = data.get("texts", [])
443
-
444
- try:
445
-
446
- dummy_embedding = [[0.1, 0.2, 0.3] for _ in range(len(texts))]
447
-
448
- return {"embeddings": dummy_embedding}
449
- except Exception as e:
450
- return {"error": str(e)}
451
-
452
- @app.get("/invert")
453
- async def invert(text: str):
454
  return {
455
- "original": text,
456
- "inverted": text[::-1],
 
 
 
 
 
 
 
 
 
 
457
  }
458
-
459
- HF_TOKEN = os.getenv('REACT_APP_HF_TOKEN')
460
- if not HF_TOKEN:
461
- raise RuntimeError("Le token Hugging Face (HF_TOKEN) n'est pas défini dans les variables d'environnement.")
462
- conversation_history = {}
463
- hf_client = InferenceClient(token=HF_TOKEN)
464
- @app.post("/api/chat")
465
- async def chat(request: Request):
466
- global conversation_history
467
-
468
- # ① Lecture du JSON et extraction des champs
469
- data = await request.json()
470
- user_message = data.get("message", "").strip()
471
- conversation_id = data.get("conversation_id")
472
-
473
- if not user_message:
474
- raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
475
-
476
- current_user = None
477
- try:
478
- current_user = await get_current_user(request)
479
- except HTTPException:
480
- pass
481
-
482
- current_tokens = 0
483
- message_tokens = 0
484
- if current_user and conversation_id:
485
- conv = db.conversations.find_one({
486
- "_id": ObjectId(conversation_id),
487
- "user_id": str(current_user["_id"])
488
- })
489
- if conv:
490
- current_tokens = conv.get("token_count", 0)
491
- message_tokens = int(len(user_message.split()) * 1.3)
492
- MAX_TOKENS = 2000
493
- if current_tokens + message_tokens > MAX_TOKENS:
494
- return JSONResponse({
495
- "error": "token_limit_exceeded",
496
- "message": "Cette conversation a atteint sa limite de taille. Veuillez en créer une nouvelle.",
497
- "tokens_used": current_tokens,
498
- "tokens_limit": MAX_TOKENS
499
- }, status_code=403)
500
-
501
- if conversation_id and current_user:
502
- db.messages.insert_one({
503
- "conversation_id": conversation_id,
504
- "user_id": str(current_user["_id"]),
505
- "sender": "user",
506
- "text": user_message,
507
- "timestamp": datetime.utcnow()
508
- })
509
-
510
- is_history_question = any(
511
- phrase in user_message.lower()
512
- for phrase in [
513
- "ma première question", "ma précédente question", "ma dernière question",
514
- "ce que j'ai demandé", "j'ai dit quoi", "quelles questions",
515
- "c'était quoi ma", "quelle était ma", "mes questions"
516
- ]
517
- )
518
-
519
- if conversation_id not in conversation_history:
520
- conversation_history[conversation_id] = []
521
- # If there's existing conversation in DB, load it to memory
522
- if current_user and conversation_id:
523
- previous_messages = list(db.messages.find(
524
- {"conversation_id": conversation_id}
525
- ).sort("timestamp", 1))
526
-
527
- for msg in previous_messages:
528
- if msg["sender"] == "user":
529
- conversation_history[conversation_id].append(f"Question : {msg['text']}")
530
- else:
531
- conversation_history[conversation_id].append(f"Réponse : {msg['text']}")
532
-
533
- if is_history_question:
534
- actual_questions = []
535
-
536
- if conversation_id in conversation_history:
537
- for msg in conversation_history[conversation_id]:
538
- if msg.startswith("Question : "):
539
- q_text = msg.replace("Question : ", "")
540
- # Ignorer les méta-questions qui parlent déjà de l'historique
541
- is_meta = any(phrase in q_text.lower() for phrase in [
542
- "ma première question", "ma précédente question", "ma dernière question",
543
- "ce que j'ai demandé", "j'ai dit quoi", "quelles questions",
544
- "c'était quoi ma", "quelle était ma", "mes questions"
545
- ])
546
- if not is_meta:
547
- actual_questions.append(q_text)
548
-
549
- if not actual_questions:
550
- return JSONResponse({
551
- "response": "Vous n'avez pas encore posé de question dans cette conversation. C'est notre premier échange."
552
- })
553
-
554
- question_number = None
555
-
556
- if any(p in user_message.lower() for p in ["première question", "1ère question", "1ere question"]):
557
- question_number = 1
558
- elif any(p in user_message.lower() for p in ["deuxième question", "2ème question", "2eme question", "seconde question"]):
559
- question_number = 2
560
- else:
561
- import re
562
- match = re.search(r'(\d+)[eèiéê]*m*e* question', user_message.lower())
563
- if match:
564
- try:
565
- question_number = int(match.group(1))
566
- except:
567
- pass
568
-
569
- if question_number is not None:
570
- if 0 < question_number <= len(actual_questions):
571
- suffix = "ère" if question_number == 1 else "ème"
572
- return JSONResponse({
573
- "response": f"Votre {question_number}{suffix} question était : \"{actual_questions[question_number-1]}\""
574
- })
575
- else:
576
- return JSONResponse({
577
- "response": f"Vous n'avez pas encore posé {question_number} questions dans cette conversation."
578
- })
579
-
580
- else:
581
- if len(actual_questions) == 1:
582
- return JSONResponse({
583
- "response": f"Vous avez posé une seule question jusqu'à présent : \"{actual_questions[0]}\""
584
- })
585
- else:
586
- question_list = "\n".join([f"{i+1}. {q}" for i, q in enumerate(actual_questions)])
587
- return JSONResponse({
588
- "response": f"Voici les questions que vous avez posées dans cette conversation :\n\n{question_list}"
589
- })
590
-
591
- context = None
592
- if not is_history_question and embedding_model:
593
- context = retrieve_relevant_context(user_message, embedding_model, db.connaissances, k=5)
594
- if context and conversation_id:
595
- conversation_history[conversation_id].append(f"Contexte : {context}")
596
-
597
- if conversation_id:
598
- conversation_history[conversation_id].append(f"Question : {user_message}")
599
-
600
- system_prompt = (
601
- "Tu es un chatbot spécialisé dans la santé mentale, et plus particulièrement la schizophrénie. "
602
- "Tu réponds de façon fiable, claire et empathique, en t'appuyant uniquement sur des sources médicales et en français. "
603
- )
604
-
605
- enriched_context = ""
606
-
607
- if conversation_id in conversation_history:
608
- actual_questions = []
609
- for msg in conversation_history[conversation_id]:
610
- if msg.startswith("Question : "):
611
- q_text = msg.replace("Question : ", "")
612
- # Ignorer les méta-questions
613
- is_meta = any(phrase in q_text.lower() for phrase in [
614
- "ma première question", "ma précédente question", "ma dernière question",
615
- "ce que j'ai demandé", "j'ai dit quoi", "quelles questions",
616
- "c'était quoi ma", "quelle était ma", "mes questions"
617
- ])
618
- if not is_meta and q_text != user_message:
619
- actual_questions.append(q_text)
620
-
621
- if actual_questions:
622
- recent_questions = actual_questions[-5:] # 3 dernières questions
623
- enriched_context += "Historique récent des questions:\n"
624
- for i, q in enumerate(recent_questions):
625
- enriched_context += f"- Question précédente {len(recent_questions)-i}: {q}\n"
626
- enriched_context += "\n"
627
-
628
- if context:
629
- enriched_context += "Contexte médical pertinent:\n"
630
- enriched_context += context
631
- enriched_context += "\n\n"
632
-
633
- if enriched_context:
634
- system_prompt += (
635
- f"\n\n{enriched_context}\n\n"
636
- "Utilise ces informations pour répondre de manière plus précise et contextuelle. "
637
- "Ne pas inventer d'informations. Si tu ne sais pas, redirige vers un professionnel de santé."
638
- )
639
- else:
640
- system_prompt += (
641
- "Tu dois répondre uniquement à partir de connaissances médicales factuelles. "
642
- "Si tu ne sais pas répondre, indique-le clairement et suggère de consulter un professionnel de santé."
643
- )
644
-
645
- messages = [{"role": "system", "content": system_prompt}]
646
-
647
- if conversation_id and len(conversation_history.get(conversation_id, [])) > 0:
648
- history = conversation_history[conversation_id]
649
- for i in range(0, min(20, len(history)-1), 2):
650
- if i+1 < len(history):
651
- if history[i].startswith("Question :"):
652
- user_text = history[i].replace("Question : ", "")
653
- messages.append({"role": "user", "content": user_text})
654
-
655
- if history[i+1].startswith("Réponse :"):
656
- assistant_text = history[i+1].replace("Réponse : ", "")
657
- messages.append({"role": "assistant", "content": assistant_text})
658
-
659
- messages.append({"role": "user", "content": user_message})
660
-
661
- try:
662
- completion = hf_client.chat.completions.create(
663
- model="mistralai/Mistral-7B-Instruct-v0.3",
664
- messages=messages,
665
- max_tokens=400,
666
- temperature=0.7,
667
- timeout=15,
668
- )
669
- bot_response = completion.choices[0].message["content"].strip()
670
- except Exception:
671
- fallback = hf_client.text_generation(
672
- model="mistralai/Mistral-7B-Instruct-v0.3",
673
- prompt=f"<s>[INST] {system_prompt}\n\nQuestion: {user_message} [/INST]",
674
- max_new_tokens=512,
675
- temperature=0.7
676
- )
677
- bot_response = fallback
678
-
679
- if conversation_id:
680
- conversation_history[conversation_id].append(f"Réponse : {bot_response}")
681
-
682
- if len(conversation_history[conversation_id]) > 50: # 25 exchanges
683
- conversation_history[conversation_id] = conversation_history[conversation_id][-50:]
684
-
685
- if conversation_id and current_user:
686
- db.messages.insert_one({
687
- "conversation_id": conversation_id,
688
- "user_id": str(current_user["_id"]),
689
- "sender": "assistant",
690
- "text": bot_response,
691
- "timestamp": datetime.utcnow()
692
- })
693
- response_tokens = int(len(bot_response.split()) * 1.3)
694
- total_tokens = current_tokens + message_tokens + response_tokens
695
- db.conversations.update_one(
696
- {"_id": ObjectId(conversation_id)},
697
- {"$set": {
698
- "last_message": bot_response,
699
- "updated_at": datetime.utcnow(),
700
- "token_count": total_tokens
701
- }}
702
- )
703
-
704
- return {"response": bot_response}
705
-
706
-
707
- def simulate_token_count(text):
708
- """
709
- Simule le comptage de tokens sans appeler d'API externe.
710
- """
711
- if not text:
712
- return 0
713
-
714
- text = text.replace('\n', ' \n ')
715
-
716
- spaces_and_punct = sum(1 for c in text if c.isspace() or c in ',.;:!?()[]{}"\'`-_=+<>/@#$%^&*|\\')
717
-
718
- digits = sum(1 for c in text if c.isdigit())
719
-
720
- words = text.split()
721
- short_words = sum(1 for w in words if len(w) <= 2)
722
-
723
- # Les URLs et codes consomment plus de tokens
724
- code_blocks = len(re.findall(r'```[\s\S]*?```', text))
725
- urls = len(re.findall(r'https?://\S+', text))
726
-
727
- adjusted_length = len(text) - spaces_and_punct - digits - short_words
728
-
729
- token_count = (
730
- adjusted_length / 4 +
731
- spaces_and_punct * 0.25 +
732
- digits * 0.5 +
733
- short_words * 0.5 +
734
- code_blocks * 5 +
735
- urls * 4
736
- )
737
-
738
- return int(token_count * 1.1) + 1
739
- @app.get("/data")
740
- async def get_data():
741
- data = {"data": np.random.rand(100).tolist()}
742
- return JSONResponse(data)
743
-
744
- @app.get("/api/conversations")
745
- async def get_conversations(current_user: dict = Depends(get_current_user)):
746
- try:
747
- user_id = str(current_user["_id"])
748
- conversations = list(db.conversations.find(
749
- {"user_id": user_id},
750
- {"_id": 1, "title": 1, "date": 1, "time": 1, "last_message": 1, "created_at": 1}
751
- ).sort("created_at", -1))
752
-
753
- for conv in conversations:
754
- conv["_id"] = str(conv["_id"])
755
-
756
- return {"conversations": conversations}
757
- except Exception as e:
758
- raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
759
-
760
- @app.post("/api/conversations")
761
- async def create_conversation(request: Request, current_user: dict = Depends(get_current_user)):
762
- try:
763
- data = await request.json()
764
- user_id = str(current_user["_id"])
765
-
766
- conversation = {
767
- "user_id": user_id,
768
- "title": data.get("title", "Nouvelle conversation"),
769
- "date": data.get("date"),
770
- "time": data.get("time"),
771
- "last_message": data.get("message", ""),
772
- "created_at": datetime.utcnow()
773
- }
774
-
775
- result = db.conversations.insert_one(conversation)
776
-
777
- return {"conversation_id": str(result.inserted_id)}
778
- except Exception as e:
779
- raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
780
-
781
- @app.post("/api/conversations/{conversation_id}/messages")
782
- async def add_message(conversation_id: str, request: Request, current_user: dict = Depends(get_current_user)):
783
- try:
784
- data = await request.json()
785
- user_id = str(current_user["_id"])
786
-
787
- print(f"Ajout message: conversation_id={conversation_id}, sender={data.get('sender')}, text={data.get('text')[:20]}...")
788
-
789
- conversation = db.conversations.find_one({
790
- "_id": ObjectId(conversation_id),
791
- "user_id": user_id
792
- })
793
-
794
- if not conversation:
795
- raise HTTPException(status_code=404, detail="Conversation non trouvée")
796
-
797
- message = {
798
- "conversation_id": conversation_id,
799
- "user_id": user_id,
800
- "sender": data.get("sender", "user"),
801
- "text": data.get("text", ""),
802
- "timestamp": datetime.utcnow()
803
- }
804
-
805
- db.messages.insert_one(message)
806
-
807
- db.conversations.update_one(
808
- {"_id": ObjectId(conversation_id)},
809
- {"$set": {"last_message": data.get("text", ""), "updated_at": datetime.utcnow()}}
810
- )
811
-
812
- return {"success": True}
813
- except Exception as e:
814
- print(f"Erreur lors de l'ajout d'un message: {str(e)}")
815
- raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
816
-
817
- @app.get("/api/conversations/{conversation_id}/messages")
818
- async def get_messages(conversation_id: str, current_user: dict = Depends(get_current_user)):
819
- try:
820
- user_id = str(current_user["_id"])
821
-
822
- conversation = db.conversations.find_one({
823
- "_id": ObjectId(conversation_id),
824
- "user_id": user_id
825
- })
826
-
827
- if not conversation:
828
- raise HTTPException(status_code=404, detail="Conversation non trouvée")
829
-
830
- messages = list(db.messages.find(
831
- {"conversation_id": conversation_id}
832
- ).sort("timestamp", 1))
833
-
834
- for msg in messages:
835
- msg["_id"] = str(msg["_id"])
836
- if "timestamp" in msg:
837
- msg["timestamp"] = msg["timestamp"].isoformat()
838
-
839
- return {"messages": messages}
840
- except Exception as e:
841
- raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
842
-
843
- @app.delete("/api/conversations/{conversation_id}")
844
- async def delete_conversation(conversation_id: str, current_user: dict = Depends(get_current_user)):
845
- try:
846
- user_id = str(current_user["_id"])
847
-
848
- result = db.conversations.delete_one({
849
- "_id": ObjectId(conversation_id),
850
- "user_id": user_id
851
- })
852
-
853
- if result.deleted_count == 0:
854
- raise HTTPException(status_code=404, detail="Conversation non trouvée")
855
-
856
- db.messages.delete_many({"conversation_id": conversation_id})
857
-
858
- return {"success": True}
859
- except Exception as e:
860
- raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
861
-
862
-
863
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
864
-
865
  if __name__ == "__main__":
866
- import uvicorn
 
 
 
 
 
 
867
 
868
- print(args)
869
  uvicorn.run(
870
  "app:app",
871
  host=args.host,
872
  port=args.port,
873
  reload=args.reload,
874
-
875
  ssl_certfile=args.ssl_certfile,
876
  ssl_keyfile=args.ssl_keyfile,
877
- )
878
-
 
1
+ import config
2
+ from fastapi import FastAPI
3
  from fastapi.middleware.cors import CORSMiddleware
 
4
  from fastapi.staticfiles import StaticFiles
5
+ import uvicorn
 
 
 
 
 
 
 
 
6
  import argparse
 
 
 
 
 
 
 
7
 
8
+ from database import init_mongodb
9
+ import auth, chat, conversations, admin
 
10
 
11
+ app = FastAPI(title="Medic.ial", description="Assistant IA spécialisé sur la maladie de la schizophrénie")
 
 
12
 
13
+ # Configuration CORS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
+ allow_origins=config.CORS_ORIGINS,
 
 
 
 
 
 
 
17
  allow_credentials=True,
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
22
+ init_mongodb()
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ app.include_router(auth.router)
25
+ app.include_router(chat.router)
26
+ app.include_router(conversations.router)
27
+ app.include_router(admin.router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
29
 
 
 
 
 
 
 
 
30
 
31
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
32
  '''
33
+ @app.get("/")
34
+ async def root():
35
+ """Page d'accueil de l'API Medic.ial."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  return {
37
+ "app_name": "Medic.ial - Assistant IA sur la schizophrénie",
38
+ "version": "1.0.0",
39
+ "api_endpoints": [
40
+ {"path": "/api/login", "method": "POST", "description": "Connexion utilisateur"},
41
+ {"path": "/api/register", "method": "POST", "description": "Création d'un compte"},
42
+ {"path": "/api/chat", "method": "POST", "description": "Poser une question à l'assistant"},
43
+ {"path": "/api/conversations", "method": "GET", "description": "Liste des conversations"},
44
+ {"path": "/api/conversations/{id}/messages", "method": "GET", "description": "Messages d'une conversation"}
45
+ ],
46
+ "documentation": "/docs",
47
+ "status": "En ligne",
48
+ "environment": "Développement"
49
  }
50
+ '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  if __name__ == "__main__":
52
+ parser = argparse.ArgumentParser()
53
+ parser.add_argument("--host", default=config.HOST)
54
+ parser.add_argument("--port", type=int, default=config.PORT)
55
+ parser.add_argument("--reload", action="store_true", default=True)
56
+ parser.add_argument("--ssl_certfile")
57
+ parser.add_argument("--ssl_keyfile")
58
+ args = parser.parse_args()
59
 
 
60
  uvicorn.run(
61
  "app:app",
62
  host=args.host,
63
  port=args.port,
64
  reload=args.reload,
 
65
  ssl_certfile=args.ssl_certfile,
66
  ssl_keyfile=args.ssl_keyfile,
67
+ )