AxL95 commited on
Commit
b01bbf5
·
verified ·
1 Parent(s): 42580a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -54
app.py CHANGED
@@ -5,18 +5,18 @@ from fastapi.staticfiles import StaticFiles
5
  from huggingface_hub import InferenceClient
6
  import secrets
7
  from typing import Optional
8
- from bson.objectid import ObjectId # Pour les _id MongoDB
9
- from datetime import datetime, timedelta # Ajout de timedelta pour expires_at
10
  from fastapi import Request
11
  import requests
12
  import numpy as np
13
  import argparse
14
  import os
15
  from pymongo import MongoClient
16
- from datetime import datetime # Ajout pour createdAt
17
- from passlib.hash import bcrypt # Ajout pour le hash du mot de passe
18
 
19
- SECRET_KEY = secrets.token_hex(32) # Génère une clé aléatoire
20
 
21
  HOST = os.environ.get("API_URL", "0.0.0.0")
22
  PORT = os.environ.get("PORT", 7860)
@@ -39,9 +39,9 @@ db = mongo_client[db_name]
39
  app = FastAPI()
40
  app.add_middleware(
41
  CORSMiddleware,
42
- # Spécifiez les origines exactes au lieu de "*"
43
  allow_origins=[
44
- "https://axl95-medically.hf.space", # URL de votre espace HF
45
  "https://huggingface.co",
46
  "http://localhost:3000",
47
  "http://localhost:7860",
@@ -54,7 +54,6 @@ app.add_middleware(
54
 
55
 
56
 
57
- # Fonction pour créer une session
58
  @app.post("/api/login")
59
  async def login(request: Request, response: Response):
60
  try:
@@ -62,12 +61,10 @@ async def login(request: Request, response: Response):
62
  email = data.get("email")
63
  password = data.get("password")
64
 
65
- # Vérifier les identifiants
66
  user = db.users.find_one({"email": email})
67
  if not user or not bcrypt.verify(password, user["password"]):
68
  raise HTTPException(status_code=401, detail="Email ou mot de passe incorrect")
69
 
70
- # Créer une session
71
  session_id = secrets.token_hex(16)
72
  user_id = str(user["_id"])
73
  username = f"{user['prenom']} {user['nom']}"
@@ -84,22 +81,21 @@ async def login(request: Request, response: Response):
84
  response.set_cookie(
85
  key="session_id",
86
  value=session_id,
87
- httponly=False, # Permettre à JavaScript d'accéder au cookie
88
- max_age=7*24*60*60, # 7 jours
89
- samesite="none", # Obligatoire pour le cross-site
90
- secure=True, # Obligatoire avec samesite=none
91
- path="/" # Accessible sur tout le domaine
92
  )
93
 
94
  # Log pour débogage
95
  print(f"Session créée: {session_id} pour l'utilisateur {user_id}")
96
 
97
- # Renvoyer le session_id dans la réponse pour permettre le stockage manuel si nécessaire
98
  return {
99
  "success": True,
100
  "username": username,
101
  "user_id": user_id,
102
- "session_id": session_id # Permet de stocker la session côté client
103
  }
104
 
105
  except Exception as e:
@@ -108,18 +104,15 @@ async def login(request: Request, response: Response):
108
 
109
 
110
  async def get_current_user(request: Request):
111
- # 1. Essayer d'obtenir la session du cookie
112
  session_id = request.cookies.get("session_id")
113
  print(f"Cookie de session reçu: {session_id[:5] if session_id else 'None'}")
114
 
115
- # 2. Si pas de cookie, vérifier l'en-tête d'autorisation
116
  if not session_id:
117
  auth_header = request.headers.get("Authorization")
118
  if auth_header and auth_header.startswith("Bearer "):
119
  session_id = auth_header.replace("Bearer ", "")
120
  print(f"Session d'autorisation reçue: {session_id[:5]}...")
121
 
122
- # 3. Si toujours pas de session, vérifier les paramètres de requête
123
  if not session_id:
124
  session_id = request.query_params.get("session_id")
125
  if session_id:
@@ -128,7 +121,6 @@ async def get_current_user(request: Request):
128
  if not session_id:
129
  raise HTTPException(status_code=401, detail="Non authentifié - Aucune session trouvée")
130
 
131
- # Vérifier si la session existe et n'est pas expirée
132
  session = db.sessions.find_one({
133
  "session_id": session_id,
134
  "expires_at": {"$gt": datetime.utcnow()}
@@ -157,21 +149,17 @@ async def register(request: Request):
157
  try:
158
  data = await request.json()
159
 
160
- # Validation
161
  required_fields = ["prenom", "nom", "email", "password"]
162
  for field in required_fields:
163
  if not data.get(field):
164
  raise HTTPException(status_code=400, detail=f"Le champ {field} est requis")
165
 
166
- # Vérifier si l'email existe déjà
167
  existing_user = db.users.find_one({"email": data["email"]})
168
  if existing_user:
169
  raise HTTPException(status_code=409, detail="Cet email est déjà utilisé")
170
 
171
- # Hash du mot de passe (installer passlib avec pip install passlib[bcrypt])
172
  hashed_password = bcrypt.hash(data["password"])
173
 
174
- # Insérer l'utilisateur
175
  user = {
176
  "prenom": data["prenom"],
177
  "nom": data["nom"],
@@ -185,26 +173,20 @@ async def register(request: Request):
185
  return {"message": "Utilisateur créé avec succès", "userId": str(result.inserted_id)}
186
 
187
  except HTTPException as he:
188
- # Re-lever les HTTPException pour FastAPI les traite
189
  raise he
190
 
191
  except Exception as e:
192
- # Logging détaillé pour débogage
193
  import traceback
194
  print(f"Erreur lors de l'inscription: {str(e)}")
195
  print(traceback.format_exc())
196
  raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
197
- # Par cette solution utilisant l'API d'inférence directement:
198
  @app.post("/api/embed")
199
  async def embed(request: Request):
200
  data = await request.json()
201
  texts = data.get("texts", [])
202
 
203
  try:
204
- # Si vous utilisez un embedder personnalisé
205
- # embeddings = embedder.encode(texts).tolist()
206
 
207
- # Pour déboguer, renvoyez simplement un embedding fictif
208
  dummy_embedding = [[0.1, 0.2, 0.3] for _ in range(len(texts))]
209
 
210
  return {"embeddings": dummy_embedding}
@@ -222,9 +204,6 @@ HF_TOKEN = os.getenv('REACT_APP_HF_TOKEN')
222
  if not HF_TOKEN:
223
  raise RuntimeError("Le token Hugging Face (HF_TOKEN) n'est pas défini dans les variables d'environnement.")
224
 
225
- # Initialisation du client HF
226
-
227
- # Par cette version correcte
228
  hf_client = InferenceClient(token=HF_TOKEN)
229
 
230
  @app.post("/api/chat")
@@ -235,13 +214,12 @@ async def chat(request: Request):
235
  raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
236
 
237
  try:
238
- # Utiliser le provider novita comme demandé
239
  response = hf_client.text_generation(
240
  model="mistralai/Mistral-7B-Instruct-v0.3",
241
  prompt=f"<s>[INST] Tu es un assistant médical spécialisé en schizophrénie. Réponds à cette question: {user_message} [/INST]",
242
  max_new_tokens=512,
243
  temperature=0.7,
244
- provider="novita" # Spécifier le provider novita ici
245
  )
246
 
247
  return {"response": response}
@@ -258,18 +236,15 @@ async def get_data():
258
  data = {"data": np.random.rand(100).tolist()}
259
  return JSONResponse(data)
260
 
261
- # Endpoint pour récupérer toutes les conversations d'un utilisateur
262
  @app.get("/api/conversations")
263
  async def get_conversations(current_user: dict = Depends(get_current_user)):
264
  try:
265
  user_id = str(current_user["_id"])
266
- # Récupération des conversations triées par date (les plus récentes d'abord)
267
  conversations = list(db.conversations.find(
268
  {"user_id": user_id},
269
  {"_id": 1, "title": 1, "date": 1, "time": 1, "last_message": 1, "created_at": 1}
270
  ).sort("created_at", -1))
271
 
272
- # Convertir les ObjectId en strings pour la sérialisation JSON
273
  for conv in conversations:
274
  conv["_id"] = str(conv["_id"])
275
 
@@ -277,14 +252,12 @@ async def get_conversations(current_user: dict = Depends(get_current_user)):
277
  except Exception as e:
278
  raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
279
 
280
- # Endpoint pour sauvegarder une nouvelle conversation
281
  @app.post("/api/conversations")
282
  async def create_conversation(request: Request, current_user: dict = Depends(get_current_user)):
283
  try:
284
  data = await request.json()
285
  user_id = str(current_user["_id"])
286
 
287
- # Créer la nouvelle conversation
288
  conversation = {
289
  "user_id": user_id,
290
  "title": data.get("title", "Nouvelle conversation"),
@@ -296,22 +269,18 @@ async def create_conversation(request: Request, current_user: dict = Depends(get
296
 
297
  result = db.conversations.insert_one(conversation)
298
 
299
- # Retourner l'ID de la conversation créée
300
  return {"conversation_id": str(result.inserted_id)}
301
  except Exception as e:
302
  raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
303
 
304
- # Endpoint pour sauvegarder un message dans une conversation
305
  @app.post("/api/conversations/{conversation_id}/messages")
306
  async def add_message(conversation_id: str, request: Request, current_user: dict = Depends(get_current_user)):
307
  try:
308
  data = await request.json()
309
  user_id = str(current_user["_id"])
310
 
311
- # Debug pour vérifier les données
312
  print(f"Ajout message: conversation_id={conversation_id}, sender={data.get('sender')}, text={data.get('text')[:20]}...")
313
 
314
- # Vérifier que la conversation appartient à l'utilisateur
315
  conversation = db.conversations.find_one({
316
  "_id": ObjectId(conversation_id),
317
  "user_id": user_id
@@ -320,7 +289,6 @@ async def add_message(conversation_id: str, request: Request, current_user: dict
320
  if not conversation:
321
  raise HTTPException(status_code=404, detail="Conversation non trouvée")
322
 
323
- # Ajouter le message
324
  message = {
325
  "conversation_id": conversation_id,
326
  "user_id": user_id,
@@ -331,7 +299,6 @@ async def add_message(conversation_id: str, request: Request, current_user: dict
331
 
332
  db.messages.insert_one(message)
333
 
334
- # Mettre à jour la dernière activité de la conversation
335
  db.conversations.update_one(
336
  {"_id": ObjectId(conversation_id)},
337
  {"$set": {"last_message": data.get("text", ""), "updated_at": datetime.utcnow()}}
@@ -342,13 +309,11 @@ async def add_message(conversation_id: str, request: Request, current_user: dict
342
  print(f"Erreur lors de l'ajout d'un message: {str(e)}")
343
  raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
344
 
345
- # Endpoint pour récupérer les messages d'une conversation
346
  @app.get("/api/conversations/{conversation_id}/messages")
347
  async def get_messages(conversation_id: str, current_user: dict = Depends(get_current_user)):
348
  try:
349
  user_id = str(current_user["_id"])
350
 
351
- # Vérifier que la conversation appartient à l'utilisateur
352
  conversation = db.conversations.find_one({
353
  "_id": ObjectId(conversation_id),
354
  "user_id": user_id
@@ -357,12 +322,10 @@ async def get_messages(conversation_id: str, current_user: dict = Depends(get_cu
357
  if not conversation:
358
  raise HTTPException(status_code=404, detail="Conversation non trouvée")
359
 
360
- # Récupérer les messages
361
  messages = list(db.messages.find(
362
  {"conversation_id": conversation_id}
363
- ).sort("timestamp", 1)) # Du plus ancien au plus récent
364
 
365
- # Convertir les ObjectId en strings pour la sérialisation JSON
366
  for msg in messages:
367
  msg["_id"] = str(msg["_id"])
368
  if "timestamp" in msg:
@@ -377,7 +340,6 @@ async def delete_conversation(conversation_id: str, current_user: dict = Depends
377
  try:
378
  user_id = str(current_user["_id"])
379
 
380
- # Vérifier que la conversation appartient à l'utilisateur
381
  result = db.conversations.delete_one({
382
  "_id": ObjectId(conversation_id),
383
  "user_id": user_id
@@ -386,7 +348,6 @@ async def delete_conversation(conversation_id: str, current_user: dict = Depends
386
  if result.deleted_count == 0:
387
  raise HTTPException(status_code=404, detail="Conversation non trouvée")
388
 
389
- # Supprimer également tous les messages associés
390
  db.messages.delete_many({"conversation_id": conversation_id})
391
 
392
  return {"success": True}
 
5
  from huggingface_hub import InferenceClient
6
  import secrets
7
  from typing import Optional
8
+ from bson.objectid import ObjectId
9
+ from datetime import datetime, timedelta
10
  from fastapi import Request
11
  import requests
12
  import numpy as np
13
  import argparse
14
  import os
15
  from pymongo import MongoClient
16
+ from datetime import datetime
17
+ from passlib.hash import bcrypt
18
 
19
+ SECRET_KEY = secrets.token_hex(32)
20
 
21
  HOST = os.environ.get("API_URL", "0.0.0.0")
22
  PORT = os.environ.get("PORT", 7860)
 
39
  app = FastAPI()
40
  app.add_middleware(
41
  CORSMiddleware,
42
+
43
  allow_origins=[
44
+ "https://axl95-medically.hf.space",
45
  "https://huggingface.co",
46
  "http://localhost:3000",
47
  "http://localhost:7860",
 
54
 
55
 
56
 
 
57
  @app.post("/api/login")
58
  async def login(request: Request, response: Response):
59
  try:
 
61
  email = data.get("email")
62
  password = data.get("password")
63
 
 
64
  user = db.users.find_one({"email": email})
65
  if not user or not bcrypt.verify(password, user["password"]):
66
  raise HTTPException(status_code=401, detail="Email ou mot de passe incorrect")
67
 
 
68
  session_id = secrets.token_hex(16)
69
  user_id = str(user["_id"])
70
  username = f"{user['prenom']} {user['nom']}"
 
81
  response.set_cookie(
82
  key="session_id",
83
  value=session_id,
84
+ httponly=False,
85
+ max_age=7*24*60*60,
86
+ samesite="none",
87
+ secure=True,
88
+ path="/"
89
  )
90
 
91
  # Log pour débogage
92
  print(f"Session créée: {session_id} pour l'utilisateur {user_id}")
93
 
 
94
  return {
95
  "success": True,
96
  "username": username,
97
  "user_id": user_id,
98
+ "session_id": session_id
99
  }
100
 
101
  except Exception as e:
 
104
 
105
 
106
  async def get_current_user(request: Request):
 
107
  session_id = request.cookies.get("session_id")
108
  print(f"Cookie de session reçu: {session_id[:5] if session_id else 'None'}")
109
 
 
110
  if not session_id:
111
  auth_header = request.headers.get("Authorization")
112
  if auth_header and auth_header.startswith("Bearer "):
113
  session_id = auth_header.replace("Bearer ", "")
114
  print(f"Session d'autorisation reçue: {session_id[:5]}...")
115
 
 
116
  if not session_id:
117
  session_id = request.query_params.get("session_id")
118
  if session_id:
 
121
  if not session_id:
122
  raise HTTPException(status_code=401, detail="Non authentifié - Aucune session trouvée")
123
 
 
124
  session = db.sessions.find_one({
125
  "session_id": session_id,
126
  "expires_at": {"$gt": datetime.utcnow()}
 
149
  try:
150
  data = await request.json()
151
 
 
152
  required_fields = ["prenom", "nom", "email", "password"]
153
  for field in required_fields:
154
  if not data.get(field):
155
  raise HTTPException(status_code=400, detail=f"Le champ {field} est requis")
156
 
 
157
  existing_user = db.users.find_one({"email": data["email"]})
158
  if existing_user:
159
  raise HTTPException(status_code=409, detail="Cet email est déjà utilisé")
160
 
 
161
  hashed_password = bcrypt.hash(data["password"])
162
 
 
163
  user = {
164
  "prenom": data["prenom"],
165
  "nom": data["nom"],
 
173
  return {"message": "Utilisateur créé avec succès", "userId": str(result.inserted_id)}
174
 
175
  except HTTPException as he:
 
176
  raise he
177
 
178
  except Exception as e:
 
179
  import traceback
180
  print(f"Erreur lors de l'inscription: {str(e)}")
181
  print(traceback.format_exc())
182
  raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
 
183
  @app.post("/api/embed")
184
  async def embed(request: Request):
185
  data = await request.json()
186
  texts = data.get("texts", [])
187
 
188
  try:
 
 
189
 
 
190
  dummy_embedding = [[0.1, 0.2, 0.3] for _ in range(len(texts))]
191
 
192
  return {"embeddings": dummy_embedding}
 
204
  if not HF_TOKEN:
205
  raise RuntimeError("Le token Hugging Face (HF_TOKEN) n'est pas défini dans les variables d'environnement.")
206
 
 
 
 
207
  hf_client = InferenceClient(token=HF_TOKEN)
208
 
209
  @app.post("/api/chat")
 
214
  raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
215
 
216
  try:
 
217
  response = hf_client.text_generation(
218
  model="mistralai/Mistral-7B-Instruct-v0.3",
219
  prompt=f"<s>[INST] Tu es un assistant médical spécialisé en schizophrénie. Réponds à cette question: {user_message} [/INST]",
220
  max_new_tokens=512,
221
  temperature=0.7,
222
+ provider="novita"
223
  )
224
 
225
  return {"response": response}
 
236
  data = {"data": np.random.rand(100).tolist()}
237
  return JSONResponse(data)
238
 
 
239
  @app.get("/api/conversations")
240
  async def get_conversations(current_user: dict = Depends(get_current_user)):
241
  try:
242
  user_id = str(current_user["_id"])
 
243
  conversations = list(db.conversations.find(
244
  {"user_id": user_id},
245
  {"_id": 1, "title": 1, "date": 1, "time": 1, "last_message": 1, "created_at": 1}
246
  ).sort("created_at", -1))
247
 
 
248
  for conv in conversations:
249
  conv["_id"] = str(conv["_id"])
250
 
 
252
  except Exception as e:
253
  raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
254
 
 
255
  @app.post("/api/conversations")
256
  async def create_conversation(request: Request, current_user: dict = Depends(get_current_user)):
257
  try:
258
  data = await request.json()
259
  user_id = str(current_user["_id"])
260
 
 
261
  conversation = {
262
  "user_id": user_id,
263
  "title": data.get("title", "Nouvelle conversation"),
 
269
 
270
  result = db.conversations.insert_one(conversation)
271
 
 
272
  return {"conversation_id": str(result.inserted_id)}
273
  except Exception as e:
274
  raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
275
 
 
276
  @app.post("/api/conversations/{conversation_id}/messages")
277
  async def add_message(conversation_id: str, request: Request, current_user: dict = Depends(get_current_user)):
278
  try:
279
  data = await request.json()
280
  user_id = str(current_user["_id"])
281
 
 
282
  print(f"Ajout message: conversation_id={conversation_id}, sender={data.get('sender')}, text={data.get('text')[:20]}...")
283
 
 
284
  conversation = db.conversations.find_one({
285
  "_id": ObjectId(conversation_id),
286
  "user_id": user_id
 
289
  if not conversation:
290
  raise HTTPException(status_code=404, detail="Conversation non trouvée")
291
 
 
292
  message = {
293
  "conversation_id": conversation_id,
294
  "user_id": user_id,
 
299
 
300
  db.messages.insert_one(message)
301
 
 
302
  db.conversations.update_one(
303
  {"_id": ObjectId(conversation_id)},
304
  {"$set": {"last_message": data.get("text", ""), "updated_at": datetime.utcnow()}}
 
309
  print(f"Erreur lors de l'ajout d'un message: {str(e)}")
310
  raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
311
 
 
312
  @app.get("/api/conversations/{conversation_id}/messages")
313
  async def get_messages(conversation_id: str, current_user: dict = Depends(get_current_user)):
314
  try:
315
  user_id = str(current_user["_id"])
316
 
 
317
  conversation = db.conversations.find_one({
318
  "_id": ObjectId(conversation_id),
319
  "user_id": user_id
 
322
  if not conversation:
323
  raise HTTPException(status_code=404, detail="Conversation non trouvée")
324
 
 
325
  messages = list(db.messages.find(
326
  {"conversation_id": conversation_id}
327
+ ).sort("timestamp", 1))
328
 
 
329
  for msg in messages:
330
  msg["_id"] = str(msg["_id"])
331
  if "timestamp" in msg:
 
340
  try:
341
  user_id = str(current_user["_id"])
342
 
 
343
  result = db.conversations.delete_one({
344
  "_id": ObjectId(conversation_id),
345
  "user_id": user_id
 
348
  if result.deleted_count == 0:
349
  raise HTTPException(status_code=404, detail="Conversation non trouvée")
350
 
 
351
  db.messages.delete_many({"conversation_id": conversation_id})
352
 
353
  return {"success": True}