# detect_and_segment.py import torch import supervision as sv from typing import List, Tuple, Optional # ==== 1. One-time global model loading ===================================== from .utils.florence import ( load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK ) from .utils.sam import load_sam_image_model, run_sam_inference from PIL import Image, ImageDraw, ImageColor import numpy as np DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # load models once – they stay in memory for repeated calls FLORENCE_MODEL, FLORENCE_PROC = load_florence_model(device=DEVICE) SAM_IMAGE_MODEL = load_sam_image_model(device=DEVICE) # quick annotators COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2'] COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS) BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) LABEL_ANNOTATOR = sv.LabelAnnotator( color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX, text_position=sv.Position.CENTER_OF_MASS, text_color=sv.Color.from_hex("#000000"), border_radius=5, ) MASK_ANNOTATOR = sv.MaskAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX) # ==== 2. Inference function =============================================== @torch.inference_mode() @torch.autocast(device_type="cuda", dtype=torch.bfloat16) def detect_and_segment( image : Image.Image, text_prompts : str | List[str], return_image : bool = True, ) -> Tuple[sv.Detections, Optional[Image.Image]]: """ Run Florence-2 open-vocabulary detection + SAM2 mask refinement on a PIL image. Parameters ---------- image : PIL.Image Input image in RGB. text_prompts : str | List[str] Single prompt or comma-separated list (e.g. "dog, tail, leash"). return_image : bool If True, also returns an annotated PIL image. Returns ------- detections : sv.Detections Supervision object with xyxy, mask, class_id, etc. annotated : PIL.Image | None Annotated image (None if return_image=False) """ # Normalize prompt list if isinstance(text_prompts, str): prompts = [p.strip() for p in text_prompts.split(",") if p.strip()] else: prompts = [p.strip() for p in text_prompts] if len(prompts) == 0: raise ValueError("Empty prompt list given.") # Collect detections from each prompt det_list: list[sv.Detections] = [] for p in prompts: _, result = run_florence_inference( model = FLORENCE_MODEL, processor = FLORENCE_PROC, device = DEVICE, image = image, task = FLORENCE_OPEN_VOCABULARY_DETECTION_TASK, text = p, ) det = sv.Detections.from_lmm( lmm = sv.LMM.FLORENCE_2, result = result, resolution_wh = image.size, ) det = run_sam_inference(SAM_IMAGE_MODEL, image, det) # SAM2 refinement det_list.append(det) detections = sv.Detections.merge(det_list) annotated_img = None if return_image: annotated_img = image.copy() annotated_img = MASK_ANNOTATOR.annotate(annotated_img, detections) annotated_img = BOX_ANNOTATOR.annotate(annotated_img, detections) annotated_img = LABEL_ANNOTATOR.annotate(annotated_img, detections) return detections, annotated_img def fill_detected_bboxes( image: Image.Image, text: str, inflate_pct: float = 0.10, fill_color: str | tuple[int, int, int] = "#00FF00", ): """ Detect objects matching `text`, inflate each bounding-box by `inflate_pct`, fill the area with `fill_color`, and return the resulting image. Parameters ---------- image : PIL.Image Input image (RGB). text : str Comma-separated prompt(s) for open-vocabulary detection. inflate_pct : float, default 0.10 Extra margin per side (0.10 = +10 % width & height). fill_color : str | tuple, default "#00FF00" Solid color used to fill each inflated bbox (hex or RGB tuple). Returns ------- filled_img : PIL.Image Image with each detected (inflated) box filled. detections : sv.Detections Original detection object from `detect_and_segment`. """ # run Florence2 + SAM2 pipeline (your helper from earlier) detections, _ = detect_and_segment(image, text) w, h = image.size filled_img = image.copy() draw = ImageDraw.Draw(filled_img) fill_rgb = ImageColor.getrgb(fill_color) if isinstance(fill_color, str) else fill_color for box in detections.xyxy: # xyxy is numpy array → cast to float for math x1, y1, x2, y2 = box.astype(float) dw, dh = (x2 - x1) * inflate_pct, (y2 - y1) * inflate_pct x1_i = max(0, x1 - dw) y1_i = max(0, y1 - dh) x2_i = min(w, x2 + dw) y2_i = min(h, y2 + dh) draw.rectangle([x1_i, y1_i, x2_i, y2_i], fill=fill_rgb) return filled_img, detections