from fastapi import FastAPI, HTTPException from pydantic import BaseModel import xgboost as xgb import numpy as np import pickle from huggingface_hub import hf_hub_download import os import sys from typing import List, Union app = FastAPI(title="Headache Predictor API") # Load model at startup model = None @app.on_event("startup") async def load_model(): global model try: # Set cache directory to writable location cache_dir = "/tmp/hf_cache" os.makedirs(cache_dir, exist_ok=True) # Get HF token from environment (set as Space secret) hf_token = os.environ.get("HF_TOKEN") model_path = hf_hub_download( repo_id="emp-admin/headache-predictor-xgboost", filename="model.pkl", cache_dir=cache_dir, token=hf_token # Use token for private repo access ) with open(model_path, 'rb') as f: model_data = pickle.load(f) # Handle both dict format and raw model if isinstance(model_data, dict): model = model_data['model'] print(f"✅ Model loaded successfully (threshold: {model_data.get('optimal_threshold', 0.5)})") else: model = model_data print("✅ Model loaded successfully") except Exception as e: print(f"❌ Error loading model: {e}") import traceback traceback.print_exc() class SinglePredictionRequest(BaseModel): features: List[float] class BatchPredictionRequest(BaseModel): instances: List[List[float]] class DayPrediction(BaseModel): day: int prediction: int probability: float # Probability of HEADACHE (class 1), regardless of prediction class SinglePredictionResponse(BaseModel): prediction: int probability: float # Probability of HEADACHE (class 1), regardless of prediction class BatchPredictionResponse(BaseModel): predictions: List[DayPrediction] @app.get("/") def read_root(): return { "message": "Headache Predictor API", "status": "running", "endpoints": { "predict": "/predict - Single day prediction", "predict_batch": "/predict/batch - 7-day forecast", "health": "/health" }, "examples": { "single": { "url": "/predict", "body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]} }, "batch": { "url": "/predict/batch", "body": {"instances": [["array of 37 features for day 1"], ["array for day 2"], "..."]} } } } @app.get("/health") def health_check(): return { "status": "healthy", "model_loaded": model is not None } @app.post("/predict", response_model=SinglePredictionResponse) def predict(request: SinglePredictionRequest): """Predict headache risk for a single day""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Convert input to numpy array features = np.array(request.features).reshape(1, -1) # Get probability array for both classes prob_array = model.predict_proba(features)[0] # Always return probability of headache (class 1) headache_probability = float(prob_array[1]) # Make prediction using threshold if available if isinstance(model, dict) and 'optimal_threshold' in model: threshold = model['optimal_threshold'] prediction = 1 if headache_probability >= threshold else 0 else: prediction = model.predict(features)[0] return SinglePredictionResponse( prediction=int(prediction), probability=headache_probability ) except Exception as e: raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}") @app.post("/predict/batch", response_model=BatchPredictionResponse) def predict_batch(request: BatchPredictionRequest): """Predict headache risk for multiple days (7-day forecast)""" if model is None: raise HTTPException(status_code=503, detail="Model not loaded") try: # Convert all instances to numpy array features = np.array(request.instances) if features.ndim != 2: raise ValueError(f"Expected 2D array, got shape {features.shape}") # Get probabilities for all days probabilities = model.predict_proba(features) # Format results results = [] for i, prob_array in enumerate(probabilities, 1): # Always use probability of headache (class 1) headache_probability = float(prob_array[1]) # Make prediction using threshold if available if isinstance(model, dict) and 'optimal_threshold' in model: threshold = model['optimal_threshold'] prediction = 1 if headache_probability >= threshold else 0 else: prediction = model.predict(features[i-1:i])[0] results.append(DayPrediction( day=i, prediction=int(prediction), probability=headache_probability )) return BatchPredictionResponse(predictions=results) except Exception as e: raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}")