Spaces:
Sleeping
Sleeping
File size: 15,482 Bytes
01323e1 9605648 01323e1 87cf5bb fdc1498 9605648 b31d91c 17bb3c0 e3c49c8 fbfffae b31d91c fbfffae 4954a1e fbfffae b31d91c fbfffae 87cf5bb fbfffae 87cf5bb fbfffae b31d91c fbfffae b31d91c fbfffae c38d91d 01323e1 9605648 01323e1 1e328fc fbfffae 619f2e4 01323e1 619f2e4 fbfffae 87cf5bb b31d91c fbfffae 87cf5bb 174ef72 b31d91c 87cf5bb 174ef72 87cf5bb fdc1498 87cf5bb fbfffae 87cf5bb b31d91c fbfffae b31d91c 87cf5bb fbfffae b31d91c 87cf5bb d996f25 4954a1e 174ef72 4954a1e e3c49c8 87cf5bb b31d91c d996f25 fbfffae 87cf5bb b31d91c fbfffae 01323e1 832d772 708c63e 10deabd fdc1498 ccd38a7 10deabd fdc1498 708c63e fdc1498 ccd38a7 10deabd fbfffae b31d91c 10deabd 87cf5bb fdc1498 b31d91c fbfffae ccd38a7 fdc1498 10deabd 87cf5bb fbfffae fdc1498 10deabd ccd38a7 10deabd fdc1498 87cf5bb fdc1498 b31d91c ccd38a7 fdc1498 10deabd fdc1498 10deabd fdc1498 ccd38a7 fdc1498 ccd38a7 9605648 ccd38a7 9605648 ccd38a7 fdc1498 9605648 ccd38a7 9605648 ccd38a7 9605648 ccd38a7 10deabd fdc1498 10deabd 9605648 ccd38a7 10deabd fdc1498 619f2e4 01323e1 619f2e4 b31d91c 619f2e4 fbfffae b31d91c 87cf5bb 01323e1 b31d91c c38d91d fbfffae 87cf5bb b31d91c 87cf5bb 4954a1e b31d91c 01323e1 9605648 708c63e 619f2e4 87cf5bb b31d91c 01323e1 b31d91c 01323e1 87cf5bb b31d91c 87cf5bb fbfffae 87cf5bb fbfffae b31d91c abe843d 01323e1 fbfffae d996f25 174ef72 fbfffae 01323e1 fdc1498 b31d91c 87cf5bb c38d91d 01323e1 fdc1498 01323e1 c38d91d 01323e1 9605648 01323e1 fdc1498 01323e1 c38d91d 01323e1 fdc1498 01323e1 b31d91c fdc1498 b31d91c 87cf5bb 01323e1 87cf5bb 4282b8d 01323e1 fdc1498 b7221b8 01323e1 fbfffae 174ef72 d996f25 174ef72 fbfffae 01323e1 fbfffae 01323e1 fbfffae 87cf5bb fbfffae 01323e1 a9999ab abe843d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
# -------------------------------------------------------------------------- #
# UNIFIED AI SERVICE V3.4 (Color-Enhanced Segmentation)
# -------------------------------------------------------------------------- #
# This service uses DINOv2 for image embeddings and BGE for text embeddings.
# - The segmentation prompt now includes colors for better accuracy.
# - For debugging, segmented images are uploaded to Uploadcare.
# --------------------------------------------------------------------------
import sys
sys.stdout.reconfigure(line_buffering=True)
import os
import numpy as np
import requests
import cv2
import traceback
from io import BytesIO
from flask import Flask, request, jsonify
from PIL import Image
from datetime import datetime, timedelta
# --- Import Deep Learning Libraries ---
import torch
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
from segment_anything import SamPredictor, sam_model_registry
from transformers import AutoProcessor as AutoGndProcessor, AutoModelForZeroShotObjectDetection
# ==========================================================================
# --- CONFIGURATION & INITIALIZATION ---
# ==========================================================================
app = Flask(__name__)
TEXT_FIELDS_TO_EMBED = ["brand", "material", "markings"]
SCORE_WEIGHTS = {
"text_score": 0.6,
"image_score": 0.4
}
FINAL_SCORE_THRESHOLD = 0.75
# --- Load Uploadcare Credentials from Environment Variables ---
UPLOADCARE_PUBLIC_KEY = os.getenv('UPLOADCARE_PUBLIC_KEY')
if not UPLOADCARE_PUBLIC_KEY:
print("β WARNING: UPLOADCARE_PUBLIC_KEY environment variable not set. Debug uploads will fail.")
print("="*50)
print("π Initializing AI Service with DINOv2...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"π§ Using device: {device}")
print("...Loading BGE text model...")
bge_model_id = "BAAI/bge-small-en-v1.5"
tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
model_text = AutoModel.from_pretrained(bge_model_id).to(device)
print("β
BGE model loaded.")
print("...Loading DINOv2 model...")
dinov2_model_id = "facebook/dinov2-base"
processor_dinov2 = AutoImageProcessor.from_pretrained(dinov2_model_id)
model_dinov2 = AutoModel.from_pretrained(dinov2_model_id).to(device)
print("β
DINOv2 model loaded.")
print("...Loading Grounding DINO model for segmentation...")
gnd_model_id = "IDEA-Research/grounding-dino-base" # Kept base as you didn't specify changing this
processor_gnd = AutoGndProcessor.from_pretrained(gnd_model_id)
model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
print("β
Grounding DINO model loaded.")
print("...Loading SAM model...")
sam_checkpoint = "sam_vit_b_01ec64.pth"
sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
sam_predictor = SamPredictor(sam_model)
print("β
SAM model loaded.")
print("="*50)
# ==========================================================================
# --- HELPER FUNCTIONS ---
# ==========================================================================
def get_text_embedding(text: str) -> list:
if isinstance(text, list):
if not text: return None
text = ", ".join(text)
if not text or not text.strip():
return None
instruction = "Represent this sentence for searching relevant passages: "
inputs = tokenizer_text(instruction + text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
with torch.no_grad():
outputs = model_text(**inputs)
embedding = outputs.last_hidden_state[:, 0, :]
embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
return embedding.cpu().numpy()[0].tolist()
def get_image_embedding(image: Image.Image) -> list:
inputs = processor_dinov2(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model_dinov2(**inputs)
embedding = outputs.last_hidden_state[:, 0, :]
embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
return embedding.cpu().numpy()[0].tolist()
def cosine_similarity(vec1, vec2):
if vec1 is None or vec2 is None: return 0.0
vec1, vec2 = np.array(vec1), np.array(vec2)
return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
def jaccard_similarity(set1, set2):
if not isinstance(set1, set) or not isinstance(set2, set):
return 0.0
intersection = set1.intersection(set2)
union = set1.union(set2)
if not union:
return 1.0 if not intersection else 0.0
return len(intersection) / len(union)
def segment_guided_object(image: Image.Image, object_label: str, colors: list = []) -> Image.Image:
"""
Finds and segments ALL instances of an object based on a text label and colors,
returning the original image with the detected objects segmented with transparency.
This version includes a hole-filling step to create solid masks.
"""
# Create a more descriptive prompt using colors, as per your new app's logic
color_str = " ".join(c.lower() for c in colors if c)
if color_str:
prompt = f"a {color_str} {object_label}."
else:
prompt = f"a {object_label}."
print(f" [Segment] Using prompt: '{prompt}' for segmentation.")
image_rgb = image.convert("RGB")
image_np = np.array(image_rgb)
height, width = image_np.shape[:2]
# Grounding DINO detection
inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model_gnd(**inputs)
# Process results with a threshold
results = processor_gnd.post_process_grounded_object_detection(
outputs, inputs.input_ids, threshold=0.35, text_threshold=0.5, target_sizes=[(height, width)]
)
if not results or len(results[0]['boxes']) == 0:
print(f" [Segment] β Warning: Could not detect '{object_label}' with GroundingDINO. Returning original image.")
return Image.fromarray(np.concatenate([image_np, np.full((height, width, 1), 255, dtype=np.uint8)], axis=-1), 'RGBA')
boxes = results[0]['boxes']
scores = results[0]['scores']
print(f" [Segment] β
Found {len(boxes)} potential object(s) with confidence scores: {[round(s.item(), 2) for s in scores]}")
# Set image for SAM
sam_predictor.set_image(image_np)
# Initialize an empty mask to combine all detections
combined_mask = np.zeros((height, width), dtype=np.uint8)
# Predict masks for all detected boxes and combine them
for box in boxes:
box = box.cpu().numpy().astype(int)
masks, _, _ = sam_predictor.predict(box=box, multimask_output=False)
combined_mask = np.bitwise_or(combined_mask, masks[0]) # Combine masks
print(" [Segment] Combined masks for all detected objects.")
# --- START: HOLE FILLING LOGIC ---
# This new block will fill any holes within the combined mask.
print(" [Segment] Post-processing: Filling holes in the combined mask...")
# Find contours. RETR_EXTERNAL retrieves only the extreme outer contours.
contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Create a new blank mask to draw the filled contours on.
filled_mask = np.zeros_like(combined_mask)
if contours:
# Draw the detected contours onto the new mask and fill them.
# The -1 index means draw all contours, and cv2.FILLED fills them.
cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED)
else:
# If for some reason no contours were found, fall back to the original mask.
filled_mask = combined_mask
print(" [Segment] β
Hole filling complete.")
# --- END: HOLE FILLING LOGIC ---
# Create an RGBA image where the background is transparent
object_rgba = np.zeros((height, width, 4), dtype=np.uint8)
object_rgba[:, :, :3] = image_np # Copy original RGB
# Apply the NEW filled mask as the alpha channel
object_rgba[:, :, 3] = filled_mask
return Image.fromarray(object_rgba, 'RGBA')
def upload_to_uploadcare(image: Image.Image) -> str:
if not UPLOADCARE_PUBLIC_KEY:
return "UPLOADCARE_PUBLIC_KEY not configured."
try:
buffer = BytesIO()
image.save(buffer, format='PNG')
buffer.seek(0)
files = { 'file': ('segmented_image.png', buffer, 'image/png') }
data = { 'UPLOADCARE_PUB_KEY': UPLOADCARE_PUBLIC_KEY, 'UPLOADCARE_STORE': '1' }
response = requests.post('https://upload.uploadcare.com/base/', files=files, data=data)
response.raise_for_status()
file_uuid = response.json().get('file')
return f"https://ucarecdn.com/{file_uuid}/"
except Exception as e:
return f"Uploadcare upload failed: {e}"
# ==========================================================================
# --- FLASK ENDPOINTS ---
# ==========================================================================
@app.route('/', methods=['GET'])
def health_check():
return jsonify({"status": "Unified AI Service is running"}), 200
@app.route('/process', methods=['POST'])
def process_item():
try:
data = request.json
print(f"\n[PROCESS] Received request for: {data.get('objectName')}")
response = {
"canonicalLabel": data.get('objectName', '').lower().strip(),
"brand_embedding": get_text_embedding(data.get('brand')),
"material_embedding": get_text_embedding(data.get('material')),
"markings_embedding": get_text_embedding(data.get('markings')),
}
image_embeddings = []
if data.get('images'):
print(f" [PROCESS] Processing {len(data['images'])} image(s)...")
for image_url in data['images']:
try:
img_response = requests.get(image_url, timeout=20)
img_response.raise_for_status()
image = Image.open(BytesIO(img_response.content))
# --- UPDATED: Pass colors to the segmentation function ---
segmented_image = segment_guided_object(image, data['objectName'], data.get('colors', []))
debug_url = upload_to_uploadcare(segmented_image)
print(f" - π DEBUG URL: {debug_url}")
embedding = get_image_embedding(segmented_image)
image_embeddings.append(embedding)
except Exception as e:
print(f" - β Could not process image {image_url}: {e}")
continue
response["image_embeddings"] = image_embeddings
print(f" [PROCESS] β
Successfully processed all features.")
return jsonify(response), 200
except Exception as e:
print(f"β Error in /process: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
@app.route('/compare', methods=['POST'])
def compare_items():
try:
payload = request.json
query_item = payload['queryItem']
search_list = payload['searchList']
print(f"\n[COMPARE] Received {len(search_list)} pre-filtered candidates for '{query_item.get('objectName')}'.")
results = []
for item in search_list:
item_id = item.get('_id')
print(f"\n - Comparing with item: {item_id}")
try:
text_score_components = []
component_log = {}
# 1. Calculate score for fields with text embeddings (now includes 'markings')
for field in TEXT_FIELDS_TO_EMBED:
q_emb = query_item.get(f"{field}_embedding")
i_emb = item.get(f"{field}_embedding")
if q_emb and i_emb:
score = cosine_similarity(q_emb, i_emb)
text_score_components.append(score)
component_log[field] = f"{score:.4f}"
# 2. Calculate Jaccard score for 'colors'
q_colors = set(c.lower().strip() for c in query_item.get('colors', []) if c)
i_colors = set(c.lower().strip() for c in item.get('colors', []) if c)
if q_colors and i_colors:
score = jaccard_similarity(q_colors, i_colors)
text_score_components.append(score)
component_log['colors'] = f"{score:.4f}"
# 3. Calculate direct match score for 'size'
q_size = (query_item.get('size') or "").lower().strip()
i_size = (item.get('size') or "").lower().strip()
if q_size and i_size:
score = 1.0 if q_size == i_size else 0.0
text_score_components.append(score)
component_log['size'] = f"{score:.4f}"
# 4. Average only the scores from the available components
text_score = 0.0
if text_score_components:
text_score = sum(text_score_components) / len(text_score_components)
print(f" - Text Score Components: {component_log}")
print(f" - Final Avg Text Score: {text_score:.4f} (from {len(text_score_components)} components)")
# 5. Calculate Image Score
image_score = 0.0
query_img_embs = query_item.get('image_embeddings', [])
item_img_embs = item.get('image_embeddings', [])
if query_img_embs and item_img_embs:
all_img_scores = []
for q_emb in query_img_embs:
for i_emb in item_img_embs:
all_img_scores.append(cosine_similarity(q_emb, i_emb))
if all_img_scores:
image_score = max(all_img_scores)
print(f" - Max Image Score: {image_score:.4f}")
# 6. Calculate Final Score (Dynamic)
final_score = 0.0
if query_img_embs and item_img_embs:
print(f" - Calculating Hybrid Score (Text + Image)...")
final_score = (SCORE_WEIGHTS['text_score'] * text_score + SCORE_WEIGHTS['image_score'] * image_score)
else:
print(f" - One or both items missing images. Using Text Score only...")
final_score = text_score
print(f" - Final Dynamic Score: {final_score:.4f}")
if final_score >= FINAL_SCORE_THRESHOLD:
print(f" - β
ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
results.append({ "_id": str(item_id), "score": round(final_score, 4) })
else:
print(f" - β REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
except Exception as e:
print(f" - β Skipping item {item_id} due to scoring error: {e}")
continue
results.sort(key=lambda x: x["score"], reverse=True)
print(f"\n[COMPARE] β
Search complete. Found {len(results)} potential matches.")
return jsonify({"matches": results}), 200
except Exception as e:
print(f"β Error in /compare: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860) |