# -------------------------------------------------------------------------- # # UNIFIED AI SERVICE V3.3 (Added Markings Comparison) # -------------------------------------------------------------------------- # # This service uses DINOv2 for image embeddings and BGE for text embeddings. # - Filtering is handled by the Node.js backend. # - For debugging, segmented images are uploaded to Uploadcare and the URL # is printed to the console log. # -------------------------------------------------------------------------- import sys sys.stdout.reconfigure(line_buffering=True) import os import numpy as np import requests import cv2 import traceback from io import BytesIO from flask import Flask, request, jsonify from PIL import Image from datetime import datetime, timedelta # --- Import Deep Learning Libraries --- import torch from transformers import AutoImageProcessor, AutoModel, AutoTokenizer from segment_anything import SamPredictor, sam_model_registry from transformers import AutoProcessor as AutoGndProcessor, AutoModelForZeroShotObjectDetection # ========================================================================== # --- CONFIGURATION & INITIALIZATION --- # ========================================================================== app = Flask(__name__) # --- UPDATED: Added "markings" to the list of fields to compare --- TEXT_FIELDS_TO_EMBED = ["brand", "material", "markings"] SCORE_WEIGHTS = { "text_score": 0.4, "image_score": 0.6 } FINAL_SCORE_THRESHOLD = 0.5 # --- Load Uploadcare Credentials from Environment Variables --- UPLOADCARE_PUBLIC_KEY = os.getenv('UPLOADCARE_PUBLIC_KEY') if not UPLOADCARE_PUBLIC_KEY: print("⚠ WARNING: UPLOADCARE_PUBLIC_KEY environment variable not set. Debug uploads will fail.") print("="*50) print("🚀 Initializing AI Service with DINOv2...") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"🧠 Using device: {device}") print("...Loading BGE text model...") bge_model_id = "BAAI/bge-small-en-v1.5" tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id) model_text = AutoModel.from_pretrained(bge_model_id).to(device) print("✅ BGE model loaded.") print("...Loading DINOv2 model...") dinov2_model_id = "facebook/dinov2-base" processor_dinov2 = AutoImageProcessor.from_pretrained(dinov2_model_id) model_dinov2 = AutoModel.from_pretrained(dinov2_model_id).to(device) print("✅ DINOv2 model loaded.") print("...Loading Grounding DINO model for segmentation...") gnd_model_id = "IDEA-Research/grounding-dino-base" processor_gnd = AutoGndProcessor.from_pretrained(gnd_model_id) model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device) print("✅ Grounding DINO model loaded.") print("...Loading SAM model...") sam_checkpoint = "sam_vit_b_01ec64.pth" sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device) sam_predictor = SamPredictor(sam_model) print("✅ SAM model loaded.") print("="*50) # ========================================================================== # --- HELPER FUNCTIONS --- # ========================================================================== def get_text_embedding(text: str) -> list: if isinstance(text, list): if not text: return None text = ", ".join(text) if not text or not text.strip(): return None instruction = "Represent this sentence for searching relevant passages: " inputs = tokenizer_text(instruction + text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) with torch.no_grad(): outputs = model_text(**inputs) embedding = outputs.last_hidden_state[:, 0, :] embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) return embedding.cpu().numpy()[0].tolist() def get_image_embedding(image: Image.Image) -> list: inputs = processor_dinov2(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model_dinov2(**inputs) embedding = outputs.last_hidden_state[:, 0, :] embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) return embedding.cpu().numpy()[0].tolist() def cosine_similarity(vec1, vec2): if vec1 is None or vec2 is None: return 0.0 vec1, vec2 = np.array(vec1), np.array(vec2) return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))) def jaccard_similarity(set1, set2): if not isinstance(set1, set) or not isinstance(set2, set): return 0.0 intersection = set1.intersection(set2) union = set1.union(set2) if not union: return 1.0 if not intersection else 0.0 return len(intersection) / len(union) def segment_guided_object(image: Image.Image, object_label: str) -> Image.Image: prompt = f"a {object_label}." print(f" [Segment] Using simple prompt: '{prompt}'") image_rgb = image.convert("RGB") image_np = np.array(image_rgb) h, w = image_np.shape[:2] inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device) with torch.no_grad(): outputs = model_gnd(**inputs) results = processor_gnd.post_process_grounded_object_detection( outputs, inputs.input_ids, threshold=0.5, text_threshold=0.5, target_sizes=[(h, w)] ) if not results or len(results[0]['boxes']) == 0: print(f" [Segment] ⚠ Warning: Could not detect object with Grounding DINO. Using full image.") return image_rgb print(f" [Segment] ✅ Object detected successfully.") box = results[0]['boxes'][0].cpu().numpy() sam_predictor.set_image(image_np) masks, _, _ = sam_predictor.predict(box=box, multimask_output=False) mask = masks[0] image_rgba = np.concatenate([image_np, np.full((h, w, 1), 255, dtype=np.uint8)], axis=-1) image_rgba[:, :, 3] = mask * 255 segmented_image = Image.fromarray(image_rgba, 'RGBA') true_points = np.argwhere(mask) if true_points.size > 0: top_left = true_points.min(axis=0) bottom_right = true_points.max(axis=0) bbox = (top_left[1], top_left[0], bottom_right[1], bottom_right[0]) segmented_image = segmented_image.crop(bbox) return segmented_image def upload_to_uploadcare(image: Image.Image) -> str: if not UPLOADCARE_PUBLIC_KEY: return "UPLOADCARE_PUBLIC_KEY not configured." try: buffer = BytesIO() image.save(buffer, format='PNG') buffer.seek(0) files = { 'file': ('segmented_image.png', buffer, 'image/png') } data = { 'UPLOADCARE_PUB_KEY': UPLOADCARE_PUBLIC_KEY, 'UPLOADCARE_STORE': '1' } response = requests.post('https://upload.uploadcare.com/base/', files=files, data=data) response.raise_for_status() file_uuid = response.json().get('file') return f"https://ucarecdn.com/{file_uuid}/" except Exception as e: return f"Uploadcare upload failed: {e}" # ========================================================================== # --- FLASK ENDPOINTS --- # ========================================================================== @app.route('/', methods=['GET']) def health_check(): return jsonify({"status": "Unified AI Service is running"}), 200 @app.route('/process', methods=['POST']) def process_item(): try: data = request.json print(f"\n[PROCESS] Received request for: {data.get('objectName')}") # --- UPDATED: Added markings_embedding --- response = { "canonicalLabel": data.get('objectName', '').lower().strip(), "brand_embedding": get_text_embedding(data.get('brand')), "material_embedding": get_text_embedding(data.get('material')), "markings_embedding": get_text_embedding(data.get('markings')), } image_embeddings = [] if data.get('images'): print(f" [PROCESS] Processing {len(data['images'])} image(s)...") for image_url in data['images']: try: img_response = requests.get(image_url, timeout=20) img_response.raise_for_status() image = Image.open(BytesIO(img_response.content)) segmented_image = segment_guided_object(image, data['objectName']) debug_url = upload_to_uploadcare(segmented_image) print(f" - 🐞 DEBUG URL: {debug_url}") embedding = get_image_embedding(segmented_image) image_embeddings.append(embedding) except Exception as e: print(f" - ⚠ Could not process image {image_url}: {e}") continue response["image_embeddings"] = image_embeddings print(f" [PROCESS] ✅ Successfully processed all features.") return jsonify(response), 200 except Exception as e: print(f"❌ Error in /process: {e}") traceback.print_exc() return jsonify({"error": str(e)}), 500 @app.route('/compare', methods=['POST']) def compare_items(): try: payload = request.json query_item = payload['queryItem'] search_list = payload['searchList'] print(f"\n[COMPARE] Received {len(search_list)} pre-filtered candidates for '{query_item.get('objectName')}'.") results = [] for item in search_list: item_id = item.get('_id') print(f"\n - Comparing with item: {item_id}") try: text_score_components = [] component_log = {} # 1. Calculate score for fields with text embeddings (now includes 'markings') for field in TEXT_FIELDS_TO_EMBED: q_emb = query_item.get(f"{field}_embedding") i_emb = item.get(f"{field}_embedding") if q_emb and i_emb: score = cosine_similarity(q_emb, i_emb) text_score_components.append(score) component_log[field] = f"{score:.4f}" # 2. Calculate Jaccard score for 'colors' q_colors = set(c.lower().strip() for c in query_item.get('colors', []) if c) i_colors = set(c.lower().strip() for c in item.get('colors', []) if c) if q_colors and i_colors: score = jaccard_similarity(q_colors, i_colors) text_score_components.append(score) component_log['colors'] = f"{score:.4f}" # 3. Calculate direct match score for 'size' q_size = (query_item.get('size') or "").lower().strip() i_size = (item.get('size') or "").lower().strip() if q_size and i_size: score = 1.0 if q_size == i_size else 0.0 text_score_components.append(score) component_log['size'] = f"{score:.4f}" # 4. Average only the scores from the available components text_score = 0.0 if text_score_components: text_score = sum(text_score_components) / len(text_score_components) print(f" - Text Score Components: {component_log}") print(f" - Final Avg Text Score: {text_score:.4f} (from {len(text_score_components)} components)") # 5. Calculate Image Score image_score = 0.0 query_img_embs = query_item.get('image_embeddings', []) item_img_embs = item.get('image_embeddings', []) if query_img_embs and item_img_embs: all_img_scores = [] for q_emb in query_img_embs: for i_emb in item_img_embs: all_img_scores.append(cosine_similarity(q_emb, i_emb)) if all_img_scores: image_score = max(all_img_scores) print(f" - Max Image Score: {image_score:.4f}") # 6. Calculate Final Score (Dynamic) final_score = 0.0 if query_img_embs and item_img_embs: print(f" - Calculating Hybrid Score (Text + Image)...") final_score = (SCORE_WEIGHTS['text_score'] * text_score + SCORE_WEIGHTS['image_score'] * image_score) else: print(f" - One or both items missing images. Using Text Score only...") final_score = text_score print(f" - Final Dynamic Score: {final_score:.4f}") if final_score >= FINAL_SCORE_THRESHOLD: print(f" - ✅ ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})") results.append({ "_id": str(item_id), "score": round(final_score, 4) }) else: print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})") except Exception as e: print(f" - ⚠ Skipping item {item_id} due to scoring error: {e}") continue results.sort(key=lambda x: x["score"], reverse=True) print(f"\n[COMPARE] ✅ Search complete. Found {len(results)} potential matches.") return jsonify({"matches": results}), 200 except Exception as e: print(f"❌ Error in /compare: {e}") traceback.print_exc() return jsonify({"error": str(e)}), 500 if __name__ == '__main__': app.run(host='0.0.0.0', port=7860)