Image_Tagger / app.py
stephenebert's picture
Update app.py
d3e85eb verified
from __future__ import annotations
import io
import os
from pathlib import Path
from typing import List
import gradio as gr
from fastapi import FastAPI, File, HTTPException, Query, UploadFile
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from PIL import Image
import tagger as tg
# -------------------- FastAPI --------------------
app = FastAPI(
title="Image Tagger API",
version="1.0.0",
description="Generate a caption with BLIP, then return top-K tags derived from that caption.",
)
WRITE_SIDECAR = os.getenv("WRITE_SIDECAR", "1") != "0"
class TagResponse(BaseModel):
filename: str = Field(..., examples=["photo.jpg"])
caption: str = Field(..., examples=["a lion rests on a rock in the wild"])
tags: List[str] = Field(..., examples=[["lion", "rests", "rock", "wild"]])
@app.on_event("startup")
def _load_once() -> None:
tg.init_models()
@app.get("/healthz")
def healthz():
return {"ok": True}
@app.get("/", response_class=HTMLResponse)
def root():
return """
<!doctype html>
<html>
<head>
<meta charset="utf-8" />
<title>Image Tagger API</title>
<style>
body{font-family: system-ui, -apple-system, Segoe UI, Roboto, Ubuntu, sans-serif; max-width: 820px; margin: 48px auto; padding: 0 16px;}
.card{border:1px solid #e5e7eb; border-radius:12px; padding:20px;}
.btn{background:#111; color:#fff; padding:.6rem 1rem; border-radius:10px; text-decoration:none;}
.btn:focus,.btn:hover{opacity:.9}
input[type=number]{width:80px;}
</style>
</head>
<body>
<h2>🖼️ Image Tagger API</h2>
<p>Use <a href="/docs">/docs</a> for Swagger or try the simple UI at <a class="btn" href="/ui">/ui</a>.</p>
<div class="card">
<h3>Quick upload</h3>
<form action="/upload" method="post" enctype="multipart/form-data">
<p><input type="file" name="file" accept="image/png,image/jpeg,image/webp" required></p>
<p>Top K tags: <input type="number" name="top_k" min="1" max="20" value="5"></p>
<p><button class="btn" type="submit">Upload</button></p>
</form>
</div>
</body>
</html>"""
@app.post("/upload", response_model=TagResponse)
async def upload_image(
file: UploadFile = File(...),
top_k: int = Query(5, ge=1, le=20, description="How many tags to return"),
):
try:
content = await file.read()
img = Image.open(io.BytesIO(content)).convert("RGB")
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid image: {e}")
# caption with BLIP
caption = tg.caption_image(img)
# top-K tags (ensure tagger returns ONLY the list)
stem = Path(file.filename).stem
tags = tg.caption_to_tags(caption, top_k=top_k)
# optional sidecar (same content shape as JSON response)
if WRITE_SIDECAR:
try:
(Path(os.getenv("DATA_DIR", "/app/data"))).mkdir(parents=True, exist_ok=True)
(Path(os.getenv("DATA_DIR", "/app/data")) / f"{stem}.json").write_text(
TagResponse(filename=file.filename, caption=caption, tags=tags).model_dump_json(indent=2)
)
except Exception:
# ignore filesystem errors; do not fail the request
pass
return TagResponse(filename=file.filename, caption=caption, tags=tags)
# -------------------- Gradio (mounted at /ui) --------------------
def _infer(image: Image.Image, top_k: int):
"""Wraps the same logic used by the API, but returns simple types
so the schema is trivial for Gradio (avoids JSON/dict outputs)."""
if image is None:
return "", ""
cap = tg.caption_image(image)
tags = tg.caption_to_tags(cap, top_k=top_k)
return cap, ", ".join(tags)
with gr.Blocks(title="Image Tagger UI") as demo:
gr.Markdown("### 🔍 Image → Caption → Tags\nUpload an image → BLIP generates a caption → we extract up to **K** simple tags.")
with gr.Row():
with gr.Column(scale=3):
in_img = gr.Image(type="pil", label="Upload image", height=480)
k = gr.Slider(1, 20, value=5, step=1, label="Number of tags (K)")
submit = gr.Button("Submit", variant="primary")
clear = gr.Button("Clear")
with gr.Column(scale=2):
out_cap = gr.Textbox(label="Generated Caption", lines=2)
out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2)
submit.click(_infer, inputs=[in_img, k], outputs=[out_cap, out_tags])
clear.click(lambda: (None, 5, "", ""), outputs=[in_img, k, out_cap, out_tags])
# mount Gradio under FastAPI
app = gr.mount_gradio_app(app, demo, path="/ui")