from __future__ import annotations import json import os import re from pathlib import Path from typing import List, Optional from PIL import Image from transformers import BlipForConditionalGeneration, BlipProcessor # -------------------- config -------------------- MODEL_ID = "Salesforce/blip-image-captioning-base" DATA_DIR = Path(os.getenv("DATA_DIR", "/app/data")) DATA_DIR.mkdir(parents=True, exist_ok=True) # safe if already exists # light, built-in stopword list (keeps us NLTK-free) _STOP = { "a", "an", "the", "and", "or", "of", "to", "in", "on", "with", "near", "at", "over", "under", "by", "from", "for", "into", "along", "through", "is", "are", "be", "being", "been", "it", "its", "this", "that", "as", "while", "than", "then", "there", "here", } # -------------------- model cache -------------------- _processor: Optional[BlipProcessor] = None _model: Optional[BlipForConditionalGeneration] = None def init_models() -> None: """Load BLIP once (idempotent).""" global _processor, _model if _processor is None or _model is None: _processor = BlipProcessor.from_pretrained(MODEL_ID) _model = BlipForConditionalGeneration.from_pretrained(MODEL_ID) # -------------------- core functionality -------------------- def caption_image(img: Image.Image, max_len: int = 30) -> str: """Generate a short caption for the image.""" assert _processor and _model, "Call init_models() first" inputs = _processor(images=img, return_tensors="pt") ids = _model.generate(**inputs, max_length=max_len) return _processor.decode(ids[0], skip_special_tokens=True) _TAG_RE = re.compile(r"[a-z0-9-]+") def caption_to_tags(caption: str, top_k: int = 5) -> List[str]: """ Convert a caption into up to K simple tags: - normalize to lowercase alnum/hyphen tokens - remove tiny stopword list - keep order of appearance, dedup """ tags: List[str] = [] seen = set() for tok in _TAG_RE.findall(caption.lower()): if tok in _STOP or tok in seen: continue seen.add(tok) tags.append(tok) if len(tags) >= top_k: break return tags def tag_pil_image( img: Image.Image, stem: str, *, top_k: int = 5, write_sidecar: bool = True, ) -> List[str]: """ Return ONLY the tags list. (We optionally persist a sidecar JSON with caption + tags.) """ cap = caption_image(img) tags = caption_to_tags(cap, top_k=top_k) if write_sidecar: payload = {"caption": cap, "tags": tags} sidecar = DATA_DIR / f"{stem}.json" try: sidecar.write_text(json.dumps(payload, indent=2)) except Exception: # best-effort; tagging should still succeed pass return tags