Spaces:
Sleeping
Sleeping
| # -------------------------------------------------------------------------- # | |
| # UNIFIED AI SERVICE V3.4 (Color-Enhanced Segmentation) | |
| # -------------------------------------------------------------------------- # | |
| # This service uses DINOv2 for image embeddings and BGE for text embeddings. | |
| # - The segmentation prompt now includes colors for better accuracy. | |
| # - For debugging, segmented images are uploaded to Uploadcare. | |
| # -------------------------------------------------------------------------- | |
| import sys | |
| sys.stdout.reconfigure(line_buffering=True) | |
| import os | |
| import numpy as np | |
| import requests | |
| import cv2 | |
| import traceback | |
| from io import BytesIO | |
| from flask import Flask, request, jsonify | |
| from PIL import Image | |
| from datetime import datetime, timedelta | |
| # --- Import Deep Learning Libraries --- | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModel, AutoTokenizer | |
| from segment_anything import SamPredictor, sam_model_registry | |
| from transformers import AutoProcessor as AutoGndProcessor, AutoModelForZeroShotObjectDetection | |
| # ========================================================================== | |
| # --- CONFIGURATION & INITIALIZATION --- | |
| # ========================================================================== | |
| app = Flask(__name__) | |
| TEXT_FIELDS_TO_EMBED = ["brand", "material", "markings"] | |
| SCORE_WEIGHTS = { | |
| "text_score": 0.6, | |
| "image_score": 0.4 | |
| } | |
| FINAL_SCORE_THRESHOLD = 0.75 | |
| # --- Load Uploadcare Credentials from Environment Variables --- | |
| UPLOADCARE_PUBLIC_KEY = os.getenv('UPLOADCARE_PUBLIC_KEY') | |
| if not UPLOADCARE_PUBLIC_KEY: | |
| print("β WARNING: UPLOADCARE_PUBLIC_KEY environment variable not set. Debug uploads will fail.") | |
| print("="*50) | |
| print("π Initializing AI Service with DINOv2...") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"π§ Using device: {device}") | |
| print("...Loading BGE text model...") | |
| bge_model_id = "BAAI/bge-small-en-v1.5" | |
| tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id) | |
| model_text = AutoModel.from_pretrained(bge_model_id).to(device) | |
| print("β BGE model loaded.") | |
| print("...Loading DINOv2 model...") | |
| dinov2_model_id = "facebook/dinov2-base" | |
| processor_dinov2 = AutoImageProcessor.from_pretrained(dinov2_model_id) | |
| model_dinov2 = AutoModel.from_pretrained(dinov2_model_id).to(device) | |
| print("β DINOv2 model loaded.") | |
| print("...Loading Grounding DINO model for segmentation...") | |
| gnd_model_id = "IDEA-Research/grounding-dino-base" # Kept base as you didn't specify changing this | |
| processor_gnd = AutoGndProcessor.from_pretrained(gnd_model_id) | |
| model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device) | |
| print("β Grounding DINO model loaded.") | |
| print("...Loading SAM model...") | |
| sam_checkpoint = "sam_vit_b_01ec64.pth" | |
| sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device) | |
| sam_predictor = SamPredictor(sam_model) | |
| print("β SAM model loaded.") | |
| print("="*50) | |
| # ========================================================================== | |
| # --- HELPER FUNCTIONS --- | |
| # ========================================================================== | |
| def get_text_embedding(text: str) -> list: | |
| if isinstance(text, list): | |
| if not text: return None | |
| text = ", ".join(text) | |
| if not text or not text.strip(): | |
| return None | |
| instruction = "Represent this sentence for searching relevant passages: " | |
| inputs = tokenizer_text(instruction + text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) | |
| with torch.no_grad(): | |
| outputs = model_text(**inputs) | |
| embedding = outputs.last_hidden_state[:, 0, :] | |
| embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) | |
| return embedding.cpu().numpy()[0].tolist() | |
| def get_image_embedding(image: Image.Image) -> list: | |
| inputs = processor_dinov2(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model_dinov2(**inputs) | |
| embedding = outputs.last_hidden_state[:, 0, :] | |
| embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) | |
| return embedding.cpu().numpy()[0].tolist() | |
| def cosine_similarity(vec1, vec2): | |
| if vec1 is None or vec2 is None: return 0.0 | |
| vec1, vec2 = np.array(vec1), np.array(vec2) | |
| return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))) | |
| def jaccard_similarity(set1, set2): | |
| if not isinstance(set1, set) or not isinstance(set2, set): | |
| return 0.0 | |
| intersection = set1.intersection(set2) | |
| union = set1.union(set2) | |
| if not union: | |
| return 1.0 if not intersection else 0.0 | |
| return len(intersection) / len(union) | |
| def segment_guided_object(image: Image.Image, object_label: str, colors: list = []) -> Image.Image: | |
| """ | |
| Finds and segments ALL instances of an object based on a text label and colors, | |
| returning the original image with the detected objects segmented with transparency. | |
| This version includes a hole-filling step to create solid masks. | |
| """ | |
| # Create a more descriptive prompt using colors, as per your new app's logic | |
| color_str = " ".join(c.lower() for c in colors if c) | |
| if color_str: | |
| prompt = f"a {color_str} {object_label}." | |
| else: | |
| prompt = f"a {object_label}." | |
| print(f" [Segment] Using prompt: '{prompt}' for segmentation.") | |
| image_rgb = image.convert("RGB") | |
| image_np = np.array(image_rgb) | |
| height, width = image_np.shape[:2] | |
| # Grounding DINO detection | |
| inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model_gnd(**inputs) | |
| # Process results with a threshold | |
| results = processor_gnd.post_process_grounded_object_detection( | |
| outputs, inputs.input_ids, threshold=0.35, text_threshold=0.5, target_sizes=[(height, width)] | |
| ) | |
| if not results or len(results[0]['boxes']) == 0: | |
| print(f" [Segment] β Warning: Could not detect '{object_label}' with GroundingDINO. Returning original image.") | |
| return Image.fromarray(np.concatenate([image_np, np.full((height, width, 1), 255, dtype=np.uint8)], axis=-1), 'RGBA') | |
| boxes = results[0]['boxes'] | |
| scores = results[0]['scores'] | |
| print(f" [Segment] β Found {len(boxes)} potential object(s) with confidence scores: {[round(s.item(), 2) for s in scores]}") | |
| # Set image for SAM | |
| sam_predictor.set_image(image_np) | |
| # Initialize an empty mask to combine all detections | |
| combined_mask = np.zeros((height, width), dtype=np.uint8) | |
| # Predict masks for all detected boxes and combine them | |
| for box in boxes: | |
| box = box.cpu().numpy().astype(int) | |
| masks, _, _ = sam_predictor.predict(box=box, multimask_output=False) | |
| combined_mask = np.bitwise_or(combined_mask, masks[0]) # Combine masks | |
| print(" [Segment] Combined masks for all detected objects.") | |
| # --- START: HOLE FILLING LOGIC --- | |
| # This new block will fill any holes within the combined mask. | |
| print(" [Segment] Post-processing: Filling holes in the combined mask...") | |
| # Find contours. RETR_EXTERNAL retrieves only the extreme outer contours. | |
| contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| # Create a new blank mask to draw the filled contours on. | |
| filled_mask = np.zeros_like(combined_mask) | |
| if contours: | |
| # Draw the detected contours onto the new mask and fill them. | |
| # The -1 index means draw all contours, and cv2.FILLED fills them. | |
| cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED) | |
| else: | |
| # If for some reason no contours were found, fall back to the original mask. | |
| filled_mask = combined_mask | |
| print(" [Segment] β Hole filling complete.") | |
| # --- END: HOLE FILLING LOGIC --- | |
| # Create an RGBA image where the background is transparent | |
| object_rgba = np.zeros((height, width, 4), dtype=np.uint8) | |
| object_rgba[:, :, :3] = image_np # Copy original RGB | |
| # Apply the NEW filled mask as the alpha channel | |
| object_rgba[:, :, 3] = filled_mask | |
| return Image.fromarray(object_rgba, 'RGBA') | |
| def upload_to_uploadcare(image: Image.Image) -> str: | |
| if not UPLOADCARE_PUBLIC_KEY: | |
| return "UPLOADCARE_PUBLIC_KEY not configured." | |
| try: | |
| buffer = BytesIO() | |
| image.save(buffer, format='PNG') | |
| buffer.seek(0) | |
| files = { 'file': ('segmented_image.png', buffer, 'image/png') } | |
| data = { 'UPLOADCARE_PUB_KEY': UPLOADCARE_PUBLIC_KEY, 'UPLOADCARE_STORE': '1' } | |
| response = requests.post('https://upload.uploadcare.com/base/', files=files, data=data) | |
| response.raise_for_status() | |
| file_uuid = response.json().get('file') | |
| return f"https://ucarecdn.com/{file_uuid}/" | |
| except Exception as e: | |
| return f"Uploadcare upload failed: {e}" | |
| # ========================================================================== | |
| # --- FLASK ENDPOINTS --- | |
| # ========================================================================== | |
| def health_check(): | |
| return jsonify({"status": "Unified AI Service is running"}), 200 | |
| def process_item(): | |
| try: | |
| data = request.json | |
| print(f"\n[PROCESS] Received request for: {data.get('objectName')}") | |
| response = { | |
| "canonicalLabel": data.get('objectName', '').lower().strip(), | |
| "brand_embedding": get_text_embedding(data.get('brand')), | |
| "material_embedding": get_text_embedding(data.get('material')), | |
| "markings_embedding": get_text_embedding(data.get('markings')), | |
| } | |
| image_embeddings = [] | |
| if data.get('images'): | |
| print(f" [PROCESS] Processing {len(data['images'])} image(s)...") | |
| for image_url in data['images']: | |
| try: | |
| img_response = requests.get(image_url, timeout=20) | |
| img_response.raise_for_status() | |
| image = Image.open(BytesIO(img_response.content)) | |
| # --- UPDATED: Pass colors to the segmentation function --- | |
| segmented_image = segment_guided_object(image, data['objectName'], data.get('colors', [])) | |
| debug_url = upload_to_uploadcare(segmented_image) | |
| print(f" - π DEBUG URL: {debug_url}") | |
| embedding = get_image_embedding(segmented_image) | |
| image_embeddings.append(embedding) | |
| except Exception as e: | |
| print(f" - β Could not process image {image_url}: {e}") | |
| continue | |
| response["image_embeddings"] = image_embeddings | |
| print(f" [PROCESS] β Successfully processed all features.") | |
| return jsonify(response), 200 | |
| except Exception as e: | |
| print(f"β Error in /process: {e}") | |
| traceback.print_exc() | |
| return jsonify({"error": str(e)}), 500 | |
| def compare_items(): | |
| try: | |
| payload = request.json | |
| query_item = payload['queryItem'] | |
| search_list = payload['searchList'] | |
| print(f"\n[COMPARE] Received {len(search_list)} pre-filtered candidates for '{query_item.get('objectName')}'.") | |
| results = [] | |
| for item in search_list: | |
| item_id = item.get('_id') | |
| print(f"\n - Comparing with item: {item_id}") | |
| try: | |
| text_score_components = [] | |
| component_log = {} | |
| # 1. Calculate score for fields with text embeddings (now includes 'markings') | |
| for field in TEXT_FIELDS_TO_EMBED: | |
| q_emb = query_item.get(f"{field}_embedding") | |
| i_emb = item.get(f"{field}_embedding") | |
| if q_emb and i_emb: | |
| score = cosine_similarity(q_emb, i_emb) | |
| text_score_components.append(score) | |
| component_log[field] = f"{score:.4f}" | |
| # 2. Calculate Jaccard score for 'colors' | |
| q_colors = set(c.lower().strip() for c in query_item.get('colors', []) if c) | |
| i_colors = set(c.lower().strip() for c in item.get('colors', []) if c) | |
| if q_colors and i_colors: | |
| score = jaccard_similarity(q_colors, i_colors) | |
| text_score_components.append(score) | |
| component_log['colors'] = f"{score:.4f}" | |
| # 3. Calculate direct match score for 'size' | |
| q_size = (query_item.get('size') or "").lower().strip() | |
| i_size = (item.get('size') or "").lower().strip() | |
| if q_size and i_size: | |
| score = 1.0 if q_size == i_size else 0.0 | |
| text_score_components.append(score) | |
| component_log['size'] = f"{score:.4f}" | |
| # 4. Average only the scores from the available components | |
| text_score = 0.0 | |
| if text_score_components: | |
| text_score = sum(text_score_components) / len(text_score_components) | |
| print(f" - Text Score Components: {component_log}") | |
| print(f" - Final Avg Text Score: {text_score:.4f} (from {len(text_score_components)} components)") | |
| # 5. Calculate Image Score | |
| image_score = 0.0 | |
| query_img_embs = query_item.get('image_embeddings', []) | |
| item_img_embs = item.get('image_embeddings', []) | |
| if query_img_embs and item_img_embs: | |
| all_img_scores = [] | |
| for q_emb in query_img_embs: | |
| for i_emb in item_img_embs: | |
| all_img_scores.append(cosine_similarity(q_emb, i_emb)) | |
| if all_img_scores: | |
| image_score = max(all_img_scores) | |
| print(f" - Max Image Score: {image_score:.4f}") | |
| # 6. Calculate Final Score (Dynamic) | |
| final_score = 0.0 | |
| if query_img_embs and item_img_embs: | |
| print(f" - Calculating Hybrid Score (Text + Image)...") | |
| final_score = (SCORE_WEIGHTS['text_score'] * text_score + SCORE_WEIGHTS['image_score'] * image_score) | |
| else: | |
| print(f" - One or both items missing images. Using Text Score only...") | |
| final_score = text_score | |
| print(f" - Final Dynamic Score: {final_score:.4f}") | |
| if final_score >= FINAL_SCORE_THRESHOLD: | |
| print(f" - β ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})") | |
| results.append({ "_id": str(item_id), "score": round(final_score, 4) }) | |
| else: | |
| print(f" - β REJECTED (Score < {FINAL_SCORE_THRESHOLD})") | |
| except Exception as e: | |
| print(f" - β Skipping item {item_id} due to scoring error: {e}") | |
| continue | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| print(f"\n[COMPARE] β Search complete. Found {len(results)} potential matches.") | |
| return jsonify({"matches": results}), 200 | |
| except Exception as e: | |
| print(f"β Error in /compare: {e}") | |
| traceback.print_exc() | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |