File size: 15,482 Bytes
01323e1
9605648
01323e1
87cf5bb
fdc1498
9605648
b31d91c
17bb3c0
e3c49c8
fbfffae
 
 
 
 
b31d91c
fbfffae
 
4954a1e
fbfffae
b31d91c
fbfffae
87cf5bb
fbfffae
87cf5bb
fbfffae
b31d91c
 
 
fbfffae
b31d91c
fbfffae
c38d91d
01323e1
9605648
 
01323e1
1e328fc
fbfffae
619f2e4
 
 
01323e1
619f2e4
fbfffae
87cf5bb
b31d91c
fbfffae
87cf5bb
174ef72
b31d91c
 
 
87cf5bb
 
174ef72
87cf5bb
 
 
 
 
 
fdc1498
87cf5bb
fbfffae
87cf5bb
 
 
b31d91c
fbfffae
b31d91c
87cf5bb
fbfffae
 
b31d91c
 
 
87cf5bb
d996f25
4954a1e
174ef72
4954a1e
e3c49c8
87cf5bb
b31d91c
 
 
 
 
 
d996f25
fbfffae
87cf5bb
 
 
 
 
 
 
 
b31d91c
 
 
 
fbfffae
01323e1
 
 
 
 
 
 
 
832d772
708c63e
10deabd
fdc1498
 
ccd38a7
10deabd
fdc1498
708c63e
fdc1498
 
 
 
ccd38a7
10deabd
fbfffae
b31d91c
10deabd
87cf5bb
fdc1498
b31d91c
fbfffae
 
ccd38a7
fdc1498
10deabd
 
 
87cf5bb
fbfffae
fdc1498
10deabd
ccd38a7
10deabd
 
fdc1498
87cf5bb
fdc1498
b31d91c
ccd38a7
fdc1498
10deabd
fdc1498
 
10deabd
fdc1498
 
 
ccd38a7
fdc1498
ccd38a7
 
9605648
ccd38a7
 
9605648
ccd38a7
fdc1498
9605648
ccd38a7
 
 
9605648
 
ccd38a7
 
9605648
ccd38a7
 
 
 
 
10deabd
fdc1498
10deabd
9605648
ccd38a7
 
10deabd
 
fdc1498
619f2e4
 
 
 
 
 
 
 
 
 
 
 
01323e1
619f2e4
 
 
b31d91c
 
 
 
 
 
619f2e4
fbfffae
 
 
 
b31d91c
87cf5bb
01323e1
b31d91c
 
 
 
c38d91d
fbfffae
87cf5bb
 
b31d91c
87cf5bb
4954a1e
b31d91c
 
 
 
01323e1
9605648
708c63e
619f2e4
 
 
87cf5bb
 
b31d91c
01323e1
b31d91c
01323e1
87cf5bb
 
b31d91c
87cf5bb
fbfffae
87cf5bb
 
fbfffae
 
 
 
 
b31d91c
 
 
abe843d
01323e1
fbfffae
 
d996f25
174ef72
fbfffae
01323e1
 
 
fdc1498
b31d91c
87cf5bb
 
c38d91d
01323e1
 
 
 
fdc1498
01323e1
 
c38d91d
01323e1
9605648
 
01323e1
fdc1498
01323e1
 
c38d91d
01323e1
 
 
 
fdc1498
01323e1
 
 
 
 
 
b31d91c
fdc1498
b31d91c
87cf5bb
 
 
 
01323e1
 
 
87cf5bb
 
4282b8d
01323e1
fdc1498
b7221b8
 
 
 
 
 
 
 
 
01323e1
fbfffae
174ef72
d996f25
174ef72
 
 
fbfffae
01323e1
fbfffae
01323e1
fbfffae
87cf5bb
fbfffae
 
 
 
 
 
01323e1
a9999ab
abe843d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# -------------------------------------------------------------------------- #
# UNIFIED AI SERVICE V3.4 (Color-Enhanced Segmentation)
# -------------------------------------------------------------------------- #
# This service uses DINOv2 for image embeddings and BGE for text embeddings.
# - The segmentation prompt now includes colors for better accuracy.
# - For debugging, segmented images are uploaded to Uploadcare.
# --------------------------------------------------------------------------
import sys
sys.stdout.reconfigure(line_buffering=True)
import os
import numpy as np
import requests
import cv2
import traceback
from io import BytesIO
from flask import Flask, request, jsonify
from PIL import Image
from datetime import datetime, timedelta

