Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import io | |
| import os | |
| from pathlib import Path | |
| from typing import List | |
| import gradio as gr | |
| from fastapi import FastAPI, File, HTTPException, Query, UploadFile | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel, Field | |
| from PIL import Image | |
| import tagger as tg | |
| # -------------------- FastAPI -------------------- | |
| app = FastAPI( | |
| title="Image Tagger API", | |
| version="1.0.0", | |
| description="Generate a caption with BLIP, then return top-K tags derived from that caption.", | |
| ) | |
| WRITE_SIDECAR = os.getenv("WRITE_SIDECAR", "1") != "0" | |
| class TagResponse(BaseModel): | |
| filename: str = Field(..., examples=["photo.jpg"]) | |
| caption: str = Field(..., examples=["a lion rests on a rock in the wild"]) | |
| tags: List[str] = Field(..., examples=[["lion", "rests", "rock", "wild"]]) | |
| def _load_once() -> None: | |
| tg.init_models() | |
| def healthz(): | |
| return {"ok": True} | |
| def root(): | |
| return """ | |
| <!doctype html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8" /> | |
| <title>Image Tagger API</title> | |
| <style> | |
| body{font-family: system-ui, -apple-system, Segoe UI, Roboto, Ubuntu, sans-serif; max-width: 820px; margin: 48px auto; padding: 0 16px;} | |
| .card{border:1px solid #e5e7eb; border-radius:12px; padding:20px;} | |
| .btn{background:#111; color:#fff; padding:.6rem 1rem; border-radius:10px; text-decoration:none;} | |
| .btn:focus,.btn:hover{opacity:.9} | |
| input[type=number]{width:80px;} | |
| </style> | |
| </head> | |
| <body> | |
| <h2>🖼️ Image Tagger API</h2> | |
| <p>Use <a href="/docs">/docs</a> for Swagger or try the simple UI at <a class="btn" href="/ui">/ui</a>.</p> | |
| <div class="card"> | |
| <h3>Quick upload</h3> | |
| <form action="/upload" method="post" enctype="multipart/form-data"> | |
| <p><input type="file" name="file" accept="image/png,image/jpeg,image/webp" required></p> | |
| <p>Top K tags: <input type="number" name="top_k" min="1" max="20" value="5"></p> | |
| <p><button class="btn" type="submit">Upload</button></p> | |
| </form> | |
| </div> | |
| </body> | |
| </html>""" | |
| async def upload_image( | |
| file: UploadFile = File(...), | |
| top_k: int = Query(5, ge=1, le=20, description="How many tags to return"), | |
| ): | |
| try: | |
| content = await file.read() | |
| img = Image.open(io.BytesIO(content)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Invalid image: {e}") | |
| # caption with BLIP | |
| caption = tg.caption_image(img) | |
| # top-K tags (ensure tagger returns ONLY the list) | |
| stem = Path(file.filename).stem | |
| tags = tg.caption_to_tags(caption, top_k=top_k) | |
| # optional sidecar (same content shape as JSON response) | |
| if WRITE_SIDECAR: | |
| try: | |
| (Path(os.getenv("DATA_DIR", "/app/data"))).mkdir(parents=True, exist_ok=True) | |
| (Path(os.getenv("DATA_DIR", "/app/data")) / f"{stem}.json").write_text( | |
| TagResponse(filename=file.filename, caption=caption, tags=tags).model_dump_json(indent=2) | |
| ) | |
| except Exception: | |
| # ignore filesystem errors; do not fail the request | |
| pass | |
| return TagResponse(filename=file.filename, caption=caption, tags=tags) | |
| # -------------------- Gradio (mounted at /ui) -------------------- | |
| def _infer(image: Image.Image, top_k: int): | |
| """Wraps the same logic used by the API, but returns simple types | |
| so the schema is trivial for Gradio (avoids JSON/dict outputs).""" | |
| if image is None: | |
| return "", "" | |
| cap = tg.caption_image(image) | |
| tags = tg.caption_to_tags(cap, top_k=top_k) | |
| return cap, ", ".join(tags) | |
| with gr.Blocks(title="Image Tagger UI") as demo: | |
| gr.Markdown("### 🔍 Image → Caption → Tags\nUpload an image → BLIP generates a caption → we extract up to **K** simple tags.") | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| in_img = gr.Image(type="pil", label="Upload image", height=480) | |
| k = gr.Slider(1, 20, value=5, step=1, label="Number of tags (K)") | |
| submit = gr.Button("Submit", variant="primary") | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=2): | |
| out_cap = gr.Textbox(label="Generated Caption", lines=2) | |
| out_tags = gr.Textbox(label="Tags (comma-separated)", lines=2) | |
| submit.click(_infer, inputs=[in_img, k], outputs=[out_cap, out_tags]) | |
| clear.click(lambda: (None, 5, "", ""), outputs=[in_img, k, out_cap, out_tags]) | |
| # mount Gradio under FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/ui") | |