File size: 5,173 Bytes
545e508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# detect_and_segment.py
import torch
import supervision as sv
from typing import List, Tuple, Optional

# ==== 1. One-time global model loading  =====================================
from .utils.florence import (
    load_florence_model,
    run_florence_inference,
    FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
)
from .utils.sam import load_sam_image_model, run_sam_inference

from PIL import Image, ImageDraw, ImageColor
import numpy as np

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# load models once – they stay in memory for repeated calls
FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL               = load_sam_image_model(device=DEVICE)

# quick annotators
COLORS          = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
COLOR_PALETTE   = sv.ColorPalette.from_hex(COLORS)
BOX_ANNOTATOR   = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LABEL_ANNOTATOR = sv.LabelAnnotator(
    color=COLOR_PALETTE,
    color_lookup=sv.ColorLookup.INDEX,
    text_position=sv.Position.CENTER_OF_MASS,
    text_color=sv.Color.from_hex("#000000"),
    border_radius=5,
)
MASK_ANNOTATOR  = sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)

# ==== 2. Inference function  ===============================================

@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def detect_and_segment(
    image          : Image.Image,
    text_prompts   : str | List[str],
    return_image   : bool = True,
) -> Tuple[sv.Detections, Optional[Image.Image]]:
    """
    Run Florence-2 open-vocabulary detection + SAM2 mask refinement on a PIL image.

    Parameters
    ----------
    image : PIL.Image
        Input image in RGB.
    text_prompts : str | List[str]
        Single prompt or comma-separated list (e.g. "dog, tail, leash").
    return_image : bool
        If True, also returns an annotated PIL image.

    Returns
    -------
    detections : sv.Detections
        Supervision object with xyxy, mask, class_id, etc.
    annotated  : PIL.Image | None
        Annotated image (None if return_image=False)
    """
    # Normalize prompt list
    if isinstance(text_prompts, str):
        prompts = [p.strip() for p in text_prompts.split(",") if p.strip()]
    else:
        prompts = [p.strip() for p in text_prompts]

    if len(prompts) == 0:
        raise ValueError("Empty prompt list given.")

    # Collect detections from each prompt
    det_list: list[sv.Detections] = []
    for p in prompts:
        _, result = run_florence_inference(
            model      = FLORENCE_MODEL,
            processor  = FLORENCE_PROC,
            device     = DEVICE,
            image      = image,
            task       = FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
            text       = p,
        )
        det = sv.Detections.from_lmm(
            lmm           = sv.LMM.FLORENCE_2,
            result        = result,
            resolution_wh = image.size,
        )
        det = run_sam_inference(SAM_IMAGE_MODEL, image, det)  # SAM2 refinement
        det_list.append(det)

    detections = sv.Detections.merge(det_list)

    annotated_img = None
    if return_image:
        annotated_img = image.copy()
        annotated_img = MASK_ANNOTATOR.annotate(annotated_img, detections)
        annotated_img = BOX_ANNOTATOR.annotate(annotated_img, detections)
        annotated_img = LABEL_ANNOTATOR.annotate(annotated_img, detections)

    return detections, annotated_img



def fill_detected_bboxes(
    image: Image.Image,
    text: str,
    inflate_pct: float = 0.10,
    fill_color: str | tuple[int, int, int] = "#00FF00",
):
    """
    Detect objects matching `text`, inflate each bounding-box by `inflate_pct`,
    fill the area with `fill_color`, and return the resulting image.

    Parameters
    ----------
    image        : PIL.Image
        Input image (RGB).
    text         : str
        Comma-separated prompt(s) for open-vocabulary detection.
    inflate_pct  : float, default 0.10
        Extra margin per side (0.10 = +10 % width & height).
    fill_color   : str | tuple, default "#00FF00"
        Solid color used to fill each inflated bbox (hex or RGB tuple).

    Returns
    -------
    filled_img   : PIL.Image
        Image with each detected (inflated) box filled.
    detections   : sv.Detections
        Original detection object from `detect_and_segment`.
    """
    # run Florence2 + SAM2 pipeline (your helper from earlier)
    detections, _ = detect_and_segment(image, text)

    w, h = image.size
    filled_img = image.copy()
    draw = ImageDraw.Draw(filled_img)
    fill_rgb = ImageColor.getrgb(fill_color) if isinstance(fill_color, str) else fill_color

    for box in detections.xyxy:
        # xyxy is numpy array β†’ cast to float for math
        x1, y1, x2, y2 = box.astype(float)
        dw, dh = (x2 - x1) * inflate_pct, (y2 - y1) * inflate_pct
        x1_i = max(0, x1 - dw)
        y1_i = max(0, y1 - dh)
        x2_i = min(w, x2 + dw)
        y2_i = min(h, y2 + dh)
        draw.rectangle([x1_i, y1_i, x2_i, y2_i], fill=fill_rgb)

    return filled_img, detections