File size: 7,635 Bytes
545e508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python
# furniture_bbox_to_files.py  ────────────────────────────────────────
# Florence-2 + SAM-2 batch processor with retries *and* file-based images
# --------------------------------------------------------------------
import os, json, random, time
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List

import torch, supervision as sv
from PIL import Image, ImageDraw, ImageColor, ImageOps
from tqdm.auto import tqdm
from datasets import load_dataset, Image as HFImage, disable_progress_bar

# ───── global models ────────────────────────────────────────────────
from utils.florence import (
    load_florence_model, run_florence_inference,
    FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
)
from utils.sam import load_sam_image_model, run_sam_inference

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL               = load_sam_image_model(device=DEVICE)

# annotators
_PALETTE = sv.ColorPalette.from_hex(
    ['#FF1493','#00BFFF','#FF6347','#FFD700','#32CD32','#8A2BE2'])
BOX_ANN  = sv.BoxAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX)
MASK_ANN = sv.MaskAnnotator(color=_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LBL_ANN  = sv.LabelAnnotator(
    color=_PALETTE, color_lookup=sv.ColorLookup.INDEX,
    text_position=sv.Position.CENTER_OF_MASS,
    text_color=sv.Color.from_hex("#000"), border_radius=5)

# ───── config ───────────────────────────────────────────────────────
os.environ["TOKENIZERS_PARALLELISM"] = "false"
disable_progress_bar()

DATASET_NAME  = "fotographerai/furniture_captioned_segment_prompt"
SPLIT         = "train"
IMAGE_COL     = "img2"
PROMPT_COL    = "segmenting_prompt"

INFLATE_RANGE = (0.01, 0.05)
FILL_COLOR    = "#00FF00"
TARGET_SIDE   = 1500

QA_DIR        = Path("bbox_review_recaptioned")
GREEN_DIR     = QA_DIR / "green";  GREEN_DIR.mkdir(parents=True, exist_ok=True)
ANNO_DIR      = QA_DIR / "anno";   ANNO_DIR.mkdir(parents=True, exist_ok=True)
JSON_DIR      = QA_DIR / "json";   JSON_DIR.mkdir(parents=True, exist_ok=True)

MAX_WORKERS   = 100
MAX_RETRIES   = 5
RETRY_SLEEP   = .3
FAILED_LOG    = QA_DIR / "failed_rows.jsonl"

PROMPT_MAP: dict[str,str] = {}   # optional overrides

# ───── helpers ──────────────────────────────────────────────────────
def make_square(img: Image.Image, side: int = TARGET_SIDE) -> Image.Image:
    img = ImageOps.contain(img, (side, side))
    pad_w, pad_h = side - img.width, side - img.height
    return ImageOps.expand(img, border=(pad_w//2, pad_h//2,
                                        pad_w - pad_w//2, pad_h - pad_h//2),
                           fill=img.getpixel((0,0)))

def img_to_file(img: Image.Image, fname: str, folder: Path) -> dict:
    path = folder / f"{fname}.png"
    if not path.exists():
        img.save(path)
    return {"path": str(path), "bytes": None}

# ───── core functions ───────────────────────────────────────────────
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def detect_and_segment(img: Image.Image, prompts: str|List[str]) -> sv.Detections:
    if isinstance(prompts, str):
        prompts = [p.strip() for p in prompts.split(",") if p.strip()]
    all_dets = []
    for p in prompts:
        _, res = run_florence_inference(
            model=FLORENCE_MODEL, processor=FLORENCE_PROC, device=DEVICE,
            image=img, task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text=p)
        d = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, res, img.size)
        all_dets.append(run_sam_inference(SAM_IMAGE_MODEL, img, d))
    return sv.Detections.merge(all_dets)

def fill_detected_bboxes(img: Image.Image, prompt: str,
                         inflate_pct: float) -> tuple[Image.Image, sv.Detections]:
    dets = detect_and_segment(img, prompt)
    filled = img.copy()
    draw = ImageDraw.Draw(filled)
    rgb   = ImageColor.getrgb(FILL_COLOR)
    w,h   = img.size
    for box in dets.xyxy:
        x1,y1,x2,y2 = box.astype(float)
        dw,dh = (x2-x1)*inflate_pct, (y2-y1)*inflate_pct
        draw.rectangle([max(0,x1-dw), max(0,y1-dh),
                        min(w,x2+dw), min(h,y2+dh)], fill=rgb)
    return filled, dets

# ───── threaded worker ──────────────────────────────────────────────
def process_row(idx: int, sample):
    prompt = PROMPT_MAP.get(sample[PROMPT_COL],
                            sample[PROMPT_COL].split(",",1)[0].strip())
    img_sq = make_square(sample[IMAGE_COL].convert("RGB"))
    for attempt in range(1, MAX_RETRIES+1):
        try:
            filled, dets = fill_detected_bboxes(
                img_sq, prompt, inflate_pct=random.uniform(*INFLATE_RANGE))
            if len(dets.xyxy) == 0:
                raise ValueError("no detections")

            sid = f"{idx:06d}"
            json_p = JSON_DIR / f"{sid}_bbox.json"
            json_p.write_text(json.dumps({"xyxy": dets.xyxy.tolist()}))

            anno = img_sq.copy()
            for ann in (MASK_ANN, BOX_ANN, LABEL_ANN):
                anno = ann.annotate(anno, dets)

            return ("ok",
                    img_to_file(filled, sid, GREEN_DIR),
                    img_to_file(anno,   sid, ANNO_DIR),
                    json_p.read_text())
        except Exception as e:
            if attempt < MAX_RETRIES:
                time.sleep(RETRY_SLEEP)
            else:
                return ("fail", str(e))

# ───── run batch ────────────────────────────────────────────────────
ds = load_dataset(DATASET_NAME, split=SPLIT, streaming=False)
N  = len(ds)
print("Rows:", N)

filled_col, anno_col, json_col = [None]*N, [None]*N, [None]*N
fails = 0

with ThreadPoolExecutor(MAX_WORKERS) as pool:
    fut2idx = {pool.submit(process_row, i, ds[i]): i for i in range(N)}
    for fut in tqdm(as_completed(fut2idx), total=N, desc="Florence+SAM"):
        idx = fut2idx[fut]
        status, *data = fut.result()
        if status == "ok":
            filled_col[idx], anno_col[idx], json_col[idx] = data
        else:
            fails += 1
            FAILED_LOG.write_text(json.dumps({"idx": idx, "reason": data[0]})+"\n")

print(f"❌ permanently failed rows: {fails}")

keep = [i for i,x in enumerate(filled_col) if x]
new_ds = ds.select(keep)
new_ds = new_ds.add_column("bbox_filled", [filled_col[i] for i in keep])
new_ds = new_ds.add_column("annotated",   [anno_col[i]   for i in keep])
new_ds = new_ds.add_column("bbox_json",   [json_col[i]   for i in keep])
new_ds = new_ds.cast_column("bbox_filled", HFImage())
new_ds = new_ds.cast_column("annotated",   HFImage())

print(f"βœ… successes: {len(new_ds)} / {N}")
print("Columns:", new_ds.column_names)
print("QA artefacts β†’", QA_DIR.resolve())

# optional push
new_ds.push_to_hub("fotographerai/surround_furniture_bboxfilled",
                    private=True, max_shard_size="500MB")