from fastapi import FastAPI, HTTPException, Request | |
from pydantic import BaseModel | |
from typing import Dict, Any, Optional, List, Union | |
import base64 | |
import io | |
from PIL import Image | |
import torch | |
import os | |
import sys | |
import json | |
# Import the handler | |
from handler import EndpointHandler | |
# Initialize the app | |
app = FastAPI() | |
# Initialize the model | |
model = EndpointHandler(model_dir="/code") | |
async def process_request(request: Request): | |
try: | |
# Get the raw request body | |
body = await request.body() | |
# Try to parse as JSON | |
try: | |
data = json.loads(body) | |
except: | |
# If not JSON, treat as plain text | |
data = {"inputs": body.decode("utf-8")} | |
# Handle different input formats | |
if isinstance(data, dict): | |
if "inputs" in data: | |
# Standard format | |
pass | |
elif "text" in data: | |
# Text field directly | |
data = {"inputs": data["text"]} | |
else: | |
# No recognized fields, use the whole dict as input | |
data = {"inputs": str(data)} | |
else: | |
# Not a dict, use as is | |
data = {"inputs": str(data)} | |
# Process the request | |
result = model(data) | |
return result | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Add a health check endpoint | |
async def health(): | |
return {"status": "ok"} |