AxL95 commited on
Commit
7e3cbc2
·
verified ·
1 Parent(s): d80657c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +302 -28
app.py CHANGED
@@ -1,15 +1,22 @@
1
- from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from huggingface_hub import InferenceClient
6
- from sentence_transformers import SentenceTransformer
7
-
 
 
8
  from fastapi import Request
9
  import requests
10
  import numpy as np
11
  import argparse
12
  import os
 
 
 
 
 
13
 
14
  HOST = os.environ.get("API_URL", "0.0.0.0")
15
  PORT = os.environ.get("PORT", 7860)
@@ -21,6 +28,14 @@ parser.add_argument("--ssl_certfile")
21
  parser.add_argument("--ssl_keyfile")
22
  args = parser.parse_args()
23
 
 
 
 
 
 
 
 
 
24
  app = FastAPI()
25
  app.add_middleware(
26
  CORSMiddleware,
@@ -31,15 +46,137 @@ app.add_middleware(
31
  )
32
 
33
 
34
- app = FastAPI()
35
- embedder = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v1')
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  @app.post("/api/embed")
38
  async def embed(request: Request):
39
  data = await request.json()
40
  texts = data.get("texts", [])
41
- embeddings = embedder.encode(texts)
42
- return {"embeddings": embeddings.tolist()}
 
 
 
 
 
 
 
 
 
43
 
44
  @app.get("/invert")
45
  async def invert(text: str):
@@ -48,41 +185,39 @@ async def invert(text: str):
48
  "inverted": text[::-1],
49
  }
50
 
51
- HF_TOKEN = os.getenv("REACT_APP_HF_TOKEN")
52
  if not HF_TOKEN:
53
  raise RuntimeError("Le token Hugging Face (HF_TOKEN) n'est pas défini dans les variables d'environnement.")
54
 
55
  # Initialisation du client HF
56
- hf_client = InferenceClient(
57
- provider="novita",
58
- api_key=HF_TOKEN,
59
- )
60
 
61
  @app.post("/api/chat")
62
  async def chat(request: Request):
63
  data = await request.json()
64
  user_message = data.get("message", "").strip()
65
  if not user_message:
 
66
  raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
67
 
68
  try:
69
- # Appel au modèle en mode chat
70
- completion = hf_client.chat.completions.create(
71
  model="mistralai/Mistral-7B-Instruct-v0.3",
72
- messages=[
73
- {"role": "system", "content": "Tu es un assistant médical spécialisé en schizophrénie."},
74
- {"role": "user", "content": user_message}
75
- ],
76
- max_tokens=512,
77
- temperature=0.7,
78
  )
79
-
80
- bot_msg = completion.choices[0].message.content
81
- return {"response": bot_msg}
82
-
83
  except Exception as e:
84
- # En cas d'erreur d'inférence
85
- raise HTTPException(status_code=502, detail=f"Erreur d'inférence HF : {e}")
 
 
86
 
87
 
88
  @app.get("/data")
@@ -91,8 +226,11 @@ async def get_data():
91
  return JSONResponse(data)
92
 
93
 
94
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
95
 
 
 
 
 
96
  if __name__ == "__main__":
97
  import uvicorn
98
 
@@ -104,4 +242,140 @@ if __name__ == "__main__":
104
  reload=args.reload,
105
  ssl_certfile=args.ssl_certfile,
106
  ssl_keyfile=args.ssl_keyfile,
107
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException,Depends, Response
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from huggingface_hub import InferenceClient
6
+ import secrets
7
+ from typing import Optional
8
+ from bson.objectid import ObjectId # Pour les _id MongoDB
9
+ from datetime import datetime, timedelta # Ajout de timedelta pour expires_at
10
  from fastapi import Request
11
  import requests
12
  import numpy as np
13
  import argparse
14
  import os
15
+ from pymongo import MongoClient
16
+ from datetime import datetime # Ajout pour createdAt
17
+ from passlib.hash import bcrypt # Ajout pour le hash du mot de passe
18
+
19
+ SECRET_KEY = secrets.token_hex(32) # Génère une clé aléatoire
20
 
21
  HOST = os.environ.get("API_URL", "0.0.0.0")
22
  PORT = os.environ.get("PORT", 7860)
 
28
  parser.add_argument("--ssl_keyfile")
29
  args = parser.parse_args()
30
 
