|
from fastapi import FastAPI, Request, HTTPException,Depends, Response |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.responses import JSONResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from huggingface_hub import InferenceClient |
|
import secrets |
|
from typing import Optional |
|
from bson.objectid import ObjectId |
|
from datetime import datetime, timedelta |
|
from fastapi import Request |
|
import requests |
|
import numpy as np |
|
import argparse |
|
import os |
|
from pymongo import MongoClient |
|
from datetime import datetime |
|
from passlib.hash import bcrypt |
|
|
|
SECRET_KEY = secrets.token_hex(32) |
|
|
|
HOST = os.environ.get("API_URL", "0.0.0.0") |
|
PORT = os.environ.get("PORT", 7860) |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--host", default=HOST) |
|
parser.add_argument("--port", type=int, default=PORT) |
|
parser.add_argument("--reload", action="store_true", default=True) |
|
parser.add_argument("--ssl_certfile") |
|
parser.add_argument("--ssl_keyfile") |
|
args = parser.parse_args() |
|
|
|
|
|
mongo_uri = os.environ.get("MONGODB_URI", "mongodb+srv://giffardaxel95:[email protected]/") |
|
db_name = os.environ.get("DB_NAME", "chatmed_schizo") |
|
mongo_client = MongoClient(mongo_uri) |
|
db = mongo_client[db_name] |
|
|
|
|
|
|
|
app = FastAPI() |
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
|
|
|
|
@app.post("/api/login") |
|
async def login(request: Request, response: Response): |
|
try: |
|
data = await request.json() |
|
email = data.get("email") |
|
password = data.get("password") |
|
|
|
|
|
user = db.users.find_one({"email": email}) |
|
if not user or not bcrypt.verify(password, user["password"]): |
|
raise HTTPException(status_code=401, detail="Email ou mot de passe incorrect") |
|
|
|
|
|
session_id = secrets.token_hex(16) |
|
user_id = str(user["_id"]) |
|
username = f"{user['prenom']} {user['nom']}" |
|
|
|
|
|
db.sessions.insert_one({ |
|
"session_id": session_id, |
|
"user_id": user_id, |
|
"created_at": datetime.utcnow(), |
|
"expires_at": datetime.utcnow() + timedelta(days=7) |
|
}) |
|
|
|
|
|
response.set_cookie( |
|
key="session_id", |
|
value=session_id, |
|
httponly=True, |
|
max_age=7*24*60*60, |
|
samesite="lax" |
|
) |
|
|
|
return {"success": True, "username": username, "user_id": user_id} |
|
|
|
except HTTPException as he: |
|
raise he |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
async def get_current_user(request: Request): |
|
session_id = request.cookies.get("session_id") |
|
if not session_id: |
|
raise HTTPException(status_code=401, detail="Non authentifié") |
|
|
|
|
|
session = db.sessions.find_one({ |
|
"session_id": session_id, |
|
"expires_at": {"$gt": datetime.utcnow()} |
|
}) |
|
|
|
if not session: |
|
raise HTTPException(status_code=401, detail="Session expirée ou invalide") |
|
|
|
user = db.users.find_one({"_id": ObjectId(session["user_id"])}) |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Utilisateur non trouvé") |
|
|
|
return user |
|
|
|
|
|
@app.post("/api/logout") |
|
async def logout(request: Request, response: Response): |
|
session_id = request.cookies.get("session_id") |
|
if session_id: |
|
db.sessions.delete_one({"session_id": session_id}) |
|
|
|
response.delete_cookie(key="session_id") |
|
return {"success": True} |
|
@app.post("/api/register") |
|
async def register(request: Request): |
|
try: |
|
data = await request.json() |
|
|
|
|
|
required_fields = ["prenom", "nom", "email", "password"] |
|
for field in required_fields: |
|
if not data.get(field): |
|
raise HTTPException(status_code=400, detail=f"Le champ {field} est requis") |
|
|
|
|
|
existing_user = db.users.find_one({"email": data["email"]}) |
|
if existing_user: |
|
raise HTTPException(status_code=409, detail="Cet email est déjà utilisé") |
|
|
|
|
|
hashed_password = bcrypt.hash(data["password"]) |
|
|
|
|
|
user = { |
|
"prenom": data["prenom"], |
|
"nom": data["nom"], |
|
"email": data["email"], |
|
"password": hashed_password, |
|
"createdAt": datetime.utcnow() |
|
} |
|
|
|
result = db.users.insert_one(user) |
|
|
|
return {"message": "Utilisateur créé avec succès", "userId": str(result.inserted_id)} |
|
|
|
except HTTPException as he: |
|
|
|
raise he |
|
|
|
except Exception as e: |
|
|
|
import traceback |
|
print(f"Erreur lors de l'inscription: {str(e)}") |
|
print(traceback.format_exc()) |
|
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}") |
|
|
|
@app.post("/api/embed") |
|
async def embed(request: Request): |
|
data = await request.json() |
|
texts = data.get("texts", []) |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
dummy_embedding = [[0.1, 0.2, 0.3] for _ in range(len(texts))] |
|
|
|
return {"embeddings": dummy_embedding} |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
@app.get("/invert") |
|
async def invert(text: str): |
|
return { |
|
"original": text, |
|
"inverted": text[::-1], |
|
} |
|
|
|
HF_TOKEN = os.getenv('REACT_APP_HF_TOKEN') |
|
if not HF_TOKEN: |
|
raise RuntimeError("Le token Hugging Face (HF_TOKEN) n'est pas défini dans les variables d'environnement.") |
|
|
|
|
|
|
|
|
|
hf_client = InferenceClient(token=HF_TOKEN) |
|
|
|
@app.post("/api/chat") |
|
async def chat(request: Request): |
|
data = await request.json() |
|
user_message = data.get("message", "").strip() |
|
if not user_message: |
|
from fastapi import HTTPException |
|
raise HTTPException(status_code=400, detail="Le champ 'message' est requis.") |
|
|
|
try: |
|
|
|
response = hf_client.text_generation( |
|
model="mistralai/Mistral-7B-Instruct-v0.3", |
|
prompt=f"<s>[INST] Tu es un assistant médical spécialisé en schizophrénie. Réponds à cette question: {user_message} [/INST]", |
|
max_new_tokens=512, |
|
temperature=0.7 |
|
) |
|
|
|
return {"response": response} |
|
|
|
except Exception as e: |
|
from fastapi import HTTPException |
|
import traceback |
|
print(f"Erreur détaillée: {traceback.format_exc()}") |
|
raise HTTPException(status_code=502, detail=f"Erreur d'inférence HF : {str(e)}") |
|
|
|
|
|
@app.get("/data") |
|
async def get_data(): |
|
data = {"data": np.random.rand(100).tolist()} |
|
return JSONResponse(data) |
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
def read_root(): |
|
return {"message": "API Medically fonctionnelle", "endpoints": ["/api/chat", "/invert", "/data"]} |
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
print(args) |
|
uvicorn.run( |
|
"app:app", |
|
host=args.host, |
|
port=args.port, |
|
reload=args.reload, |
|
ssl_certfile=args.ssl_certfile, |
|
ssl_keyfile=args.ssl_keyfile, |
|
) |
|
|
|
|
|
|
|
@app.get("/api/conversations") |
|
async def get_conversations(current_user: dict = Depends(get_current_user)): |
|
try: |
|
user_id = str(current_user["_id"]) |
|
|
|
conversations = list(db.conversations.find( |
|
{"user_id": user_id}, |
|
{"_id": 1, "title": 1, "date": 1, "time": 1, "last_message": 1, "created_at": 1} |
|
).sort("created_at", -1)) |
|
|
|
|
|
for conv in conversations: |
|
conv["_id"] = str(conv["_id"]) |
|
|
|
return {"conversations": conversations} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}") |
|
|
|
|
|
@app.post("/api/conversations") |
|
async def create_conversation(request: Request, current_user: dict = Depends(get_current_user)): |
|
try: |
|
data = await request.json() |
|
user_id = str(current_user["_id"]) |
|
|
|
|
|
conversation = { |
|
"user_id": user_id, |
|
"title": data.get("title", "Nouvelle conversation"), |
|
"date": data.get("date"), |
|
"time": data.get("time"), |
|
"last_message": data.get("message", ""), |
|
"created_at": datetime.utcnow() |
|
} |
|
|
|
result = db.conversations.insert_one(conversation) |
|
|
|
|
|
return {"conversation_id": str(result.inserted_id)} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}") |
|
|
|
|
|
@app.post("/api/conversations/{conversation_id}/messages") |
|
async def add_message(conversation_id: str, request: Request, current_user: dict = Depends(get_current_user)): |
|
try: |
|
data = await request.json() |
|
user_id = str(current_user["_id"]) |
|
|
|
|
|
print(f"Ajout message: conversation_id={conversation_id}, sender={data.get('sender')}, text={data.get('text')[:20]}...") |
|
|
|
|
|
conversation = db.conversations.find_one({ |
|
"_id": ObjectId(conversation_id), |
|
"user_id": user_id |
|
}) |
|
|
|
if not conversation: |
|
raise HTTPException(status_code=404, detail="Conversation non trouvée") |
|
|
|
|
|
message = { |
|
"conversation_id": conversation_id, |
|
"user_id": user_id, |
|
"sender": data.get("sender", "user"), |
|
"text": data.get("text", ""), |
|
"timestamp": datetime.utcnow() |
|
} |
|
|
|
db.messages.insert_one(message) |
|
|
|
|
|
db.conversations.update_one( |
|
{"_id": ObjectId(conversation_id)}, |
|
{"$set": {"last_message": data.get("text", ""), "updated_at": datetime.utcnow()}} |
|
) |
|
|
|
return {"success": True} |
|
except Exception as e: |
|
print(f"Erreur lors de l'ajout d'un message: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}") |
|
|
|
|
|
@app.get("/api/conversations/{conversation_id}/messages") |
|
async def get_messages(conversation_id: str, current_user: dict = Depends(get_current_user)): |
|
try: |
|
user_id = str(current_user["_id"]) |
|
|
|
|
|
conversation = db.conversations.find_one({ |
|
"_id": ObjectId(conversation_id), |
|
"user_id": user_id |
|
}) |
|
|
|
if not conversation: |
|
raise HTTPException(status_code=404, detail="Conversation non trouvée") |
|
|
|
|
|
messages = list(db.messages.find( |
|
{"conversation_id": conversation_id} |
|
).sort("timestamp", 1)) |
|
|
|
|
|
for msg in messages: |
|
msg["_id"] = str(msg["_id"]) |
|
if "timestamp" in msg: |
|
msg["timestamp"] = msg["timestamp"].isoformat() |
|
|
|
return {"messages": messages} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}") |
|
|
|
@app.delete("/api/conversations/{conversation_id}") |
|
async def delete_conversation(conversation_id: str, current_user: dict = Depends(get_current_user)): |
|
try: |
|
user_id = str(current_user["_id"]) |
|
|
|
|
|
result = db.conversations.delete_one({ |
|
"_id": ObjectId(conversation_id), |
|
"user_id": user_id |
|
}) |
|
|
|
if result.deleted_count == 0: |
|
raise HTTPException(status_code=404, detail="Conversation non trouvée") |
|
|
|
|
|
db.messages.delete_many({"conversation_id": conversation_id}) |
|
|
|
return {"success": True} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}") |