# --- Import Deep Learning Libraries ---
import torch
from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
from segment_anything import SamPredictor, sam_model_registry
from transformers import AutoProcessor as AutoGndProcessor, AutoModelForZeroShotObjectDetection

# ==========================================================================
# --- CONFIGURATION & INITIALIZATION ---
# ==========================================================================

app = Flask(__name__)

TEXT_FIELDS_TO_EMBED = ["brand", "material", "markings"]
SCORE_WEIGHTS = {
    "text_score": 0.6,
    "image_score": 0.4
}
FINAL_SCORE_THRESHOLD = 0.75

# --- Load Uploadcare Credentials from Environment Variables ---
UPLOADCARE_PUBLIC_KEY = os.getenv('UPLOADCARE_PUBLIC_KEY')
if not UPLOADCARE_PUBLIC_KEY:
    print("⚠ WARNING: UPLOADCARE_PUBLIC_KEY environment variable not set. Debug uploads will fail.")

print("="*50)
print("πŸš€ Initializing AI Service with DINOv2...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🧠 Using device: {device}")

print("...Loading BGE text model...")
bge_model_id = "BAAI/bge-small-en-v1.5"
tokenizer_text = AutoTokenizer.from_pretrained(bge_model_id)
model_text = AutoModel.from_pretrained(bge_model_id).to(device)
print("βœ… BGE model loaded.")

print("...Loading DINOv2 model...")
dinov2_model_id = "facebook/dinov2-base"
processor_dinov2 = AutoImageProcessor.from_pretrained(dinov2_model_id)
model_dinov2 = AutoModel.from_pretrained(dinov2_model_id).to(device)
print("βœ… DINOv2 model loaded.")

print("...Loading Grounding DINO model for segmentation...")
gnd_model_id = "IDEA-Research/grounding-dino-base" # Kept base as you didn't specify changing this
processor_gnd = AutoGndProcessor.from_pretrained(gnd_model_id)
model_gnd = AutoModelForZeroShotObjectDetection.from_pretrained(gnd_model_id).to(device)
print("βœ… Grounding DINO model loaded.")

print("...Loading SAM model...")
sam_checkpoint = "sam_vit_b_01ec64.pth"
sam_model = sam_model_registry["vit_b"](checkpoint=sam_checkpoint).to(device)
sam_predictor = SamPredictor(sam_model)
print("βœ… SAM model loaded.")
print("="*50)

# ==========================================================================
# --- HELPER FUNCTIONS ---
# ==========================================================================

def get_text_embedding(text: str) -> list:
    if isinstance(text, list):
        if not text: return None
        text = ", ".join(text)
    if not text or not text.strip():
        return None
    instruction = "Represent this sentence for searching relevant passages: "
    inputs = tokenizer_text(instruction + text, return_tensors='pt', padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model_text(**inputs)
    embedding = outputs.last_hidden_state[:, 0, :]
    embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
    return embedding.cpu().numpy()[0].tolist()

def get_image_embedding(image: Image.Image) -> list:
    inputs = processor_dinov2(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model_dinov2(**inputs)
    embedding = outputs.last_hidden_state[:, 0, :]
    embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)
    return embedding.cpu().numpy()[0].tolist()

def cosine_similarity(vec1, vec2):
    if vec1 is None or vec2 is None: return 0.0
    vec1, vec2 = np.array(vec1), np.array(vec2)
    return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))

def jaccard_similarity(set1, set2):
    if not isinstance(set1, set) or not isinstance(set2, set):
        return 0.0
    intersection = set1.intersection(set2)
    union = set1.union(set2)
    if not union:
        return 1.0 if not intersection else 0.0
    return len(intersection) / len(union)

