Spaces:
Sleeping
Sleeping
| # -------------------------------------------------------------------------- | |
| # UNIFIED AI SERVICE V3.2 (Debug Uploads & Refactored) | |
| # -------------------------------------------------------------------------- | |
| # This service uses DINOv2 for image embeddings and BGE for text embeddings. | |
| # - Filtering is handled by the Node.js backend. | |
| # - For debugging, segmented images are uploaded to Uploadcare and the URL | |
| # is printed to the console log. | |
| # -------------------------------------------------------------------------- | |
| import sys | |
| sys.stdout.reconfigure(line_buffering=True) | |
| import os | |
| import numpy as np | |
| import requests | |
| import cv2 | |
| import traceback | |
| from io import BytesIO | |
| from flask import Flask, request, jsonify | |
| from PIL import Image | |
| from datetime import datetime, timedelta | |
| # --- Import Deep Learning Libraries --- | |
| import torch | |
| from transformers import AutoImageProcessor, AutoModel, AutoTokenizer | |
| from segment_anything import SamPredictor, sam_model_registry | |
| from transformers import AutoProcessor as AutoGndProcessor, AutoModelForZeroShotObjectDetection | |
| # ========================================================================== | |
| # --- CONFIGURATION & INITIALIZATION --- | |
| # ========================================================================== | |
| app = Flask(__name__) | |
| TEXT_FIELDS_TO_EMBED = ["brand", "material", "size", "colors"] | |
| SCORE_WEIGHTS = { "text_score": 0.4, "image_score": 0.6 } | |
| FINAL_SCORE_THRESHOLD = 0.5 | |
| # --- Load Uploadcare Credentials from Environment Variables --- | |
| # Make sure to set this as a Secret in your Hugging Face Space settings. | |
| UPLOADCARE_PUBLIC_KEY = os.getenv('UPLOADCARE_PUBLIC_KEY') | |
| if not UPLOADCARE_PUBLIC_KEY: | |
| print("β οΈ WARNING: UPLOADCARE_PUBLIC_KEY environment variable not set. Debug uploads will fail.") | |
| print("="*50) | |
| print("π Initializing AI Service with DINOv2...") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"π§ Using device: {device}") | |
| print("...Loading BGE text model...") | |
| bge_model_id = "BAAI/bge-small-en-v1.5" | |
| tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id) | |
| model_text = AutoModel.from_pretrained(bge_model_id).to(device) | |
| print("β BGE model loaded.") | |
| print("...Loading DINOv2 model...") | |
| dinov2_model_id = "facebook/dinov2-base" | |
| processor_dinov2 = AutoImageProcessor.from_pretrained(dinov2_model_id) | |
| model_dinov2 = AutoModel.from_pretrained(dinov2_model_id).to(device) | |
| print("β DINOv2 model loaded.") | |
| print("...Loading Grounding DINO model for segmentation...") | |
| gnd_model_id = "IDEA-Research/grounding-dino-base" | |
| processor_gnd = AutoGndProcessor.from_pretrained(gnd_model_id) | |
| model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device) | |
| print("β Grounding DINO model loaded.") | |
| print("...Loading SAM model...") | |
| sam_checkpoint = "sam_vit_b_01ec64.pth" | |
| sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device) | |
| sam_predictor = SamPredictor(sam_model) | |
| print("β SAM model loaded.") | |
| print("="*50) | |
| # ========================================================================== | |
| # --- HELPER FUNCTIONS --- | |
| # ========================================================================== | |
| def get_text_embedding(text: str) -> list: | |
| if isinstance(text, list): | |
| if not text: return None | |
| text = ", ".join(text) | |
| if not text or not text.strip(): | |
| return None | |
| instruction = "Represent this sentence for searching relevant passages: " | |
| inputs = tokenizer_text(instruction + text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device) | |
| with torch.no_grad(): | |
| outputs = model_text(**inputs) | |
| embedding = outputs.last_hidden_state[:, 0, :] | |
| embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) | |
| return embedding.cpu().numpy()[0].tolist() | |
| def get_image_embedding(image: Image.Image) -> list: | |
| inputs = processor_dinov2(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model_dinov2(**inputs) | |
| embedding = outputs.last_hidden_state[:, 0, :] | |
| embedding = torch.nn.functional.normalize(embedding, p=2, dim=1) | |
| return embedding.cpu().numpy()[0].tolist() | |
| def cosine_similarity(vec1, vec2): | |
| if vec1 is None or vec2 is None: return 0.0 | |
| vec1, vec2 = np.array(vec1), np.array(vec2) | |
| return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))) | |
| def segment_guided_object(image: Image.Image, object_label: str) -> Image.Image: | |
| prompt = f"a {object_label}." | |
| print(f" [Segment] Using simple prompt: '{prompt}'") | |
| image_rgb = image.convert("RGB") | |
| image_np = np.array(image_rgb) | |
| h, w = image_np.shape[:2] | |
| inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model_gnd(**inputs) | |
| results = processor_gnd.post_process_grounded_object_detection( | |
| outputs, inputs.input_ids, threshold=0.5, text_threshold=0.5, target_sizes=[(h, w)] | |
| ) | |
| if not results or len(results[0]['boxes']) == 0: | |
| print(f" [Segment] β οΈ Warning: Could not detect object. Using full image.") | |
| return image_rgb | |
| print(f" [Segment] β Object detected successfully.") | |
| sam_predictor.set_image(image_np) | |
| box = results[0]['boxes'][0].cpu().numpy().astype(int) | |
| masks, _, _ = sam_predictor.predict(box=box, multimask_output=False) | |
| mask = masks[0] | |
| background = np.ones_like(image_np, dtype=np.uint8) * 255 | |
| foreground = cv2.bitwise_and(image_np, image_np, mask=mask.astype(np.uint8)) | |
| background = cv2.bitwise_and(background, background, mask=~mask.astype(np.uint8)) | |
| segmented_np = cv2.add(foreground, background) | |
| return Image.fromarray(segmented_np, 'RGB') | |
| def upload_to_uploadcare(image: Image.Image) -> str: | |
| """Uploads a PIL Image to Uploadcare and returns the CDN URL.""" | |
| if not UPLOADCARE_PUBLIC_KEY: | |
| return "UPLOADCARE_PUBLIC_KEY not configured." | |
| try: | |
| # Convert PIL Image to in-memory bytes buffer | |
| buffer = BytesIO() | |
| image.save(buffer, format='PNG') | |
| buffer.seek(0) | |
| files = { 'file': ('segmented_image.png', buffer, 'image/png') } | |
| data = { 'UPLOADCARE_PUB_KEY': UPLOADCARE_PUBLIC_KEY, 'UPLOADCARE_STORE': '1' } | |
| response = requests.post('https://upload.uploadcare.com/base/', files=files, data=data) | |
| response.raise_for_status() | |
| file_uuid = response.json().get('file') | |
| cdn_url = f"https://ucarecdn.com/{file_uuid}/" | |
| return cdn_url | |
| except Exception as e: | |
| return f"Uploadcare upload failed: {e}" | |
| # ========================================================================== | |
| # --- FLASK ENDPOINTS --- | |
| # ========================================================================== | |
| def health_check(): | |
| return jsonify({"status": "Unified AI Service is running"}), 200 | |
| def process_item(): | |
| try: | |
| data = request.json | |
| print(f"\n[PROCESS] Received request for: {data.get('objectName')}") | |
| response = { | |
| "canonicalLabel": data.get('objectName', '').lower().strip(), | |
| "brand_embedding": get_text_embedding(data.get('brand')), | |
| "material_embedding": get_text_embedding(data.get('material')), | |
| "size_embedding": get_text_embedding(data.get('size')), | |
| "colors_embedding": get_text_embedding(data.get('colors')), | |
| } | |
| image_embeddings = [] | |
| if data.get('images'): | |
| print(f" [PROCESS] Processing {len(data['images'])} image(s)...") | |
| for image_url in data['images']: | |
| try: | |
| img_response = requests.get(image_url, timeout=20) | |
| img_response.raise_for_status() | |
| image = Image.open(BytesIO(img_response.content)) | |
| segmented_image = segment_guided_object(image, data['objectName']) | |
| # --- DEBUGGING STEP: Upload segmented image and log the URL --- | |
| debug_url = upload_to_uploadcare(segmented_image) | |
| print(f" - π DEBUG URL: {debug_url}") | |
| # ----------------------------------------------------------- | |
| embedding = get_image_embedding(segmented_image) | |
| image_embeddings.append(embedding) | |
| except Exception as e: | |
| print(f" - β οΈ Could not process image {image_url}: {e}") | |
| continue | |
| response["image_embeddings"] = image_embeddings | |
| print(f" [PROCESS] β Successfully processed all features.") | |
| return jsonify(response), 200 | |
| except Exception as e: | |
| print(f"β Error in /process: {e}") | |
| traceback.print_exc() | |
| return jsonify({"error": str(e)}), 500 | |
| def compare_items(): | |
| try: | |
| payload = request.json | |
| query_item = payload['queryItem'] | |
| search_list = payload['searchList'] | |
| print(f"\n[COMPARE] Received {len(search_list)} pre-filtered candidates for '{query_item.get('objectName')}'.") | |
| results = [] | |
| for item in search_list: | |
| item_id = item.get('_id') | |
| print(f"\n - Comparing with item: {item_id}") | |
| try: | |
| # 1. Calculate Text Score | |
| total_text_score = 0 | |
| for field in TEXT_FIELDS_TO_EMBED: | |
| q_emb = query_item.get(f"{field}_embedding") | |
| i_emb = item.get(f"{field}_embedding") | |
| if q_emb and i_emb: | |
| total_text_score += cosine_similarity(q_emb, i_emb) | |
| text_score = total_text_score / len(TEXT_FIELDS_TO_EMBED) if TEXT_FIELDS_TO_EMBED else 0 | |
| print(f" - Text Score: {text_score:.4f}") | |
| # 2. Calculate Image Score with detailed logging | |
| image_score = 0.0 | |
| query_img_embs = query_item.get('image_embeddings', []) | |
| item_img_embs = item.get('image_embeddings', []) | |
| if query_img_embs and item_img_embs: | |
| all_img_scores = [] | |
| print(f" - Image Pair Scores:") | |
| for i, q_emb in enumerate(query_img_embs): | |
| for j, i_emb in enumerate(item_img_embs): | |
| pair_score = cosine_similarity(q_emb, i_emb) | |
| print(f" - Query Img {i+1} vs Item Img {j+1}: {pair_score:.4f}") | |
| all_img_scores.append(pair_score) | |
| if all_img_scores: | |
| image_score = max(all_img_scores) | |
| print(f" - Max Image Score: {image_score:.4f}") | |
| # 3. Calculate Final Score | |
| final_score = (SCORE_WEIGHTS['text_score'] * text_score + SCORE_WEIGHTS['image_score'] * image_score) | |
| print(f" - Final Hybrid Score: {final_score:.4f}") | |
| if final_score >= FINAL_SCORE_THRESHOLD: | |
| print(f" - β ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})") | |
| results.append({ "_id": str(item_id), "score": round(final_score, 4) }) | |
| else: | |
| print(f" - β REJECTED (Score < {FINAL_SCORE_THRESHOLD})") | |
| except Exception as e: | |
| print(f" - β οΈ Skipping item {item_id} due to scoring error: {e}") | |
| continue | |
| results.sort(key=lambda x: x["score"], reverse=True) | |
| print(f"\n[COMPARE] β Search complete. Found {len(results)} potential matches.") | |
| return jsonify({"matches": results}), 200 | |
| except Exception as e: | |
| print(f"β Error in /compare: {e}") | |
| traceback.print_exc() | |
| return jsonify({"error": str(e)}), 500 | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860) |