Spaces:
Running
Running
File size: 5,173 Bytes
545e508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# detect_and_segment.py
import torch
import supervision as sv
from typing import List, Tuple, Optional
# ==== 1. One-time global model loading =====================================
from .utils.florence import (
load_florence_model,
run_florence_inference,
FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
)
from .utils.sam import load_sam_image_model, run_sam_inference
from PIL import Image, ImageDraw, ImageColor
import numpy as np
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# load models once β they stay in memory for repeated calls
FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE)
SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE)
# quick annotators
COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS)
BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
LABEL_ANNOTATOR = sv.LabelAnnotator(
color=COLOR_PALETTE,
color_lookup=sv.ColorLookup.INDEX,
text_position=sv.Position.CENTER_OF_MASS,
text_color=sv.Color.from_hex("#000000"),
border_radius=5,
)
MASK_ANNOTATOR = sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
# ==== 2. Inference function ===============================================
@torch.inference_mode()
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
def detect_and_segment(
image : Image.Image,
text_prompts : str | List[str],
return_image : bool = True,
) -> Tuple[sv.Detections, Optional[Image.Image]]:
"""
Run Florence-2 open-vocabulary detection + SAM2 mask refinement on a PIL image.
Parameters
----------
image : PIL.Image
Input image in RGB.
text_prompts : str | List[str]
Single prompt or comma-separated list (e.g. "dog, tail, leash").
return_image : bool
If True, also returns an annotated PIL image.
Returns
-------
detections : sv.Detections
Supervision object with xyxy, mask, class_id, etc.
annotated : PIL.Image | None
Annotated image (None if return_image=False)
"""
# Normalize prompt list
if isinstance(text_prompts, str):
prompts = [p.strip() for p in text_prompts.split(",") if p.strip()]
else:
prompts = [p.strip() for p in text_prompts]
if len(prompts) == 0:
raise ValueError("Empty prompt list given.")
# Collect detections from each prompt
det_list: list[sv.Detections] = []
for p in prompts:
_, result = run_florence_inference(
model = FLORENCE_MODEL,
processor = FLORENCE_PROC,
device = DEVICE,
image = image,
task = FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
text = p,
)
det = sv.Detections.from_lmm(
lmm = sv.LMM.FLORENCE_2,
result = result,
resolution_wh = image.size,
)
det = run_sam_inference(SAM_IMAGE_MODEL, image, det) # SAM2 refinement
det_list.append(det)
detections = sv.Detections.merge(det_list)
annotated_img = None
if return_image:
annotated_img = image.copy()
annotated_img = MASK_ANNOTATOR.annotate(annotated_img, detections)
annotated_img = BOX_ANNOTATOR.annotate(annotated_img, detections)
annotated_img = LABEL_ANNOTATOR.annotate(annotated_img, detections)
return detections, annotated_img
def fill_detected_bboxes(
image: Image.Image,
text: str,
inflate_pct: float = 0.10,
fill_color: str | tuple[int, int, int] = "#00FF00",
):
"""
Detect objects matching `text`, inflate each bounding-box by `inflate_pct`,
fill the area with `fill_color`, and return the resulting image.
Parameters
----------
image : PIL.Image
Input image (RGB).
text : str
Comma-separated prompt(s) for open-vocabulary detection.
inflate_pct : float, default 0.10
Extra margin per side (0.10 = +10 % width & height).
fill_color : str | tuple, default "#00FF00"
Solid color used to fill each inflated bbox (hex or RGB tuple).
Returns
-------
filled_img : PIL.Image
Image with each detected (inflated) box filled.
detections : sv.Detections
Original detection object from `detect_and_segment`.
"""
# run Florence2 + SAM2 pipeline (your helper from earlier)
detections, _ = detect_and_segment(image, text)
w, h = image.size
filled_img = image.copy()
draw = ImageDraw.Draw(filled_img)
fill_rgb = ImageColor.getrgb(fill_color) if isinstance(fill_color, str) else fill_color
for box in detections.xyxy:
# xyxy is numpy array β cast to float for math
x1, y1, x2, y2 = box.astype(float)
dw, dh = (x2 - x1) * inflate_pct, (y2 - y1) * inflate_pct
x1_i = max(0, x1 - dw)
y1_i = max(0, y1 - dh)
x2_i = min(w, x2 + dw)
y2_i = min(h, y2 + dh)
draw.rectangle([x1_i, y1_i, x2_i, y2_i], fill=fill_rgb)
return filled_img, detections
|