Shami96 commited on
Commit
a84ea93
·
verified ·
1 Parent(s): 1bf0ddc

Create vision.py

Browse files
Files changed (1) hide show
  1. vision.py +37 -0
vision.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict
4
+ from PIL import Image
5
+ import torch
6
+ from transformers import BlipProcessor, BlipForConditionalGeneration
7
+ from config import VISION_MODEL
8
+
9
+ _processor = None
10
+ _model = None
11
+ _device = "cuda" if torch.cuda.is_available() else "cpu"
12
+
13
+ def _load_blip():
14
+ global _processor, _model
15
+ if _processor is None or _model is None:
16
+ _processor = BlipProcessor.from_pretrained(VISION_MODEL)
17
+ _model = BlipForConditionalGeneration.from_pretrained(VISION_MODEL).to(_device)
18
+ _model.eval()
19
+ return _processor, _model
20
+
21
+ def caption_image(img_path: Path) -> str:
22
+ processor, model = _load_blip()
23
+ img = Image.open(str(img_path)).convert("RGB")
24
+ inputs = processor(img, return_tensors="pt").to(_device)
25
+ with torch.inference_mode():
26
+ out_ids = model.generate(**inputs, max_new_tokens=40)
27
+ return processor.decode(out_ids[0], skip_special_tokens=True)
28
+
29
+ def caption_folder(frames_dir: Path) -> Dict[str, str]:
30
+ results = {}
31
+ for p in sorted(frames_dir.glob("*.jpg")):
32
+ results[p.name] = caption_image(p)
33
+ return results
34
+
35
+ def dump_json(data, out_path: Path):
36
+ with open(out_path, "w", encoding="utf-8") as f:
37
+ json.dump(data, f, ensure_ascii=False, indent=2)