Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import os | |
| import re | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from PIL import Image | |
| from transformers import BlipForConditionalGeneration, BlipProcessor | |
| # -------------------- config -------------------- | |
| MODEL_ID = "Salesforce/blip-image-captioning-base" | |
| DATA_DIR = Path(os.getenv("DATA_DIR", "/app/data")) | |
| DATA_DIR.mkdir(parents=True, exist_ok=True) # safe if already exists | |
| # light, built-in stopword list (keeps us NLTK-free) | |
| _STOP = { | |
| "a", "an", "the", "and", "or", "of", "to", "in", "on", "with", "near", | |
| "at", "over", "under", "by", "from", "for", "into", "along", "through", | |
| "is", "are", "be", "being", "been", "it", "its", "this", "that", | |
| "as", "while", "than", "then", "there", "here", | |
| } | |
| # -------------------- model cache -------------------- | |
| _processor: Optional[BlipProcessor] = None | |
| _model: Optional[BlipForConditionalGeneration] = None | |
| def init_models() -> None: | |
| """Load BLIP once (idempotent).""" | |
| global _processor, _model | |
| if _processor is None or _model is None: | |
| _processor = BlipProcessor.from_pretrained(MODEL_ID) | |
| _model = BlipForConditionalGeneration.from_pretrained(MODEL_ID) | |
| # -------------------- core functionality -------------------- | |
| def caption_image(img: Image.Image, max_len: int = 30) -> str: | |
| """Generate a short caption for the image.""" | |
| assert _processor and _model, "Call init_models() first" | |
| inputs = _processor(images=img, return_tensors="pt") | |
| ids = _model.generate(**inputs, max_length=max_len) | |
| return _processor.decode(ids[0], skip_special_tokens=True) | |
| _TAG_RE = re.compile(r"[a-z0-9-]+") | |
| def caption_to_tags(caption: str, top_k: int = 5) -> List[str]: | |
| """ | |
| Convert a caption into up to K simple tags: | |
| - normalize to lowercase alnum/hyphen tokens | |
| - remove tiny stopword list | |
| - keep order of appearance, dedup | |
| """ | |
| tags: List[str] = [] | |
| seen = set() | |
| for tok in _TAG_RE.findall(caption.lower()): | |
| if tok in _STOP or tok in seen: | |
| continue | |
| seen.add(tok) | |
| tags.append(tok) | |
| if len(tags) >= top_k: | |
| break | |
| return tags | |
| def tag_pil_image( | |
| img: Image.Image, | |
| stem: str, | |
| *, | |
| top_k: int = 5, | |
| write_sidecar: bool = True, | |
| ) -> List[str]: | |
| """ | |
| Return ONLY the tags list. | |
| (We optionally persist a sidecar JSON with caption + tags.) | |
| """ | |
| cap = caption_image(img) | |
| tags = caption_to_tags(cap, top_k=top_k) | |
| if write_sidecar: | |
| payload = {"caption": cap, "tags": tags} | |
| sidecar = DATA_DIR / f"{stem}.json" | |
| try: | |
| sidecar.write_text(json.dumps(payload, indent=2)) | |
| except Exception: | |
| # best-effort; tagging should still succeed | |
| pass | |
| return tags | |