Image_Tagger / tagger.py
stephenebert's picture
Update tagger.py
0e857c8 verified
from __future__ import annotations
import json
import os
import re
from pathlib import Path
from typing import List, Optional
from PIL import Image
from transformers import BlipForConditionalGeneration, BlipProcessor
# -------------------- config --------------------
MODEL_ID = "Salesforce/blip-image-captioning-base"
DATA_DIR = Path(os.getenv("DATA_DIR", "/app/data"))
DATA_DIR.mkdir(parents=True, exist_ok=True) # safe if already exists
# light, built-in stopword list (keeps us NLTK-free)
_STOP = {
"a", "an", "the", "and", "or", "of", "to", "in", "on", "with", "near",
"at", "over", "under", "by", "from", "for", "into", "along", "through",
"is", "are", "be", "being", "been", "it", "its", "this", "that",
"as", "while", "than", "then", "there", "here",
}
# -------------------- model cache --------------------
_processor: Optional[BlipProcessor] = None
_model: Optional[BlipForConditionalGeneration] = None
def init_models() -> None:
"""Load BLIP once (idempotent)."""
global _processor, _model
if _processor is None or _model is None:
_processor = BlipProcessor.from_pretrained(MODEL_ID)
_model = BlipForConditionalGeneration.from_pretrained(MODEL_ID)
# -------------------- core functionality --------------------
def caption_image(img: Image.Image, max_len: int = 30) -> str:
"""Generate a short caption for the image."""
assert _processor and _model, "Call init_models() first"
inputs = _processor(images=img, return_tensors="pt")
ids = _model.generate(**inputs, max_length=max_len)
return _processor.decode(ids[0], skip_special_tokens=True)
_TAG_RE = re.compile(r"[a-z0-9-]+")
def caption_to_tags(caption: str, top_k: int = 5) -> List[str]:
"""
Convert a caption into up to K simple tags:
- normalize to lowercase alnum/hyphen tokens
- remove tiny stopword list
- keep order of appearance, dedup
"""
tags: List[str] = []
seen = set()
for tok in _TAG_RE.findall(caption.lower()):
if tok in _STOP or tok in seen:
continue
seen.add(tok)
tags.append(tok)
if len(tags) >= top_k:
break
return tags
def tag_pil_image(
img: Image.Image,
stem: str,
*,
top_k: int = 5,
write_sidecar: bool = True,
) -> List[str]:
"""
Return ONLY the tags list.
(We optionally persist a sidecar JSON with caption + tags.)
"""
cap = caption_image(img)
tags = caption_to_tags(cap, top_k=top_k)
if write_sidecar:
payload = {"caption": cap, "tags": tags}
sidecar = DATA_DIR / f"{stem}.json"
try:
sidecar.write_text(json.dumps(payload, indent=2))
except Exception:
# best-effort; tagging should still succeed
pass
return tags