sohamnk's picture
Update app.py
619f2e4 verified
raw
history blame
12.1 kB
# --------------------------------------------------------------------------
# UNIFIED AI SERVICE V3.2 (Debug Uploads & Refactored)
# --------------------------------------------------------------------------
# This service uses DINOv2 for image embeddings and BGE for text embeddings.
# - Filtering is handled by the Node.js backend.
# - For debugging, segmented images are uploaded to Uploadcare and the URL
# is printed to the console log.
# --------------------------------------------------------------------------
import sys
sys.stdout.reconfigure(line_buffering=True)
import os
import numpy as np
import requests
import cv2
import traceback
from io import BytesIO
from flask import Flask, request, jsonify
from PIL import Image
from datetime import datetime, timedelta
# --- Import Deep Learning Libraries ---
import torch
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
from segment_anything import SamPredictor, sam_model_registry
from transformers import AutoProcessor as AutoGndProcessor, AutoModelForZeroShotObjectDetection
# ==========================================================================
# --- CONFIGURATION & INITIALIZATION ---
# ==========================================================================
app = Flask(__name__)
TEXT_FIELDS_TO_EMBED = ["brand", "material", "size", "colors"]
SCORE_WEIGHTS = { "text_score": 0.4, "image_score": 0.6 }
FINAL_SCORE_THRESHOLD = 0.5
# --- Load Uploadcare Credentials from Environment Variables ---
# Make sure to set this as a Secret in your Hugging Face Space settings.
UPLOADCARE_PUBLIC_KEY = os.getenv('UPLOADCARE_PUBLIC_KEY')
if not UPLOADCARE_PUBLIC_KEY:
print("⚠️ WARNING: UPLOADCARE_PUBLIC_KEY environment variable not set. Debug uploads will fail.")
print("="*50)
print("πŸš€ Initializing AI Service with DINOv2...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🧠 Using device: {device}")
print("...Loading BGE text model...")
bge_model_id = "BAAI/bge-small-en-v1.5"
tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
model_text = AutoModel.from_pretrained(bge_model_id).to(device)
print("βœ… BGE model loaded.")
print("...Loading DINOv2 model...")
dinov2_model_id = "facebook/dinov2-base"
processor_dinov2 = AutoImageProcessor.from_pretrained(dinov2_model_id)
model_dinov2 = AutoModel.from_pretrained(dinov2_model_id).to(device)
print("βœ… DINOv2 model loaded.")
print("...Loading Grounding DINO model for segmentation...")
gnd_model_id = "IDEA-Research/grounding-dino-base"
processor_gnd = AutoGndProcessor.from_pretrained(gnd_model_id)
model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
print("βœ… Grounding DINO model loaded.")
print("...Loading SAM model...")
sam_checkpoint = "sam_vit_b_01ec64.pth"
sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
sam_predictor = SamPredictor(sam_model)
print("βœ… SAM model loaded.")
print("="*50)
# ==========================================================================
# --- HELPER FUNCTIONS ---
# ==========================================================================
def get_text_embedding(text: str) -> list:
if isinstance(text, list):
if not text: return None
text = ", ".join(text)
if not text or not text.strip():
return None
instruction = "Represent this sentence for searching relevant passages: "
inputs = tokenizer_text(instruction + text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
with torch.no_grad():
outputs = model_text(**inputs)
embedding = outputs.last_hidden_state[:, 0, :]
embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
return embedding.cpu().numpy()[0].tolist()
def get_image_embedding(image: Image.Image) -> list:
inputs = processor_dinov2(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model_dinov2(**inputs)
embedding = outputs.last_hidden_state[:, 0, :]
embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
return embedding.cpu().numpy()[0].tolist()
def cosine_similarity(vec1, vec2):
if vec1 is None or vec2 is None: return 0.0
vec1, vec2 = np.array(vec1), np.array(vec2)
return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))
def segment_guided_object(image: Image.Image, object_label: str) -> Image.Image:
prompt = f"a {object_label}."
print(f" [Segment] Using simple prompt: '{prompt}'")
image_rgb = image.convert("RGB")
image_np = np.array(image_rgb)
h, w = image_np.shape[:2]
inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model_gnd(**inputs)
results = processor_gnd.post_process_grounded_object_detection(
outputs, inputs.input_ids, threshold=0.5, text_threshold=0.5, target_sizes=[(h, w)]
)
if not results or len(results[0]['boxes']) == 0:
print(f" [Segment] ⚠️ Warning: Could not detect object. Using full image.")
return image_rgb
print(f" [Segment] βœ… Object detected successfully.")
sam_predictor.set_image(image_np)
box = results[0]['boxes'][0].cpu().numpy().astype(int)
masks, _, _ = sam_predictor.predict(box=box, multimask_output=False)
mask = masks[0]
background = np.ones_like(image_np, dtype=np.uint8) * 255
foreground = cv2.bitwise_and(image_np, image_np, mask=mask.astype(np.uint8))
background = cv2.bitwise_and(background, background, mask=~mask.astype(np.uint8))
segmented_np = cv2.add(foreground, background)
return Image.fromarray(segmented_np, 'RGB')
def upload_to_uploadcare(image: Image.Image) -> str:
"""Uploads a PIL Image to Uploadcare and returns the CDN URL."""
if not UPLOADCARE_PUBLIC_KEY:
return "UPLOADCARE_PUBLIC_KEY not configured."
try:
# Convert PIL Image to in-memory bytes buffer
buffer = BytesIO()
image.save(buffer, format='PNG')
buffer.seek(0)
files = { 'file': ('segmented_image.png', buffer, 'image/png') }
data = { 'UPLOADCARE_PUB_KEY': UPLOADCARE_PUBLIC_KEY, 'UPLOADCARE_STORE': '1' }
response = requests.post('https://upload.uploadcare.com/base/', files=files, data=data)
response.raise_for_status()
file_uuid = response.json().get('file')
cdn_url = f"https://ucarecdn.com/{file_uuid}/"
return cdn_url
except Exception as e:
return f"Uploadcare upload failed: {e}"
# ==========================================================================
# --- FLASK ENDPOINTS ---
# ==========================================================================
@app.route('/', methods=['GET'])
def health_check():
return jsonify({"status": "Unified AI Service is running"}), 200
@app.route('/process', methods=['POST'])
def process_item():
try:
data = request.json
print(f"\n[PROCESS] Received request for: {data.get('objectName')}")
response = {
"canonicalLabel": data.get('objectName', '').lower().strip(),
"brand_embedding": get_text_embedding(data.get('brand')),
"material_embedding": get_text_embedding(data.get('material')),
"size_embedding": get_text_embedding(data.get('size')),
"colors_embedding": get_text_embedding(data.get('colors')),
}
image_embeddings = []
if data.get('images'):
print(f" [PROCESS] Processing {len(data['images'])} image(s)...")
for image_url in data['images']:
try:
img_response = requests.get(image_url, timeout=20)
img_response.raise_for_status()
image = Image.open(BytesIO(img_response.content))
segmented_image = segment_guided_object(image, data['objectName'])
# --- DEBUGGING STEP: Upload segmented image and log the URL ---
debug_url = upload_to_uploadcare(segmented_image)
print(f" - 🐞 DEBUG URL: {debug_url}")
# -----------------------------------------------------------
embedding = get_image_embedding(segmented_image)
image_embeddings.append(embedding)
except Exception as e:
print(f" - ⚠️ Could not process image {image_url}: {e}")
continue
response["image_embeddings"] = image_embeddings
print(f" [PROCESS] βœ… Successfully processed all features.")
return jsonify(response), 200
except Exception as e:
print(f"❌ Error in /process: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
@app.route('/compare', methods=['POST'])
def compare_items():
try:
payload = request.json
query_item = payload['queryItem']
search_list = payload['searchList']
print(f"\n[COMPARE] Received {len(search_list)} pre-filtered candidates for '{query_item.get('objectName')}'.")
results = []
for item in search_list:
item_id = item.get('_id')
print(f"\n - Comparing with item: {item_id}")
try:
# 1. Calculate Text Score
total_text_score = 0
for field in TEXT_FIELDS_TO_EMBED:
q_emb = query_item.get(f"{field}_embedding")
i_emb = item.get(f"{field}_embedding")
if q_emb and i_emb:
total_text_score += cosine_similarity(q_emb, i_emb)
text_score = total_text_score / len(TEXT_FIELDS_TO_EMBED) if TEXT_FIELDS_TO_EMBED else 0
print(f" - Text Score: {text_score:.4f}")
# 2. Calculate Image Score with detailed logging
image_score = 0.0
query_img_embs = query_item.get('image_embeddings', [])
item_img_embs = item.get('image_embeddings', [])
if query_img_embs and item_img_embs:
all_img_scores = []
print(f" - Image Pair Scores:")
for i, q_emb in enumerate(query_img_embs):
for j, i_emb in enumerate(item_img_embs):
pair_score = cosine_similarity(q_emb, i_emb)
print(f" - Query Img {i+1} vs Item Img {j+1}: {pair_score:.4f}")
all_img_scores.append(pair_score)
if all_img_scores:
image_score = max(all_img_scores)
print(f" - Max Image Score: {image_score:.4f}")
# 3. Calculate Final Score
final_score = (SCORE_WEIGHTS['text_score'] * text_score + SCORE_WEIGHTS['image_score'] * image_score)
print(f" - Final Hybrid Score: {final_score:.4f}")
if final_score >= FINAL_SCORE_THRESHOLD:
print(f" - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
results.append({ "_id": str(item_id), "score": round(final_score, 4) })
else:
print(f" - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")
except Exception as e:
print(f" - ⚠️ Skipping item {item_id} due to scoring error: {e}")
continue
results.sort(key=lambda x: x["score"], reverse=True)
print(f"\n[COMPARE] βœ… Search complete. Found {len(results)} potential matches.")
return jsonify({"matches": results}), 200
except Exception as e:
print(f"❌ Error in /compare: {e}")
traceback.print_exc()
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860)