sohamnk commited on
Commit
10deabd
Β·
verified Β·
1 Parent(s): 708c63e

chainging segment guided object

Browse files
Files changed (1) hide show
  1. app.py +41 -25
app.py CHANGED
@@ -59,7 +59,7 @@ model_dinov2 = AutoModel.from_pretrained(dinov2_model_id).to(device)
59
  print("βœ… DINOv2 model loaded.")
60
 
61
  print("...Loading Grounding DINO model for segmentation...")
62
- gnd_model_id = "IDEA-Research/grounding-dino-base"
63
  processor_gnd = AutoGndProcessor.from_pretrained(gnd_model_id)
64
  model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
65
  print("βœ… Grounding DINO model loaded.")
@@ -112,46 +112,62 @@ def jaccard_similarity(set1, set2):
112
  return len(intersection) / len(union)
113
 
114
  def segment_guided_object(image: Image.Image, object_label: str, colors: list = []) -> Image.Image:
115
- # --- UPDATED: Create a more descriptive prompt using colors ---
 
 
 
 
116
  color_str = " ".join(c.lower() for c in colors if c)
117
  if color_str:
118
  prompt = f"a {color_str} {object_label}."
119
  else:
120
  prompt = f"a {object_label}."
121
 
122
- print(f" [Segment] Using prompt: '{prompt}'")
123
  image_rgb = image.convert("RGB")
124
  image_np = np.array(image_rgb)
125
- h, w = image_np.shape[:2]
126
 
 
127
  inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device)
128
  with torch.no_grad():
129
  outputs = model_gnd(**inputs)
130
- results = processor_gnd.post_process_grounded_object_detection(
131
- outputs, inputs.input_ids, threshold=0.5, text_threshold=0.5, target_sizes=[(h, w)]
132
- )
 
 
133
 
134
  if not results or len(results[0]['boxes']) == 0:
135
- print(f" [Segment] ⚠ Warning: Could not detect object with Grounding DINO. Using full image.")
136
- return image_rgb
 
 
 
 
 
137
 
138
- print(f" [Segment] βœ… Object detected successfully.")
139
- box = results[0]['boxes'][0].cpu().numpy()
140
  sam_predictor.set_image(image_np)
141
- masks, _, _ = sam_predictor.predict(box=box, multimask_output=False)
142
- mask = masks[0]
143
- image_rgba = np.concatenate([image_np, np.full((h, w, 1), 255, dtype=np.uint8)], axis=-1)
144
- image_rgba[:, :, 3] = mask * 255
145
- segmented_image = Image.fromarray(image_rgba, 'RGBA')
146
-
147
- true_points = np.argwhere(mask)
148
- if true_points.size > 0:
149
- top_left = true_points.min(axis=0)
150
- bottom_right = true_points.max(axis=0)
151
- bbox = (top_left[1], top_left[0], bottom_right[1], bottom_right[0])
152
- segmented_image = segmented_image.crop(bbox)
153
-
154
- return segmented_image
 
 
 
 
 
155
 
156
  def upload_to_uploadcare(image: Image.Image) -> str:
157
  if not UPLOADCARE_PUBLIC_KEY:
 
59
  print("βœ… DINOv2 model loaded.")
60
 
61
  print("...Loading Grounding DINO model for segmentation...")
62
+ gnd_model_id = "IDEA-Research/grounding-dino-base" # Kept base as you didn't specify changing this
63
  processor_gnd = AutoGndProcessor.from_pretrained(gnd_model_id)
64
  model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
65
  print("βœ… Grounding DINO model loaded.")
 
112
  return len(intersection) / len(union)
113
 
114
  def segment_guided_object(image: Image.Image, object_label: str, colors: list = []) -> Image.Image:
115
+ """
116
+ Finds and segments ALL instances of an object based on a text label and colors,
117
+ returning the original image with the detected objects segmented with transparency.
118
+ """
119
+ # Create a more descriptive prompt using colors, as per your new app's logic
120
  color_str = " ".join(c.lower() for c in colors if c)
121
  if color_str:
122
  prompt = f"a {color_str} {object_label}."
123
  else:
124
  prompt = f"a {object_label}."
125
 
126
+ print(f" [Segment] Using prompt: '{prompt}' for segmentation.")
127
  image_rgb = image.convert("RGB")
128
  image_np = np.array(image_rgb)
129
+ height, width = image_np.shape[:2]
130
 
131
+ # Grounding DINO detection
132
  inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device)
133
  with torch.no_grad():
134
  outputs = model_gnd(**inputs)
135
+
136
+ # Process results with a threshold
137
+ results = processor_gnd.post_process_grounded_object_detection(
138
+ outputs, inputs.input_ids, threshold=0.35, text_threshold=0.5, target_sizes=[(height, width)]
139
+ )
140
 
141
  if not results or len(results[0]['boxes']) == 0:
142
+ print(f" [Segment] ⚠ Warning: Could not detect '{object_label}' with GroundingDINO. Returning original image.")
143
+ # Return the original RGB image converted to RGBA with a full alpha channel
144
+ return Image.fromarray(np.concatenate([image_np, np.full((height, width, 1), 255, dtype=np.uint8)], axis=-1), 'RGBA')
145
+
146
+ boxes = results[0]['boxes']
147
+ scores = results[0]['scores']
148
+ print(f" [Segment] βœ… Found {len(boxes)} potential object(s) with confidence scores: {[round(s.item(), 2) for s in scores]}")
149
 
150
+ # Set image for SAM
 
151
  sam_predictor.set_image(image_np)
152
+
153
+ # Initialize an empty mask to combine all detections
154
+ combined_mask = np.zeros((height, width), dtype=np.uint8)
155
+
156
+ # Predict masks for all detected boxes and combine them
157
+ for box in boxes:
158
+ box = box.cpu().numpy().astype(int)
159
+ masks, _, _ = sam_predictor.predict(box=box, multimask_output=False)
160
+ combined_mask = np.bitwise_or(combined_mask, masks[0]) # Combine masks
161
+
162
+ print(" [Segment] Combined masks for all detected objects.")
163
+
164
+ # Create an RGBA image where the background is transparent outside the combined mask
165
+ object_rgba = np.zeros((height, width, 4), dtype=np.uint8)
166
+ object_rgba[:, :, :3] = image_np # Copy original RGB
167
+ object_rgba[:, :, 3] = combined_mask * 255 # Apply the combined mask as alpha channel
168
+
169
+ return Image.fromarray(object_rgba, 'RGBA')
170
+
171
 
172
  def upload_to_uploadcare(image: Image.Image) -> str:
173
  if not UPLOADCARE_PUBLIC_KEY: