Spaces:
Running
Running
File size: 7,635 Bytes
545e508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
#!/usr/bin/env python
# furniture_bbox_to_files.py ββββββββββββββββββββββββββββββββββββββββ
# Florence-2 + SAM-2 batch processor with retries *and* file-based images
# --------------------------------------------------------------------
import os, json, random, time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List
import torch, supervision as sv
from PIL import Image, ImageDraw, ImageColor, ImageOps
from tqdm.auto import tqdm
from datasets import load_dataset, Image as HFImage, disable_progress_bar
# βββββ global models ββββββββββββββββββββββββββββββββββββββββββββββββ
from utils.florence import (
load_florence_model, run_florence_inference,
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
)
from utils.sam import load_sam_image_model, run_sam_inference
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
# annotators
_PALETTE = sv.ColorPalette.from_hex(
['#FF1493','#00BFFF','#FF6347','#FFD700','#32CD32','#8A2BE2'])
BOX_ANN = sv.BoxAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX)
MASK_ANN = sv.MaskAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LBL_ANN = sv.LabelAnnotator(
color=_PALETTE, color_lookup=sv.ColorLookup.INDEX,
text_position=sv.Position.CENTER_OF_MASS,
text_color=sv.Color.from_hex("#000"), border_radius=5)
# βββββ config βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
os.environ["TOKENIZERS_PARALLELISM"] = "false"
disable_progress_bar()
DATASET_NAME = "fotographerai/furniture_captioned_segment_prompt"
SPLIT = "train"
IMAGE_COL = "img2"
PROMPT_COL = "segmenting_prompt"
INFLATE_RANGE = (0.01, 0.05)
FILL_COLOR = "#00FF00"
TARGET_SIDE = 1500
QA_DIR = Path("bbox_review_recaptioned")
GREEN_DIR = QA_DIR / "green"; GREEN_DIR.mkdir(parents=True, exist_ok=True)
ANNO_DIR = QA_DIR / "anno"; ANNO_DIR.mkdir(parents=True, exist_ok=True)
JSON_DIR = QA_DIR / "json"; JSON_DIR.mkdir(parents=True, exist_ok=True)
MAX_WORKERS = 100
MAX_RETRIES = 5
RETRY_SLEEP = .3
FAILED_LOG = QA_DIR / "failed_rows.jsonl"
PROMPT_MAP: dict[str,str] = {} # optional overrides
# βββββ helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def make_square(img: Image.Image, side: int = TARGET_SIDE) -> Image.Image:
img = ImageOps.contain(img, (side, side))
pad_w, pad_h = side - img.width, side - img.height
return ImageOps.expand(img, border=(pad_w//2, pad_h//2,
pad_w - pad_w//2, pad_h - pad_h//2),
fill=img.getpixel((0,0)))
def img_to_file(img: Image.Image, fname: str, folder: Path) -> dict:
path = folder / f"{fname}.png"
if not path.exists():
img.save(path)
return {"path": str(path), "bytes": None}
# βββββ core functions βββββββββββββββββββββββββββββββββββββββββββββββ
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def detect_and_segment(img: Image.Image, prompts: str|List[str]) -> sv.Detections:
if isinstance(prompts, str):
prompts = [p.strip() for p in prompts.split(",") if p.strip()]
all_dets = []
for p in prompts:
_, res = run_florence_inference(
model=FLORENCE_MODEL, processor=FLORENCE_PROC, device=DEVICE,
image=img, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=p)
d = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, res, img.size)
all_dets.append(run_sam_inference(SAM_IMAGE_MODEL, img, d))
return sv.Detections.merge(all_dets)
def fill_detected_bboxes(img: Image.Image, prompt: str,
inflate_pct: float) -> tuple[Image.Image, sv.Detections]:
dets = detect_and_segment(img, prompt)
filled = img.copy()
draw = ImageDraw.Draw(filled)
rgb = ImageColor.getrgb(FILL_COLOR)
w,h = img.size
for box in dets.xyxy:
x1,y1,x2,y2 = box.astype(float)
dw,dh = (x2-x1)*inflate_pct, (y2-y1)*inflate_pct
draw.rectangle([max(0,x1-dw), max(0,y1-dh),
min(w,x2+dw), min(h,y2+dh)], fill=rgb)
return filled, dets
# βββββ threaded worker ββββββββββββββββββββββββββββββββββββββββββββββ
def process_row(idx: int, sample):
prompt = PROMPT_MAP.get(sample[PROMPT_COL],
sample[PROMPT_COL].split(",",1)[0].strip())
img_sq = make_square(sample[IMAGE_COL].convert("RGB"))
for attempt in range(1, MAX_RETRIES+1):
try:
filled, dets = fill_detected_bboxes(
img_sq, prompt, inflate_pct=random.uniform(*INFLATE_RANGE))
if len(dets.xyxy) == 0:
raise ValueError("no detections")
sid = f"{idx:06d}"
json_p = JSON_DIR / f"{sid}_bbox.json"
json_p.write_text(json.dumps({"xyxy": dets.xyxy.tolist()}))
anno = img_sq.copy()
for ann in (MASK_ANN, BOX_ANN, LABEL_ANN):
anno = ann.annotate(anno, dets)
return ("ok",
img_to_file(filled, sid, GREEN_DIR),
img_to_file(anno, sid, ANNO_DIR),
json_p.read_text())
except Exception as e:
if attempt < MAX_RETRIES:
time.sleep(RETRY_SLEEP)
else:
return ("fail", str(e))
# βββββ run batch ββββββββββββββββββββββββββββββββββββββββββββββββββββ
ds = load_dataset(DATASET_NAME, split=SPLIT, streaming=False)
N = len(ds)
print("Rows:", N)
filled_col, anno_col, json_col = [None]*N, [None]*N, [None]*N
fails = 0
with ThreadPoolExecutor(MAX_WORKERS) as pool:
fut2idx = {pool.submit(process_row, i, ds[i]): i for i in range(N)}
for fut in tqdm(as_completed(fut2idx), total=N, desc="Florence+SAM"):
idx = fut2idx[fut]
status, *data = fut.result()
if status == "ok":
filled_col[idx], anno_col[idx], json_col[idx] = data
else:
fails += 1
FAILED_LOG.write_text(json.dumps({"idx": idx, "reason": data[0]})+"\n")
print(f"β permanently failed rows: {fails}")
keep = [i for i,x in enumerate(filled_col) if x]
new_ds = ds.select(keep)
new_ds = new_ds.add_column("bbox_filled", [filled_col[i] for i in keep])
new_ds = new_ds.add_column("annotated", [anno_col[i] for i in keep])
new_ds = new_ds.add_column("bbox_json", [json_col[i] for i in keep])
new_ds = new_ds.cast_column("bbox_filled", HFImage())
new_ds = new_ds.cast_column("annotated", HFImage())
print(f"β
successes: {len(new_ds)} / {N}")
print("Columns:", new_ds.column_names)
print("QA artefacts β", QA_DIR.resolve())
# optional push
new_ds.push_to_hub("fotographerai/surround_furniture_bboxfilled",
private=True, max_shard_size="500MB")
|