emp-admin's picture
Upload app.py with huggingface_hub
47e0ff5 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import xgboost as xgb
import numpy as np
import pickle
from huggingface_hub import hf_hub_download
import os
import sys
from typing import List, Union
app = FastAPI(title="Headache Predictor API")
# Load model at startup
model = None
@app.on_event("startup")
async def load_model():
global model
try:
# Set cache directory to writable location
cache_dir = "/tmp/hf_cache"
os.makedirs(cache_dir, exist_ok=True)
# Get HF token from environment (set as Space secret)
hf_token = os.environ.get("HF_TOKEN")
model_path = hf_hub_download(
repo_id="emp-admin/headache-predictor-xgboost",
filename="model.pkl",
cache_dir=cache_dir,
token=hf_token # Use token for private repo access
)
with open(model_path, 'rb') as f:
model_data = pickle.load(f)
# Handle both dict format and raw model
if isinstance(model_data, dict):
model = model_data['model']
print(f"✅ Model loaded successfully (threshold: {model_data.get('optimal_threshold', 0.5)})")
else:
model = model_data
print("✅ Model loaded successfully")
except Exception as e:
print(f"❌ Error loading model: {e}")
import traceback
traceback.print_exc()
class SinglePredictionRequest(BaseModel):
features: List[float]
class BatchPredictionRequest(BaseModel):
instances: List[List[float]]
class DayPrediction(BaseModel):
day: int
prediction: int
probability: float # Probability of HEADACHE (class 1), regardless of prediction
class SinglePredictionResponse(BaseModel):
prediction: int
probability: float # Probability of HEADACHE (class 1), regardless of prediction
class BatchPredictionResponse(BaseModel):
predictions: List[DayPrediction]
@app.get("/")
def read_root():
return {
"message": "Headache Predictor API",
"status": "running",
"endpoints": {
"predict": "/predict - Single day prediction",
"predict_batch": "/predict/batch - 7-day forecast",
"health": "/health"
},
"examples": {
"single": {
"url": "/predict",
"body": {"features": [1, 0, 0, 0, 1, 0, 1005.0, -9.5, 85.0, 15.5, 64.0, 5.5, 41.0, 0.0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 10, 40, 4, 7.0, 50.0, 60.0, 3.5, 1.5, 6.8]}
},
"batch": {
"url": "/predict/batch",
"body": {"instances": [["array of 37 features for day 1"], ["array for day 2"], "..."]}
}
}
}
@app.get("/health")
def health_check():
return {
"status": "healthy",
"model_loaded": model is not None
}
@app.post("/predict", response_model=SinglePredictionResponse)
def predict(request: SinglePredictionRequest):
"""Predict headache risk for a single day"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Convert input to numpy array
features = np.array(request.features).reshape(1, -1)
# Get probability array for both classes
prob_array = model.predict_proba(features)[0]
# Always return probability of headache (class 1)
headache_probability = float(prob_array[1])
# Make prediction using threshold if available
if isinstance(model, dict) and 'optimal_threshold' in model:
threshold = model['optimal_threshold']
prediction = 1 if headache_probability >= threshold else 0
else:
prediction = model.predict(features)[0]
return SinglePredictionResponse(
prediction=int(prediction),
probability=headache_probability
)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Prediction error: {str(e)}")
@app.post("/predict/batch", response_model=BatchPredictionResponse)
def predict_batch(request: BatchPredictionRequest):
"""Predict headache risk for multiple days (7-day forecast)"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
try:
# Convert all instances to numpy array
features = np.array(request.instances)
if features.ndim != 2:
raise ValueError(f"Expected 2D array, got shape {features.shape}")
# Get probabilities for all days
probabilities = model.predict_proba(features)
# Format results
results = []
for i, prob_array in enumerate(probabilities, 1):
# Always use probability of headache (class 1)
headache_probability = float(prob_array[1])
# Make prediction using threshold if available
if isinstance(model, dict) and 'optimal_threshold' in model:
threshold = model['optimal_threshold']
prediction = 1 if headache_probability >= threshold else 0
else:
prediction = model.predict(features[i-1:i])[0]
results.append(DayPrediction(
day=i,
prediction=int(prediction),
probability=headache_probability
))
return BatchPredictionResponse(predictions=results)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Batch prediction error: {str(e)}")