from fastapi import FastAPI from transformers import AutoTokenizer, AutoModelForCausalLM import torch from pydantic import BaseModel import os # Hugging Face cache directory CACHE_DIR = "/tmp/huggingface" os.makedirs(CACHE_DIR, exist_ok=True) os.environ["HF_HOME"] = CACHE_DIR os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR # Model ID (FP8 requires GPU) MODEL_ID = "deepseek-ai/DeepSeek-R1" FALLBACK_MODEL_ID = "gpt2" # CPU-friendly fallback # Detect GPU device = "cuda" if torch.cuda.is_available() else "cpu" try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR) model = AutoModelForCausalLM.from_pretrained(MODEL_ID, cache_dir=CACHE_DIR).to(device) except Exception as e: print(f"⚠️ Failed to load GPU FP8 model: {e}") print(f"🔹 Falling back to CPU-friendly model: {FALLBACK_MODEL_ID}") tokenizer = AutoTokenizer.from_pretrained(FALLBACK_MODEL_ID, cache_dir=CACHE_DIR) model = AutoModelForCausalLM.from_pretrained(FALLBACK_MODEL_ID, cache_dir=CACHE_DIR).to(device) # FastAPI app app = FastAPI(title="QA GPT API", description="Hugging Face model served via FastAPI") # Request schema class QueryRequest(BaseModel): question: str max_new_tokens: int = 50 temperature: float = 0.7 top_p: float = 0.9 @app.get("/") def home(): return {"message": "Welcome to QA GPT API 🚀"} @app.get("/ask") def ask(question: str, max_new_tokens: int = 50): inputs = tokenizer(question, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_new_tokens=max_new_tokens) answer = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"question": question, "answer": answer} @app.get("/health") def health(): return {"status": "ok"} @app.post("/predict") def predict(request: QueryRequest): inputs = tokenizer(request.question, return_tensors="pt").to(device) outputs = model.generate( **inputs, max_new_tokens=request.max_new_tokens, do_sample=True, temperature=request.temperature, top_p=request.top_p, pad_token_id=tokenizer.eos_token_id, return_dict_in_generate=True ) answer = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True) return {"question": request.question, "answer": answer}