def segment_guided_object(image: Image.Image, object_label: str, colors: list = []) -> Image.Image:
    """
    Finds and segments ALL instances of an object based on a text label and colors,
    returning the original image with the detected objects segmented with transparency.
    This version includes a hole-filling step to create solid masks.
    """
    # Create a more descriptive prompt using colors, as per your new app's logic
    color_str = " ".join(c.lower() for c in colors if c)
    if color_str:
        prompt = f"a {color_str} {object_label}."
    else:
        prompt = f"a {object_label}."

    print(f"  [Segment] Using prompt: '{prompt}' for segmentation.")
    image_rgb = image.convert("RGB")
    image_np = np.array(image_rgb)
    height, width = image_np.shape[:2]

    # Grounding DINO detection
    inputs = processor_gnd(images=image_rgb, text=prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model_gnd(**inputs)

    # Process results with a threshold
    results = processor_gnd.post_process_grounded_object_detection(
        outputs, inputs.input_ids, threshold=0.35, text_threshold=0.5, target_sizes=[(height, width)]
    )

    if not results or len(results[0]['boxes']) == 0:
        print(f"  [Segment] ⚠ Warning: Could not detect '{object_label}' with GroundingDINO. Returning original image.")
        return Image.fromarray(np.concatenate([image_np, np.full((height, width, 1), 255, dtype=np.uint8)], axis=-1), 'RGBA')

    boxes = results[0]['boxes']
    scores = results[0]['scores']
    print(f"  [Segment] βœ… Found {len(boxes)} potential object(s) with confidence scores: {[round(s.item(), 2) for s in scores]}")

    # Set image for SAM
    sam_predictor.set_image(image_np)

    # Initialize an empty mask to combine all detections
    combined_mask = np.zeros((height, width), dtype=np.uint8)

    # Predict masks for all detected boxes and combine them
    for box in boxes:
        box = box.cpu().numpy().astype(int)
        masks, _, _ = sam_predictor.predict(box=box, multimask_output=False)
        combined_mask = np.bitwise_or(combined_mask, masks[0]) # Combine masks
    
    print("  [Segment] Combined masks for all detected objects.")

    # --- START: HOLE FILLING LOGIC ---
    # This new block will fill any holes within the combined mask.
    print("  [Segment] Post-processing: Filling holes in the combined mask...")
    
    # Find contours. RETR_EXTERNAL retrieves only the extreme outer contours.
    contours, _ = cv2.findContours(combined_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Create a new blank mask to draw the filled contours on.
    filled_mask = np.zeros_like(combined_mask)
    
    if contours:
        # Draw the detected contours onto the new mask and fill them.
        # The -1 index means draw all contours, and cv2.FILLED fills them.
        cv2.drawContours(filled_mask, contours, -1, 255, thickness=cv2.FILLED)
    else:
        # If for some reason no contours were found, fall back to the original mask.
        filled_mask = combined_mask
    print("  [Segment] βœ… Hole filling complete.")
    # --- END: HOLE FILLING LOGIC ---

    # Create an RGBA image where the background is transparent
    object_rgba = np.zeros((height, width, 4), dtype=np.uint8)
    object_rgba[:, :, :3] = image_np # Copy original RGB
    
    # Apply the NEW filled mask as the alpha channel
    object_rgba[:, :, 3] = filled_mask

    return Image.fromarray(object_rgba, 'RGBA')


def upload_to_uploadcare(image: Image.Image) -> str:
    if not UPLOADCARE_PUBLIC_KEY:
        return "UPLOADCARE_PUBLIC_KEY not configured."
    try:
        buffer = BytesIO()
        image.save(buffer, format='PNG')
        buffer.seek(0)
        files = { 'file': ('segmented_image.png', buffer, 'image/png') }
        data = { 'UPLOADCARE_PUB_KEY': UPLOADCARE_PUBLIC_KEY, 'UPLOADCARE_STORE': '1' }
        response = requests.post('https://upload.uploadcare.com/base/', files=files, data=data)
        response.raise_for_status()
        file_uuid = response.json().get('file')
        return f"https://ucarecdn.com/{file_uuid}/"
    except Exception as e:
        return f"Uploadcare upload failed: {e}"

# ==========================================================================
# --- FLASK ENDPOINTS ---
# ==========================================================================

@app.route('/', methods=['GET'])
def health_check():
    return jsonify({"status": "Unified AI Service is running"}), 200

@app.route('/process', methods=['POST'])
def process_item():
    try:
        data = request.json
        print(f"\n[PROCESS] Received request for: {data.get('objectName')}")

        response = {
            "canonicalLabel": data.get('objectName', '').lower().strip(),
            "brand_embedding": get_text_embedding(data.get('brand')),
            "material_embedding": get_text_embedding(data.get('material')),
            "markings_embedding": get_text_embedding(data.get('markings')),
        }

        image_embeddings = []
        if data.get('images'):
            print(f"  [PROCESS] Processing {len(data['images'])} image(s)...")
            for image_url in data['images']:
                try:
                    img_response = requests.get(image_url, timeout=20)
                    img_response.raise_for_status()
                    image = Image.open(BytesIO(img_response.content))

                    # --- UPDATED: Pass colors to the segmentation function ---
                    segmented_image = segment_guided_object(image, data['objectName'], data.get('colors', []))
                    debug_url = upload_to_uploadcare(segmented_image)
                    print(f"    - 🐞 DEBUG URL: {debug_url}")

                    embedding = get_image_embedding(segmented_image)
                    image_embeddings.append(embedding)
                except Exception as e:
                    print(f"    - ⚠ Could not process image {image_url}: {e}")
                    continue

        response["image_embeddings"] = image_embeddings
        print(f"  [PROCESS] βœ… Successfully processed all features.")
        return jsonify(response), 200

    except Exception as e:
        print(f"❌ Error in /process: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

@app.route('/compare', methods=['POST'])
def compare_items():
    try:
        payload = request.json
        query_item = payload['queryItem']
        search_list = payload['searchList']
        print(f"\n[COMPARE] Received {len(search_list)} pre-filtered candidates for '{query_item.get('objectName')}'.")

        results = []
        for item in search_list:
            item_id = item.get('_id')
            print(f"\n  - Comparing with item: {item_id}")
            try:
                text_score_components = []
                component_log = {}

                # 1. Calculate score for fields with text embeddings (now includes 'markings')
                for field in TEXT_FIELDS_TO_EMBED:
                    q_emb = query_item.get(f"{field}_embedding")
                    i_emb = item.get(f"{field}_embedding")
                    if q_emb and i_emb: 
                        score = cosine_similarity(q_emb, i_emb)
                        text_score_components.append(score)
                        component_log[field] = f"{score:.4f}"

                # 2. Calculate Jaccard score for 'colors'
                q_colors = set(c.lower().strip() for c in query_item.get('colors', []) if c)
                i_colors = set(c.lower().strip() for c in item.get('colors', []) if c)
                if q_colors and i_colors:
                    score = jaccard_similarity(q_colors, i_colors)
                    text_score_components.append(score)
                    component_log['colors'] = f"{score:.4f}"

                # 3. Calculate direct match score for 'size'
                q_size = (query_item.get('size') or "").lower().strip()
                i_size = (item.get('size') or "").lower().strip()
                if q_size and i_size:
                    score = 1.0 if q_size == i_size else 0.0
                    text_score_components.append(score)
                    component_log['size'] = f"{score:.4f}"

                # 4. Average only the scores from the available components
                text_score = 0.0
                if text_score_components:
                    text_score = sum(text_score_components) / len(text_score_components)
                
                print(f"    - Text Score Components: {component_log}")
                print(f"    - Final Avg Text Score: {text_score:.4f} (from {len(text_score_components)} components)")

                # 5. Calculate Image Score
                image_score = 0.0
                query_img_embs = query_item.get('image_embeddings', [])
                item_img_embs = item.get('image_embeddings', [])
                if query_img_embs and item_img_embs:
                    all_img_scores = []
                    for q_emb in query_img_embs:
                        for i_emb in item_img_embs:
                            all_img_scores.append(cosine_similarity(q_emb, i_emb))
                    if all_img_scores:
                        image_score = max(all_img_scores)
                print(f"    - Max Image Score: {image_score:.4f}")

                # 6. Calculate Final Score (Dynamic)
                final_score = 0.0
                if query_img_embs and item_img_embs:
                    print(f"    - Calculating Hybrid Score (Text + Image)...")
                    final_score = (SCORE_WEIGHTS['text_score'] * text_score + SCORE_WEIGHTS['image_score'] * image_score)
                else:
                    print(f"    - One or both items missing images. Using Text Score only...")
                    final_score = text_score

                print(f"    - Final Dynamic Score: {final_score:.4f}")

                if final_score >= FINAL_SCORE_THRESHOLD:
                    print(f"    - βœ… ACCEPTED (Score >= {FINAL_SCORE_THRESHOLD})")
                    results.append({ "_id": str(item_id), "score": round(final_score, 4) })
                else:
                    print(f"    - ❌ REJECTED (Score < {FINAL_SCORE_THRESHOLD})")

            except Exception as e:
                print(f"    - ⚠ Skipping item {item_id} due to scoring error: {e}")
                continue

        results.sort(key=lambda x: x["score"], reverse=True)
        print(f"\n[COMPARE] βœ… Search complete. Found {len(results)} potential matches.")
        return jsonify({"matches": results}), 200

    except Exception as e:
        print(f"❌ Error in /compare: {e}")
        traceback.print_exc()
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=7860)