Spaces:
Running
Running
# detect_and_segment.py | |
import torch | |
import supervision as sv | |
from typing import List, Tuple, Optional | |
# ==== 1. One-time global model loading ===================================== | |
from .utils.florence import ( | |
load_florence_model, | |
run_florence_inference, | |
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK | |
) | |
from .utils.sam import load_sam_image_model, run_sam_inference | |
from PIL import Image, ImageDraw, ImageColor | |
import numpy as np | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# load models once β they stay in memory for repeated calls | |
FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE) | |
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) | |
# quick annotators | |
COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2'] | |
COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS) | |
BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
LABEL_ANNOTATOR = sv.LabelAnnotator( | |
color=COLOR_PALETTE, | |
color_lookup=sv.ColorLookup.INDEX, | |
text_position=sv.Position.CENTER_OF_MASS, | |
text_color=sv.Color.from_hex("#000000"), | |
border_radius=5, | |
) | |
MASK_ANNOTATOR = sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) | |
# ==== 2. Inference function =============================================== | |
def detect_and_segment( | |
image : Image.Image, | |
text_prompts : str | List[str], | |
return_image : bool = True, | |
) -> Tuple[sv.Detections, Optional[Image.Image]]: | |
""" | |
Run Florence-2 open-vocabulary detection + SAM2 mask refinement on a PIL image. | |
Parameters | |
---------- | |
image : PIL.Image | |
Input image in RGB. | |
text_prompts : str | List[str] | |
Single prompt or comma-separated list (e.g. "dog, tail, leash"). | |
return_image : bool | |
If True, also returns an annotated PIL image. | |
Returns | |
------- | |
detections : sv.Detections | |
Supervision object with xyxy, mask, class_id, etc. | |
annotated : PIL.Image | None | |
Annotated image (None if return_image=False) | |
""" | |
# Normalize prompt list | |
if isinstance(text_prompts, str): | |
prompts = [p.strip() for p in text_prompts.split(",") if p.strip()] | |
else: | |
prompts = [p.strip() for p in text_prompts] | |
if len(prompts) == 0: | |
raise ValueError("Empty prompt list given.") | |
# Collect detections from each prompt | |
det_list: list[sv.Detections] = [] | |
for p in prompts: | |
_, result = run_florence_inference( | |
model = FLORENCE_MODEL, | |
processor = FLORENCE_PROC, | |
device = DEVICE, | |
image = image, | |
task = FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, | |
text = p, | |
) | |
det = sv.Detections.from_lmm( | |
lmm = sv.LMM.FLORENCE_2, | |
result = result, | |
resolution_wh = image.size, | |
) | |
det = run_sam_inference(SAM_IMAGE_MODEL, image, det) # SAM2 refinement | |
det_list.append(det) | |
detections = sv.Detections.merge(det_list) | |
annotated_img = None | |
if return_image: | |
annotated_img = image.copy() | |
annotated_img = MASK_ANNOTATOR.annotate(annotated_img, detections) | |
annotated_img = BOX_ANNOTATOR.annotate(annotated_img, detections) | |
annotated_img = LABEL_ANNOTATOR.annotate(annotated_img, detections) | |
return detections, annotated_img | |
def fill_detected_bboxes( | |
image: Image.Image, | |
text: str, | |
inflate_pct: float = 0.10, | |
fill_color: str | tuple[int, int, int] = "#00FF00", | |
): | |
""" | |
Detect objects matching `text`, inflate each bounding-box by `inflate_pct`, | |
fill the area with `fill_color`, and return the resulting image. | |
Parameters | |
---------- | |
image : PIL.Image | |
Input image (RGB). | |
text : str | |
Comma-separated prompt(s) for open-vocabulary detection. | |
inflate_pct : float, default 0.10 | |
Extra margin per side (0.10 = +10 % width & height). | |
fill_color : str | tuple, default "#00FF00" | |
Solid color used to fill each inflated bbox (hex or RGB tuple). | |
Returns | |
------- | |
filled_img : PIL.Image | |
Image with each detected (inflated) box filled. | |
detections : sv.Detections | |
Original detection object from `detect_and_segment`. | |
""" | |
# run Florence2 + SAM2 pipeline (your helper from earlier) | |
detections, _ = detect_and_segment(image, text) | |
w, h = image.size | |
filled_img = image.copy() | |
draw = ImageDraw.Draw(filled_img) | |
fill_rgb = ImageColor.getrgb(fill_color) if isinstance(fill_color, str) else fill_color | |
for box in detections.xyxy: | |
# xyxy is numpy array β cast to float for math | |
x1, y1, x2, y2 = box.astype(float) | |
dw, dh = (x2 - x1) * inflate_pct, (y2 - y1) * inflate_pct | |
x1_i = max(0, x1 - dw) | |
y1_i = max(0, y1 - dh) | |
x2_i = min(w, x2 + dw) | |
y2_i = min(h, y2 + dh) | |
draw.rectangle([x1_i, y1_i, x2_i, y2_i], fill=fill_rgb) | |
return filled_img, detections | |