YT_Video / vision.py
Shami96's picture
Create vision.py
a84ea93 verified
import json
from pathlib import Path
from typing import Dict
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from config import VISION_MODEL
_processor = None
_model = None
_device = "cuda" if torch.cuda.is_available() else "cpu"
def _load_blip():
global _processor, _model
if _processor is None or _model is None:
_processor = BlipProcessor.from_pretrained(VISION_MODEL)
_model = BlipForConditionalGeneration.from_pretrained(VISION_MODEL).to(_device)
_model.eval()
return _processor, _model
def caption_image(img_path: Path) -> str:
processor, model = _load_blip()
img = Image.open(str(img_path)).convert("RGB")
inputs = processor(img, return_tensors="pt").to(_device)
with torch.inference_mode():
out_ids = model.generate(**inputs, max_new_tokens=40)
return processor.decode(out_ids[0], skip_special_tokens=True)
def caption_folder(frames_dir: Path) -> Dict[str, str]:
results = {}
for p in sorted(frames_dir.glob("*.jpg")):
results[p.name] = caption_image(p)
return results
def dump_json(data, out_path: Path):
with open(out_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)