import os import io import base64 import time import logging import threading import uuid from datetime import datetime from pathlib import Path from collections import deque from typing import Dict, Optional, Tuple import gradio as gr from gradio_client import Client from PIL import Image # ───────── Logging ───────── logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ───────── Queue System Configuration ───────── MAX_QUEUE_SIZE = 50 MAX_CONCURRENT_REQUESTS = 1 # GPU can only handle 1 request at a time AVERAGE_PROCESSING_TIME = 15 # seconds # ───────── Backend connection ───────── HF_TOKEN = os.getenv("HF_TOKEN") if not HF_TOKEN: raise ValueError("HF_TOKEN environment variable is required") # ───────── Global Queue System ───────── class QueueManager: def __init__(self): self.queue = deque() # (request_id, user_data, timestamp) self.processing = {} # request_id -> processing_start_time self.completed = {} # request_id -> result self.failed = {} # request_id -> error_message self.lock = threading.Lock() self.stats = { 'total_processed': 0, 'total_failed': 0, 'avg_processing_time': AVERAGE_PROCESSING_TIME } def add_request(self, request_id: str, user_data: dict) -> Tuple[int, float]: """Add request to queue. Returns (position, estimated_wait)""" with self.lock: if len(self.queue) >= MAX_QUEUE_SIZE: raise Exception("Queue is full. Please try again later.") self.queue.append((request_id, user_data, time.time())) position = len(self.queue) # Calculate estimated wait time for single GPU processing_count = len(self.processing) queue_ahead = position - 1 if processing_count == 0: estimated_wait = 0 else: estimated_wait = (queue_ahead + 1) * self.stats['avg_processing_time'] logger.info(f"Request {request_id} added to queue. Position: {position}, Est. wait: {estimated_wait:.0f}s") return position, estimated_wait def get_next_requests(self): """Get next request to process (only 1 at a time for GPU)""" with self.lock: if len(self.processing) >= MAX_CONCURRENT_REQUESTS or len(self.queue) == 0: return [] request_id, user_data, timestamp = self.queue.popleft() self.processing[request_id] = time.time() return [(request_id, user_data)] def complete_request(self, request_id: str, result): """Mark request as completed""" with self.lock: if request_id in self.processing: processing_time = time.time() - self.processing[request_id] del self.processing[request_id] self.completed[request_id] = result self.stats['total_processed'] += 1 # Update average processing time current_avg = self.stats['avg_processing_time'] self.stats['avg_processing_time'] = (current_avg * 0.8) + (processing_time * 0.2) logger.info(f"Request {request_id} completed in {processing_time:.1f}s") def fail_request(self, request_id: str, error_msg: str): """Mark request as failed""" with self.lock: if request_id in self.processing: del self.processing[request_id] self.failed[request_id] = error_msg self.stats['total_failed'] += 1 logger.error(f"Request {request_id} failed: {error_msg}") def get_request_status(self, request_id: str) -> dict: """Get status of specific request""" with self.lock: if request_id in self.completed: return {'status': 'completed', 'result': self.completed[request_id]} elif request_id in self.failed: return {'status': 'failed', 'error': self.failed[request_id]} elif request_id in self.processing: processing_time = time.time() - self.processing[request_id] return {'status': 'processing', 'time': processing_time} else: for i, (rid, _, _) in enumerate(self.queue): if rid == request_id: return {'status': 'queued', 'position': i + 1} return {'status': 'not_found'} # Global queue manager queue_manager = QueueManager() backend_status = { "client": None, "connected": False, "last_check": None, "error_message": "" } def check_backend_connection(): """Ping the HF Space and cache the client object.""" try: test_client = Client("milliyin/backend", hf_token=HF_TOKEN) backend_status.update({ "client": test_client, "connected": True, "error_message": "", "last_check": time.time(), }) logger.info("✅ Backend connection established") return True, "🟢 Model is ready" except Exception as e: backend_status.update({ "client": None, "connected": False, "last_check": time.time(), "error_message": str(e), }) err = str(e).lower() if "timeout" in err or "read operation timed out" in err: return False, "🟡 Model is starting up. Please wait 3‑4 min." return False, f"🔴 Backend error: {e}" # initial probe check_backend_connection() # ───────── Queue Processing Worker ───────── def queue_worker(): """Background worker to process queue - one request at a time""" while True: try: requests = queue_manager.get_next_requests() if not requests: time.sleep(1) continue # Process single request (GPU limitation) request_id, user_data = requests[0] logger.info(f"Starting processing request {request_id}") process_single_request(request_id, user_data) time.sleep(0.5) except Exception as e: logger.error(f"Queue worker error: {e}") time.sleep(5) def process_single_request(request_id: str, user_data: dict): """Process a single request""" try: img_b64 = user_data['image_b64'] category = user_data['category'] gender = user_data['gender'] if not backend_status["connected"]: check_backend_connection() if not backend_status["connected"]: raise Exception("Backend not available") client = backend_status["client"] start_time = time.time() result = client.predict( img_b64, category, gender, api_name="/predict", ) processing_time = time.time() - start_time if not result or len(result) < 4: raise ValueError("Invalid response structure from backend") _, overlay_b64, bg_b64, status = result final_result = { 'overlay_b64': overlay_b64, 'bg_b64': bg_b64, 'status': status, 'processing_time': processing_time } queue_manager.complete_request(request_id, final_result) except Exception as e: queue_manager.fail_request(request_id, str(e)) # Start queue worker worker_thread = threading.Thread(target=queue_worker, daemon=True) worker_thread.start() # ───────── Helpers ───────── def image_to_base64(image: Image.Image) -> str: if image is None: return "" if image.mode != "RGB": image = image.convert("RGB") buf = io.BytesIO() image.save(buf, format="PNG") return base64.b64encode(buf.getvalue()).decode() def base64_to_image(b64: str) -> Optional[Image.Image]: if not b64: return None try: return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB") except Exception as e: logger.error(f"Failed to decode base64 → image: {e}") return None # ───────── Request Management ───────── active_requests = {} # session_id -> request_id def submit_request(input_image: Image.Image, category: str, gender: str): """Submit a new request to the queue""" if input_image is None: return None, None, "❌ Please upload an image.", gr.update(interactive=True), "" try: request_id = str(uuid.uuid4()) img_b64 = image_to_base64(input_image) user_data = { 'image_b64': img_b64, 'category': category, 'gender': gender, 'timestamp': time.time() } position, estimated_wait = queue_manager.add_request(request_id, user_data) status_msg = f"🚀 Request submitted! Position in queue: #{position}" if position == 1 and len(queue_manager.processing) == 0: status_msg += " | Starting processing now..." elif estimated_wait > 0: status_msg += f" | Estimated wait: {estimated_wait:.0f}s" return None, None, status_msg, gr.update(interactive=False), request_id except Exception as e: return None, None, f"❌ {str(e)}", gr.update(interactive=True), "" def check_request_status(request_id: str): """Check the status of a request""" if not request_id: return None, None, "No active request", gr.update(interactive=True) status_info = queue_manager.get_request_status(request_id) if status_info['status'] == 'completed': result = status_info['result'] overlay_img = base64_to_image(result['overlay_b64']) bg_img = base64_to_image(result['bg_b64']) status_msg = f"✅ {result['status']} (⏱ {result['processing_time']:.1f}s)" return overlay_img, bg_img, status_msg, gr.update(interactive=True) elif status_info['status'] == 'failed': return None, None, f"❌ {status_info['error']}", gr.update(interactive=True) elif status_info['status'] == 'processing': processing_time = status_info['time'] return None, None, f"⚡ Processing... ({processing_time:.1f}s)", gr.update(interactive=False) elif status_info['status'] == 'queued': position = status_info['position'] avg_time = queue_manager.stats['avg_processing_time'] estimated_wait = position * avg_time wait_msg = f" | Est. wait: {int(estimated_wait/60)}m {int(estimated_wait%60)}s" if estimated_wait > 30 else "" return None, None, f"⏳ In queue, position #{position}{wait_msg}", gr.update(interactive=False) else: return None, None, "❓ Request not found", gr.update(interactive=True) def disable_button(): return gr.update(interactive=False) # ───────── CSS ───────── custom_css = """ .gradio-container { background: linear-gradient(135deg, #3b4371 0%, #2d1b69 25%, #673ab7 50%, #8e24aa 75%, #6a1b9a 100%); font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; min-height: 100vh; } .contain { background: rgba(255, 255, 255, 0.95); border-radius: 15px; padding: 25px; margin: 15px; box-shadow: 0 10px 30px rgba(0, 0, 0, 0.2); backdrop-filter: blur(10px); } .title-container { text-align: center; margin-bottom: 25px; padding: 20px; background: linear-gradient(135deg, #673ab7, #8e24aa); border-radius: 12px; box-shadow: 0 5px 20px rgba(103, 58, 183, 0.4); } .title-container h1 { color: white; font-size: 2.2em; font-weight: bold; margin: 0; text-shadow: 1px 1px 3px rgba(0, 0, 0, 0.3); } .info-bar { background: linear-gradient(135deg, #7c4dff, #6a1b9a); padding: 12px; border-radius: 8px; margin-bottom: 20px; color: white; text-align: center; font-weight: 500; box-shadow: 0 3px 12px rgba(124, 77, 255, 0.3); } .section-header { background: linear-gradient(135deg, #e1bee7, #d1c4e9); padding: 12px; border-radius: 8px; margin-bottom: 15px; border-left: 4px solid #673ab7; } .section-header h3 { margin: 0; color: #333; font-weight: 600; } .input-group { background: rgba(255, 255, 255, 0.85); padding: 18px; border-radius: 12px; margin-bottom: 15px; border: 1px solid rgba(103, 58, 183, 0.2); box-shadow: 0 3px 12px rgba(103, 58, 183, 0.1); } .result-section { background: rgba(255, 255, 255, 0.9); padding: 18px; border-radius: 12px; border: 1px solid rgba(103, 58, 183, 0.2); box-shadow: 0 3px 12px rgba(103, 58, 183, 0.1); } .tip-box { background: linear-gradient(135deg, #f3e5f5, #e8eaf6); padding: 10px; border-radius: 6px; margin: 8px 0; border-left: 3px solid #673ab7; color: #4a148c; font-weight: 500; } button.primary { background: linear-gradient(135deg, #673ab7, #8e24aa) !important; border: none !important; border-radius: 20px !important; padding: 12px 25px !important; color: white !important; font-weight: bold !important; font-size: 15px !important; box-shadow: 0 5px 15px rgba(103, 58, 183, 0.4) !important; } button.primary:hover { box-shadow: 0 8px 25px rgba(103, 58, 183, 0.6) !important; opacity: 0.9 !important; transform: translateY(-2px) !important; } label { color: #4a148c !important; font-weight: 600 !important; } input, textarea, select { border: 1px solid rgba(103, 58, 183, 0.3) !important; border-radius: 6px !important; } input:focus, textarea:focus, select:focus { border-color: #673ab7 !important; box-shadow: 0 0 0 2px rgba(103, 58, 183, 0.2) !important; } .gr-slider input[type="range"] { accent-color: #673ab7 !important; } input[type="checkbox"] { accent-color: #673ab7 !important; } .preserve-aspect-ratio img { object-fit: contain !important; width: auto !important; max-height: 512px !important; } .social-links { text-align: center; margin: 20px 0; } .social-links a { margin: 0 10px; padding: 8px 16px; background: #667eea; color: white; text-decoration: none; border-radius: 8px; transition: all 0.3s ease; } .social-links a:hover { background: #764ba2; transform: translateY(-2px); } .feature-box { background: #f8fafc; border: 1px solid #e2e8f0; padding: 20px; border-radius: 12px; margin: 10px 0; } """ # ───────── Gradio Blocks ───────── with gr.Blocks(css=custom_css, title="Jewellery Photography Preview") as demo: # Hero gr.HTML("""
Upload a jewellery image, select model, and get professional photos instantly
Select a clear jewellery image for best results
Preview overlay detection and final professional background
Experience the future of virtual fashion and garment visualization.
© 2024 Snapwear AI. Professional AI tools for fashion and design.