File size: 1,879 Bytes
86771a8 eaa1a7e 45675d3 ac9a037 86771a8 eaa1a7e ac9a037 86771a8 45675d3 86771a8 45675d3 86771a8 ac9a037 86771a8 45675d3 86771a8 45675d3 86771a8 45675d3 eaa1a7e ac9a037 86771a8 ac9a037 45675d3 dd7be29 86771a8 45675d3 dd7be29 86771a8 45675d3 86771a8 45675d3 86771a8 dd7be29 86771a8 45675d3 86771a8 ac9a037 86771a8 45675d3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from typing import Union, Dict, Any
import os
import io
import sys
from handler import EndpointHandler
# Add debug logging
def debug_log(message):
print(f"DEBUG: {message}")
sys.stdout.flush()
debug_log("Starting API initialization")
app = FastAPI()
# Initialize the handler with the model directory
model_dir = os.environ.get("MODEL_DIR", "/code/diffsketcher")
debug_log(f"Using model_dir: {model_dir}")
handler = EndpointHandler(model_dir)
debug_log("Handler initialized")
class TextRequest(BaseModel):
inputs: Union[str, Dict[str, Any]]
@app.get("/")
def read_root():
debug_log("Root endpoint called")
return {"message": "DiffSketcher Vector Graphics Generation API"}
@app.post("/")
async def generate(request: TextRequest):
try:
debug_log(f"Generate endpoint called with request: {request}")
# Call the handler
result = handler(request.dict())
debug_log("Handler returned result")
# If the result is a PIL Image, convert it to bytes
if hasattr(result, "save"):
debug_log("Result is a PIL Image, converting to bytes")
img_byte_arr = io.BytesIO()
result.save(img_byte_arr, format="PNG")
img_byte_arr.seek(0)
# Return the image as a response
debug_log("Returning image response")
return Response(content=img_byte_arr.getvalue(), media_type="image/png")
else:
# Return the result as JSON
debug_log(f"Returning JSON response: {result}")
return result
except Exception as e:
debug_log(f"Error in generate endpoint: {e}")
import traceback
debug_log(traceback.format_exc())
raise HTTPException(status_code=500, detail=str(e)) |