File size: 1,879 Bytes
86771a8
eaa1a7e
45675d3
ac9a037
86771a8
eaa1a7e
 
ac9a037
86771a8
 
45675d3
86771a8
 
45675d3
86771a8
ac9a037
 
86771a8
45675d3
86771a8
 
45675d3
86771a8
 
 
 
 
 
45675d3
 
eaa1a7e
ac9a037
86771a8
ac9a037
45675d3
dd7be29
86771a8
 
45675d3
dd7be29
86771a8
 
45675d3
86771a8
 
 
 
 
45675d3
86771a8
dd7be29
86771a8
45675d3
86771a8
ac9a037
86771a8
 
 
45675d3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from fastapi import FastAPI, Response, HTTPException
from pydantic import BaseModel
from typing import Union, Dict, Any
import os
import io
import sys
from handler import EndpointHandler

# Add debug logging
def debug_log(message):
    print(f"DEBUG: {message}")
    sys.stdout.flush()

debug_log("Starting API initialization")

app = FastAPI()

# Initialize the handler with the model directory
model_dir = os.environ.get("MODEL_DIR", "/code/diffsketcher")
debug_log(f"Using model_dir: {model_dir}")
handler = EndpointHandler(model_dir)
debug_log("Handler initialized")

class TextRequest(BaseModel):
    inputs: Union[str, Dict[str, Any]]

@app.get("/")
def read_root():
    debug_log("Root endpoint called")
    return {"message": "DiffSketcher Vector Graphics Generation API"}

@app.post("/")
async def generate(request: TextRequest):
    try:
        debug_log(f"Generate endpoint called with request: {request}")
        
        # Call the handler
        result = handler(request.dict())
        debug_log("Handler returned result")
        
        # If the result is a PIL Image, convert it to bytes
        if hasattr(result, "save"):
            debug_log("Result is a PIL Image, converting to bytes")
            img_byte_arr = io.BytesIO()
            result.save(img_byte_arr, format="PNG")
            img_byte_arr.seek(0)
            
            # Return the image as a response
            debug_log("Returning image response")
            return Response(content=img_byte_arr.getvalue(), media_type="image/png")
        else:
            # Return the result as JSON
            debug_log(f"Returning JSON response: {result}")
            return result
    except Exception as e:
        debug_log(f"Error in generate endpoint: {e}")
        import traceback
        debug_log(traceback.format_exc())
        raise HTTPException(status_code=500, detail=str(e))