sohamnk commited on
Commit
ccd38a7
·
verified ·
1 Parent(s): fdc1498
Files changed (1) hide show
  1. app.py +30 -8
app.py CHANGED
@@ -115,6 +115,7 @@ def segment_guided_object(image: Image.Image, object_label: str, colors: list =
115
  """
116
  Finds and segments ALL instances of an object based on a text label and colors,
117
  returning the original image with the detected objects segmented with transparency.
 
118
  """
119
  # Create a more descriptive prompt using colors, as per your new app's logic
120
  color_str = " ".join(c.lower() for c in colors if c)
@@ -122,7 +123,7 @@ def segment_guided_object(image: Image.Image, object_label: str, colors: list =
122
  prompt = f"a {color_str} {object_label}."
123
  else:
124
  prompt = f"a {object_label}."
125
-
126
  print(f" [Segment] Using prompt: '{prompt}' for segmentation.")
127
  image_rgb = image.convert("RGB")
128
  image_np = np.array(image_rgb)
@@ -132,7 +133,7 @@ def segment_guided_object(image: Image.Image, object_label: str, colors: list =
132
  inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device)
133
  with torch.no_grad():
134
  outputs = model_gnd(**inputs)
135
-
136
  # Process results with a threshold
137
  results = processor_gnd.post_process_grounded_object_detection(
138
  outputs, inputs.input_ids, threshold=0.35, text_threshold=0.5, target_sizes=[(height, width)]
@@ -140,16 +141,15 @@ def segment_guided_object(image: Image.Image, object_label: str, colors: list =
140
 
141
  if not results or len(results[0]['boxes']) == 0:
142
  print(f" [Segment] ⚠ Warning: Could not detect '{object_label}' with GroundingDINO. Returning original image.")
143
- # Return the original RGB image converted to RGBA with a full alpha channel
144
  return Image.fromarray(np.concatenate([image_np, np.full((height, width, 1), 255, dtype=np.uint8)], axis=-1), 'RGBA')
145
-
146
  boxes = results[0]['boxes']
147
  scores = results[0]['scores']
148
  print(f" [Segment] ✅ Found {len(boxes)} potential object(s) with confidence scores: {[round(s.item(), 2) for s in scores]}")
149
 
150
  # Set image for SAM
151
  sam_predictor.set_image(image_np)
152
-
153
  # Initialize an empty mask to combine all detections
154
  combined_mask = np.zeros((height, width), dtype=np.uint8)
155
 
@@ -158,14 +158,36 @@ def segment_guided_object(image: Image.Image, object_label: str, colors: list =
158
  box = box.cpu().numpy().astype(int)
159
  masks, _, _ = sam_predictor.predict(box=box, multimask_output=False)
160
  combined_mask = np.bitwise_or(combined_mask, masks[0]) # Combine masks
161
-
162
  print(" [Segment] Combined masks for all detected objects.")
 
 
 
 
 
 
 
163
 
164
- # Create an RGBA image where the background is transparent outside the combined mask
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  object_rgba = np.zeros((height, width, 4), dtype=np.uint8)
166
  object_rgba[:, :, :3] = image_np # Copy original RGB
167
- object_rgba[:, :, 3] = combined_mask * 255 # Apply the combined mask as alpha channel
168
 
 
 
 
169
  return Image.fromarray(object_rgba, 'RGBA')
170
 
171
 
 
115
  """
116
  Finds and segments ALL instances of an object based on a text label and colors,
117
  returning the original image with the detected objects segmented with transparency.
118
+ This version includes a hole-filling step to create solid masks.
119
  """
120
  # Create a more descriptive prompt using colors, as per your new app's logic
121
  color_str = " ".join(c.lower() for c in colors if c)
 
123
  prompt = f"a {color_str} {object_label}."
124
  else:
125
  prompt = f"a {object_label}."
126
+
127
  print(f" [Segment] Using prompt: '{prompt}' for segmentation.")
128
  image_rgb = image.convert("RGB")
129
  image_np = np.array(image_rgb)
 
133
  inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device)
134
  with torch.no_grad():
135
  outputs = model_gnd(**inputs)
136
+
137
  # Process results with a threshold
138
  results = processor_gnd.post_process_grounded_object_detection(
139
  outputs, inputs.input_ids, threshold=0.35, text_threshold=0.5, target_sizes=[(height, width)]
 
141
 
142
  if not results or len(results[0]['boxes']) == 0:
143
  print(f" [Segment] ⚠ Warning: Could not detect '{object_label}' with GroundingDINO. Returning original image.")
 
144
  return Image.fromarray(np.concatenate([image_np, np.full((height, width, 1), 255, dtype=np.uint8)], axis=-1), 'RGBA')
145
+
146
  boxes = results[0]['boxes']
147
  scores = results[0]['scores']
148
  print(f" [Segment] ✅ Found {len(boxes)} potential object(s) with confidence scores: {[round(s.item(), 2) for s in scores]}")
149
 
150
  # Set image for SAM
151
  sam_predictor.set_image(image_np)
152
+
153
  # Initialize an empty mask to combine all detections
154
  combined_mask = np.zeros((height, width), dtype=np.uint8)
155
 
 
158
  box = box.cpu().numpy().astype(int)
159
  masks, _, _ = sam_predictor.predict(box=box, multimask_output=False)
160
  combined_mask = np.bitwise_or(combined_mask, masks[0]) # Combine masks
161
+
162
  print(" [Segment] Combined masks for all detected objects.")
163
+
164
+ # --- START: HOLE FILLING LOGIC ---
165
+ # This new block will fill any holes within the combined mask.
166
+ print(" [Segment] Post-processing: Filling holes in the combined mask...")
167
+
168
+ # Find contours. RETR_EXTERNAL retrieves only the extreme outer contours.
169
+ contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
170
 
171
+ # Create a new blank mask to draw the filled contours on.
172
+ filled_mask = np.zeros_like(combined_mask)
173
+
174
+ if contours:
175
+ # Draw the detected contours onto the new mask and fill them.
176
+ # The -1 index means draw all contours, and cv2.FILLED fills them.
177
+ cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED)
178
+ else:
179
+ # If for some reason no contours were found, fall back to the original mask.
180
+ filled_mask = combined_mask
181
+ print(" [Segment] ✅ Hole filling complete.")
182
+ # --- END: HOLE FILLING LOGIC ---
183
+
184
+ # Create an RGBA image where the background is transparent
185
  object_rgba = np.zeros((height, width, 4), dtype=np.uint8)
186
  object_rgba[:, :, :3] = image_np # Copy original RGB
 
187
 
188
+ # Apply the NEW filled mask as the alpha channel
189
+ object_rgba[:, :, 3] = filled_mask
190
+
191
  return Image.fromarray(object_rgba, 'RGBA')
192
 
193