31
+ # Configuration MongoDB
32
+ mongo_uri = os.environ.get("MONGODB_URI", "mongodb+srv://giffardaxel95:[email protected]/")
33
+ db_name = os.environ.get("DB_NAME", "chatmed_schizo")
34
+ mongo_client = MongoClient(mongo_uri)
35
+ db = mongo_client[db_name]
36
+
37
+
38
+
39
  app = FastAPI()
40
  app.add_middleware(
41
  CORSMiddleware,
 
46
  )
47
 
48
 
 
 
49
 
50
+ # Fonction pour créer une session
51
+ @app.post("/api/login")
52
+ async def login(request: Request, response: Response):
53
+ try:
54
+ data = await request.json()
55
+ email = data.get("email")
56
+ password = data.get("password")
57
+
58
+ # Vérifier les identifiants
59
+ user = db.users.find_one({"email": email})
60
+ if not user or not bcrypt.verify(password, user["password"]):
61
+ raise HTTPException(status_code=401, detail="Email ou mot de passe incorrect")
62
+
63
+ # Créer une session
64
+ session_id = secrets.token_hex(16)
65
+ user_id = str(user["_id"])
66
+ username = f"{user['prenom']} {user['nom']}"
67
+
68
+ # Stocker la session en base de données
69
+ db.sessions.insert_one({
70
+ "session_id": session_id,
71
+ "user_id": user_id,
72
+ "created_at": datetime.utcnow(),
73
+ "expires_at": datetime.utcnow() + timedelta(days=7)
74
+ })
75
+
76
+ # Définir le cookie de session
77
+ response.set_cookie(
78
+ key="session_id",
79
+ value=session_id,
80
+ httponly=True,
81
+ max_age=7*24*60*60, # 7 jours
82
+ samesite="lax"
83
+ )
84
+
85
+ return {"success": True, "username": username, "user_id": user_id}
86
+
87
+ except HTTPException as he:
88
+ raise he
89
+ except Exception as e:
90
+ raise HTTPException(status_code=500, detail=str(e))
91
+
92
+
93
+ async def get_current_user(request: Request):
94
+ session_id = request.cookies.get("session_id")
95
+ if not session_id:
96
+ raise HTTPException(status_code=401, detail="Non authentifié")
97
+
98
+ # Vérifier si la session existe et n'est pas expirée
99
+ session = db.sessions.find_one({
100
+ "session_id": session_id,
101
+ "expires_at": {"$gt": datetime.utcnow()}
102
+ })
103
+
104
+ if not session:
105
+ raise HTTPException(status_code=401, detail="Session expirée ou invalide")
106
+
107
+ user = db.users.find_one({"_id": ObjectId(session["user_id"])})
108
+ if not user:
109
+ raise HTTPException(status_code=401, detail="Utilisateur non trouvé")
110
+
111
+ return user
112
+
113
+ # Endpoint pour déconnexion
114
+ @app.post("/api/logout")
115
+ async def logout(request: Request, response: Response):
116
+ session_id = request.cookies.get("session_id")
117
+ if session_id:
118
+ db.sessions.delete_one({"session_id": session_id})
119
+
120
+ response.delete_cookie(key="session_id")
121
+ return {"success": True}
122
+ @app.post("/api/register")
123
+ async def register(request: Request):
124
+ try:
125
+ data = await request.json()
126
+
127
+ # Validation
128
+ required_fields = ["prenom", "nom", "email", "password"]
129
+ for field in required_fields:
130
+ if not data.get(field):
131
+ raise HTTPException(status_code=400, detail=f"Le champ {field} est requis")
132
+
133
+ # Vérifier si l'email existe déjà
134
+ existing_user = db.users.find_one({"email": data["email"]})
135
+ if existing_user:
136
+ raise HTTPException(status_code=409, detail="Cet email est déjà utilisé")
137
+
138
+ # Hash du mot de passe (installer passlib avec pip install passlib[bcrypt])
139
+ hashed_password = bcrypt.hash(data["password"])
140
+
141
+ # Insérer l'utilisateur
142
+ user = {
143
+ "prenom": data["prenom"],
144
+ "nom": data["nom"],
145
+ "email": data["email"],
146
+ "password": hashed_password,
147
+ "createdAt": datetime.utcnow()
148
+ }
149
+
150
+ result = db.users.insert_one(user)
151
+
152
+ return {"message": "Utilisateur créé avec succès", "userId": str(result.inserted_id)}
153
+
154
+ except HTTPException as he:
155
+ # Re-lever les HTTPException pour FastAPI les traite
156
+ raise he
157
+
158
+ except Exception as e:
159
+ # Logging détaillé pour débogage
160
+ import traceback
161
+ print(f"Erreur lors de l'inscription: {str(e)}")
162
+ print(traceback.format_exc())
163
+ raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
164
+ # Par cette solution utilisant l'API d'inférence directement:
165
  @app.post("/api/embed")
166
  async def embed(request: Request):
167
  data = await request.json()
168
  texts = data.get("texts", [])
169
+
170
+ try:
171
+ # Si vous utilisez un embedder personnalisé
172
+ # embeddings = embedder.encode(texts).tolist()
173
+
174
+ # Pour déboguer, renvoyez simplement un embedding fictif
175
+ dummy_embedding = [[0.1, 0.2, 0.3] for _ in range(len(texts))]
176
+
177
+ return {"embeddings": dummy_embedding}
178
+ except Exception as e:
179
+ return {"error": str(e)}
180
 
181
  @app.get("/invert")
182
  async def invert(text: str):
 
185
  "inverted": text[::-1],
186
  }
187
 
188
+ HF_TOKEN = os.getenv('REACT_APP_HF_TOKEN')
189
  if not HF_TOKEN:
190
  raise RuntimeError("Le token Hugging Face (HF_TOKEN) n'est pas défini dans les variables d'environnement.")
191
 
192
  # Initialisation du client HF
193
+
194
+ # Par cette version correcte
195
+ hf_client = InferenceClient(token=HF_TOKEN)
 
196
 
197
  @app.post("/api/chat")
198
  async def chat(request: Request):
199
  data = await request.json()
200
  user_message = data.get("message", "").strip()
201
  if not user_message:
202
+ from fastapi import HTTPException
203
  raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
204
 
205
  try:
206
+ # Utiliser l'API HF standard sans provider
207
+ response = hf_client.text_generation(
208
  model="mistralai/Mistral-7B-Instruct-v0.3",
209
+ prompt=f"<s>[INST] Tu es un assistant médical spécialisé en schizophrénie. Réponds à cette question: {user_message} [/INST]",
210
+ max_new_tokens=512,
211
+ temperature=0.7
 
 
 
212
  )
213
+
214
+ return {"response": response}
215
+
 
216
  except Exception as e:
217
+ from fastapi import HTTPException
218
+ import traceback
219
+ print(f"Erreur détaillée: {traceback.format_exc()}") # Log détaillé de l'erreur
220
+ raise HTTPException(status_code=502, detail=f"Erreur d'inférence HF : {str(e)}")
221
 
222
 
223
  @app.get("/data")
 
226
  return JSONResponse(data)
227
 
228
 
 
229
 
230
+ # Ou ajoutez un index minimal
231
+ @app.get("/")
232
+ def read_root():
233
+ return {"message": "API Medically fonctionnelle", "endpoints": ["/api/chat", "/invert", "/data"]}
234
  if __name__ == "__main__":
235
  import uvicorn
236
 
 
242
  reload=args.reload,
243
  ssl_certfile=args.ssl_certfile,
244
  ssl_keyfile=args.ssl_keyfile,
