from fastapi import FastAPI, UploadFile, File, Request from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from PIL import Image import io import base64 import torch from transformers import BlipProcessor, BlipForConditionalGeneration # Initialize FastAPI app app = FastAPI() # Setup static and templates directories app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") # Load BLIP model and processor with local caching cache_dir = "./model_cache" processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=cache_dir) model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", cache_dir=cache_dir) @app.get("/", response_class=HTMLResponse) async def main(request: Request): return templates.TemplateResponse("index.html", {"request": request}) @app.post("/", response_class=HTMLResponse) async def caption(request: Request, file: UploadFile = File(...)): contents = await file.read() image = Image.open(io.BytesIO(contents)).convert("RGB") # Generate caption using BLIP inputs = processor(images=image, return_tensors="pt") out = model.generate(**inputs) caption = processor.decode(out[0], skip_special_tokens=True) # Convert image to base64 for preview buffered = io.BytesIO() image.save(buffered, format="PNG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return templates.TemplateResponse("index.html", { "request": request, "caption": caption, "image_data": img_str })