from typing import Optional import spaces import gradio as gr import numpy as np import torch from PIL import Image import io import base64, os from huggingface_hub import snapshot_download import traceback import warnings import sys # Suppress warnings warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", message=".*_supports_sdpa.*") # Simple monkey patch for transformers - avoid recursion def simple_patch_transformers(): """Simple patch to fix _supports_sdpa issue""" try: import transformers.modeling_utils as modeling_utils # Store original method original_check = modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation def patched_check(self, *args, **kwargs): # Simply set the attribute if it doesn't exist if not hasattr(self, '_supports_sdpa'): object.__setattr__(self, '_supports_sdpa', False) try: return original_check(self, *args, **kwargs) except AttributeError as e: if '_supports_sdpa' in str(e): # Return default attention implementation return "eager" raise modeling_utils.PreTrainedModel._check_and_adjust_attn_implementation = patched_check print("Applied simple transformers patch") except Exception as e: print(f"Warning: Could not patch transformers: {e}") # Apply the patch BEFORE importing utils simple_patch_transformers() # Now import the utils from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img # Download repository repo_id = "microsoft/OmniParser-v2.0" local_dir = "weights" if not os.path.exists(local_dir): snapshot_download(repo_id=repo_id, local_dir=local_dir) print(f"Repository downloaded to: {local_dir}") else: print(f"Weights already exist at: {local_dir}") # Custom function to load caption model def load_caption_model_safe(model_name="florence2", model_name_or_path="weights/icon_caption"): """Safely load caption model""" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Method 1: Try original function try: return get_caption_model_processor(model_name, model_name_or_path) except Exception as e: print(f"Original loading failed: {e}, trying alternative...") # Method 2: Load with specific configs try: from transformers import AutoProcessor, AutoModelForCausalLM print(f"Loading caption model from {model_name_or_path}...") processor = AutoProcessor.from_pretrained( model_name_or_path, trust_remote_code=True ) # Load model with safer config model = AutoModelForCausalLM.from_pretrained( model_name_or_path, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, trust_remote_code=True, attn_implementation="eager", # Use eager attention low_cpu_mem_usage=True ) # Ensure attribute exists (using object.__setattr__ to avoid recursion) if not hasattr(model, '_supports_sdpa'): object.__setattr__(model, '_supports_sdpa', False) if device.type == 'cuda': model = model.to(device) print("Model loaded successfully with alternative method") return {'model': model, 'processor': processor} except Exception as e: print(f"Alternative loading also failed: {e}") # Method 3: Manual loading as last resort try: print("Attempting manual model loading...") # Import required modules from transformers import AutoProcessor, AutoConfig import importlib.util # Load processor processor = AutoProcessor.from_pretrained( model_name_or_path, trust_remote_code=True ) # Load config config = AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=True ) # Manually import and instantiate model model_file = os.path.join(model_name_or_path, "modeling_florence2.py") if os.path.exists(model_file): spec = importlib.util.spec_from_file_location("modeling_florence2_custom", model_file) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # Get model class if hasattr(module, 'Florence2ForConditionalGeneration'): model_class = module.Florence2ForConditionalGeneration # Create model instance model = model_class(config) # Set the attribute before loading weights object.__setattr__(model, '_supports_sdpa', False) # Load weights weight_file = os.path.join(model_name_or_path, "model.safetensors") if os.path.exists(weight_file): from safetensors.torch import load_file state_dict = load_file(weight_file) model.load_state_dict(state_dict, strict=False) if device.type == 'cuda': model = model.to(device) model = model.half() # Use half precision print("Model loaded successfully with manual method") return {'model': model, 'processor': processor} except Exception as e: print(f"Manual loading failed: {e}") raise RuntimeError(f"Could not load model with any method: {e}") # Load models try: print("Loading YOLO model...") yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') print("YOLO model loaded successfully") print("Loading caption model...") caption_model_processor = load_caption_model_safe() print("Caption model loaded successfully") except Exception as e: print(f"Critical error loading models: {e}") print(traceback.format_exc()) caption_model_processor = None yolo_model = None # UI Configuration MARKDOWN = """ # OmniParser V2 Pro🔥

🎯 AI-powered screen understanding tool that detects UI elements and extracts text with high accuracy.

📝 Supports both PaddleOCR and EasyOCR for flexible text extraction.

""" DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Using device: {DEVICE}") custom_css = """ body { background-color: #f0f2f5; } .gradio-container { font-family: 'Segoe UI', sans-serif; max-width: 1400px; margin: auto; } h1, h2, h3, h4 { color: #283E51; } button { border-radius: 6px; transition: all 0.3s ease; } button:hover { transform: translateY(-2px); box-shadow: 0 4px 12px rgba(0,0,0,0.15); } .output-image { border: 2px solid #e1e4e8; border-radius: 8px; } #input_image { border: 2px dashed #4a90e2; border-radius: 8px; } #input_image:hover { border-color: #2c5aa0; } """ @spaces.GPU @torch.inference_mode() def process( image_input, box_threshold, iou_threshold, use_paddleocr, imgsz ) -> tuple: """Process image with error handling""" if image_input is None: return None, "⚠️ Please upload an image for processing." if caption_model_processor is None or yolo_model is None: return None, "⚠️ Models not loaded properly. Please restart the application." try: print(f"Processing: box_threshold={box_threshold}, iou_threshold={iou_threshold}, " f"use_paddleocr={use_paddleocr}, imgsz={imgsz}") # Calculate overlay ratio image_width = image_input.size[0] box_overlay_ratio = max(0.5, min(2.0, image_width / 3200)) draw_bbox_config = { 'text_scale': 0.8 * box_overlay_ratio, 'text_thickness': max(int(2 * box_overlay_ratio), 1), 'text_padding': max(int(3 * box_overlay_ratio), 1), 'thickness': max(int(3 * box_overlay_ratio), 1), } # OCR processing try: ocr_bbox_rslt, is_goal_filtered = check_ocr_box( image_input, display_img=False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold': 0.9}, use_paddleocr=use_paddleocr ) if ocr_bbox_rslt is None: text, ocr_bbox = [], [] else: text, ocr_bbox = ocr_bbox_rslt text = text if text is not None else [] ocr_bbox = ocr_bbox if ocr_bbox is not None else [] print(f"OCR found {len(text)} text regions") except Exception as e: print(f"OCR error: {e}") text, ocr_bbox = [], [] # Object detection and captioning try: # Ensure model has _supports_sdpa attribute if isinstance(caption_model_processor, dict) and 'model' in caption_model_processor: model = caption_model_processor['model'] if not hasattr(model, '_supports_sdpa'): object.__setattr__(model, '_supports_sdpa', False) dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( image_input, yolo_model, BOX_TRESHOLD=box_threshold, output_coord_in_ratio=True, ocr_bbox=ocr_bbox, draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text, iou_threshold=iou_threshold, imgsz=imgsz ) if dino_labled_img is None: raise ValueError("Failed to generate labeled image") except Exception as e: print(f"Detection error: {e}") return image_input, f"⚠️ Error during detection: {str(e)}" # Decode image try: image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) except Exception as e: print(f"Image decode error: {e}") return image_input, f"⚠️ Error decoding image: {str(e)}" # Format results if parsed_content_list and len(parsed_content_list) > 0: parsed_text = "🎯 **Detected Elements:**\n\n" for i, v in enumerate(parsed_content_list): if v: parsed_text += f"**Element {i}:** {v}\n" else: parsed_text = "ℹ️ No UI elements detected. Try adjusting the thresholds." print(f'Processing complete. Found {len(parsed_content_list)} elements.') return image, parsed_text except Exception as e: print(f"Processing error: {e}") print(traceback.format_exc()) return None, f"⚠️ Error: {str(e)}" # Build UI with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown(MARKDOWN) if caption_model_processor is None or yolo_model is None: gr.Markdown("### ⚠️ Warning: Models failed to load. Please check logs.") with gr.Row(): with gr.Column(scale=1): with gr.Accordion("📤 Upload & Settings", open=True): image_input_component = gr.Image( type='pil', label='Upload Screenshot', elem_id="input_image" ) gr.Markdown("### 🎛️ Detection Settings") box_threshold_component = gr.Slider( label='Box Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.05, info="Lower = more detections" ) iou_threshold_component = gr.Slider( label='IOU Threshold', minimum=0.01, maximum=1.0, step=0.01, value=0.1, info="Overlap filtering" ) use_paddleocr_component = gr.Checkbox( label='Use PaddleOCR', value=True ) imgsz_component = gr.Slider( label='Image Size', minimum=640, maximum=1920, step=32, value=640 ) submit_button_component = gr.Button( value='🚀 Process', variant='primary' ) with gr.Column(scale=2): with gr.Tabs(): with gr.Tab("🖼️ Result"): image_output_component = gr.Image( type='pil', label='Annotated Image' ) with gr.Tab("📝 Elements"): text_output_component = gr.Markdown( value="*Results will appear here...*" ) submit_button_component.click( fn=process, inputs=[ image_input_component, box_threshold_component, iou_threshold_component, use_paddleocr_component, imgsz_component ], outputs=[image_output_component, text_output_component], show_progress=True ) # Launch if __name__ == "__main__": try: demo.queue(max_size=10) demo.launch( share=False, show_error=True, server_name="0.0.0.0", server_port=7860 ) except Exception as e: print(f"Launch failed: {e}")