245
+ )
246
+
247
+
248
+ # Endpoint pour récupérer toutes les conversations d'un utilisateur
249
+ @app.get("/api/conversations")
250
+ async def get_conversations(current_user: dict = Depends(get_current_user)):
251
+ try:
252
+ user_id = str(current_user["_id"])
253
+ # Récupération des conversations triées par date (les plus récentes d'abord)
254
+ conversations = list(db.conversations.find(
255
+ {"user_id": user_id},
256
+ {"_id": 1, "title": 1, "date": 1, "time": 1, "last_message": 1, "created_at": 1}
257
+ ).sort("created_at", -1))
258
+
259
+ # Convertir les ObjectId en strings pour la sérialisation JSON
260
+ for conv in conversations:
261
+ conv["_id"] = str(conv["_id"])
262
+
263
+ return {"conversations": conversations}
264
+ except Exception as e:
265
+ raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
266
+
267
+ # Endpoint pour sauvegarder une nouvelle conversation
268
+ @app.post("/api/conversations")
269
+ async def create_conversation(request: Request, current_user: dict = Depends(get_current_user)):
270
+ try:
271
+ data = await request.json()
272
+ user_id = str(current_user["_id"])
273
+
274
+ # Créer la nouvelle conversation
275
+ conversation = {
276
+ "user_id": user_id,
277
+ "title": data.get("title", "Nouvelle conversation"),
278
+ "date": data.get("date"),
279
+ "time": data.get("time"),
280
+ "last_message": data.get("message", ""),
281
+ "created_at": datetime.utcnow()
282
+ }
283
+
284
+ result = db.conversations.insert_one(conversation)
285
+
286
+ # Retourner l'ID de la conversation créée
287
+ return {"conversation_id": str(result.inserted_id)}
288
+ except Exception as e:
289
+ raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
290
+
291
+ # Endpoint pour sauvegarder un message dans une conversation
292
+ @app.post("/api/conversations/{conversation_id}/messages")
293
+ async def add_message(conversation_id: str, request: Request, current_user: dict = Depends(get_current_user)):
294
+ try:
295
+ data = await request.json()
296
+ user_id = str(current_user["_id"])
297
+
298
+ # Debug pour vérifier les données
299
+ print(f"Ajout message: conversation_id={conversation_id}, sender={data.get('sender')}, text={data.get('text')[:20]}...")
300
+
301
+ # Vérifier que la conversation appartient à l'utilisateur
302
+ conversation = db.conversations.find_one({
303
+ "_id": ObjectId(conversation_id),
304
+ "user_id": user_id
305
+ })
306
+
307
+ if not conversation:
308
+ raise HTTPException(status_code=404, detail="Conversation non trouvée")
309
+
310
+ # Ajouter le message
311
+ message = {
312
+ "conversation_id": conversation_id,
313
+ "user_id": user_id,
314
+ "sender": data.get("sender", "user"),
315
+ "text": data.get("text", ""),
316
+ "timestamp": datetime.utcnow()
317
+ }
318
+
319
+ db.messages.insert_one(message)
320
+
321
+ # Mettre à jour la dernière activité de la conversation
322
+ db.conversations.update_one(
323
+ {"_id": ObjectId(conversation_id)},
324
+ {"$set": {"last_message": data.get("text", ""), "updated_at": datetime.utcnow()}}
325
+ )
326
+
327
+ return {"success": True}
328
+ except Exception as e:
329
+ print(f"Erreur lors de l'ajout d'un message: {str(e)}")
330
+ raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
331
+
332
+ # Endpoint pour récupérer les messages d'une conversation
333
+ @app.get("/api/conversations/{conversation_id}/messages")
334
+ async def get_messages(conversation_id: str, current_user: dict = Depends(get_current_user)):
335
+ try:
336
+ user_id = str(current_user["_id"])
337
+
338
+ # Vérifier que la conversation appartient à l'utilisateur
339
+ conversation = db.conversations.find_one({
340
+ "_id": ObjectId(conversation_id),
341
+ "user_id": user_id
342
+ })
343
+
344
+ if not conversation:
345
+ raise HTTPException(status_code=404, detail="Conversation non trouvée")
346
+
347
+ # Récupérer les messages
348
+ messages = list(db.messages.find(
349
+ {"conversation_id": conversation_id}
350
+ ).sort("timestamp", 1)) # Du plus ancien au plus récent
351
+
352
+ # Convertir les ObjectId en strings pour la sérialisation JSON
353
+ for msg in messages:
354
+ msg["_id"] = str(msg["_id"])
355
+ if "timestamp" in msg:
356
+ msg["timestamp"] = msg["timestamp"].isoformat()
357
+
358
+ return {"messages": messages}
359
+ except Exception as e:
360
+ raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
361
+
362
+ @app.delete("/api/conversations/{conversation_id}")
363
+ async def delete_conversation(conversation_id: str, current_user: dict = Depends(get_current_user)):
364
+ try:
365
+ user_id = str(current_user["_id"])
366
+
367
+ # Vérifier que la conversation appartient à l'utilisateur
368
+ result = db.conversations.delete_one({
369
+ "_id": ObjectId(conversation_id),
370
+ "user_id": user_id
371
+ })
372
+
373
+ if result.deleted_count == 0:
374
+ raise HTTPException(status_code=404, detail="Conversation non trouvée")
375
+
376
+ # Supprimer également tous les messages associés
377
+ db.messages.delete_many({"conversation_id": conversation_id})
378
+
379
+ return {"success": True}
380
+ except Exception as e:
381
+ raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")