anfdock / main.py
Ashrafb's picture
Update main.py
71911c0 verified
raw
history blame
1.3 kB
from fastapi import FastAPI, File, UploadFile,Form
from fastapi.responses import FileResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from PIL import Image
import torch
from io import BytesIO
app = FastAPI()
model2 = torch.hub.load(
"AK391/animegan2-pytorch:main",
"generator",
pretrained=True,
device="cpu",
progress=False
)
model1 = torch.hub.load("AK391/animegan2-pytorch:main",
"generator", pretrained="face_paint_512_v1", device="cpu")
face2paint = torch.hub.load(
'AK391/animegan2-pytorch:main', 'face2paint',
size=512, device="cpu", side_by_side=False
)
@app.post("/predict/")
async def predict(
file: UploadFile = File(...),
version: str = Form(...)
):
contents = await file.read()
image = Image.open(BytesIO(contents))
if version == 'version2':
out = face2paint(model2, image)
else:
out = face2paint(model1, image)
img_byte_arr = BytesIO()
out.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return StreamingResponse(img_byte_arr, media_type="image/png")
app.mount("/", StaticFiles(directory="AB", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/AB/index.html", media_type="text/html")