import gradio as gr import numpy as np from PIL import Image, ImageDraw import base64 from io import BytesIO import re import os examples = [ {"image": "./assets/example_desktop.png", "prompt": "switch off the wired connection"}, {"image": "./assets/example_web.png", "prompt": "view all branches"}, {"image": "./assets/example_mobile.jpg", "prompt": "share the screenshot"}, ] # Code from user openai_api_key = os.environ["aria_ui_api_key"] openai_api_base = os.environ["aria_ui_api_base"] from openai import OpenAI # Assuming the OpenAI client library is installed client = OpenAI( api_key=openai_api_key, base_url=openai_api_base, ) models = client.models.list() model = models.data[0].id def encode_pil_image_to_base64(image: Image.Image) -> str: image = image.convert("RGB") buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") return img_str def request_aria_ui(image: Image.Image, prompt: str) -> str: image_base64 = encode_pil_image_to_base64(image) chat_completion_from_url = client.chat.completions.create( messages=[{ "role": "user", "content": [ { "type": "text", "text": "Given a GUI image, what are the relative (0-1000) pixel point coordinates for the element corresponding to the following instruction or description: " + prompt }, { "type": "image_url", "image_url": { "url": f"data:image/jpeg;base64,{image_base64}" }, }, ], }], model=model, max_tokens=512, stop=["<|im_end|>"], extra_body={"split_image": True, "image_max_size": 980, "temperature": 0, "top_k": 1} ) result = chat_completion_from_url.choices[0].message.content return result def _extract_coords_from_response(response: str) -> tuple[int, int]: resp = response.replace("```", "").strip() numbers = re.findall(r'\d+', resp) if len(numbers) != 2: raise ValueError(f"Expected exactly 2 coordinates, found {len(numbers)} numbers in response: {response}") return int(numbers[0]), int(numbers[1]) def image_grounding(image: Image.Image, prompt: str) -> Image.Image: try: # Request processing from API response = request_aria_ui(image, prompt) # Extract normalized coordinates norm_coords = _extract_coords_from_response(response) # Convert normalized coordinates to absolute coordinates width, height = image.size long_side = max(width, height) abs_coords = ( int(norm_coords[0] * width / 1000), # Scale x-coordinate int(norm_coords[1] * height / 1000) # Scale y-coordinate ) # Load and prepare the click indicator image click_image = Image.open("assets/click.png") # Calculate adaptive size for click indicator # Make it proportional to the image width (e.g., 3% of image width) target_width = int(long_side * 0.03) # 3% of image width aspect_ratio = click_image.width / click_image.height target_height = int(target_width / aspect_ratio) click_image = click_image.resize((target_width, target_height)) # Calculate position to center the click image on the coordinates # Add a small offset downward (20% of click image height) # Calculate position to align the 30% point of the click image with the coordinates click_x = abs_coords[0] - int(click_image.width * 0.3) # Align 30% from left click_y = abs_coords[1] - int(click_image.height * 0.3) # Align 30% from top # Create output image and paste the click indicator output_image = image.copy() # Draw bounding box draw = ImageDraw.Draw(output_image) bbox = [ click_x, # left click_y, # top click_x + click_image.width, # right click_y + click_image.height # bottom ] draw.rectangle(bbox, outline='red', width=int(click_image.width * 0.1)) output_image.paste(click_image, (click_x, click_y), click_image) return output_image except Exception as e: raise ValueError(f"An error occurred: {e}") def resize_image_with_max_size(image: Image.Image, max_size: int = 1920) -> Image.Image: """Resize image to have a maximum dimension of max_size while maintaining aspect ratio.""" width, height = image.size if width <= max_size and height <= max_size: return image if width > height: new_width = max_size new_height = int(height * (max_size / width)) else: new_height = max_size new_width = int(width * (max_size / height)) return image.resize((new_width, new_height), Image.Resampling.LANCZOS) # Gradio app def gradio_interface(input_image, prompt): print(input_image.size) input_image = resize_image_with_max_size(input_image) print(input_image.size) output_image = image_grounding(input_image, prompt) return output_image with gr.Blocks() as demo: # with gr.Row(elem_classes="container"): # gr.Image("https://raw.githubusercontent.com/AriaUI/Aria-UI/refs/heads/main/assets/logo_long.png", show_label=False, container=False, scale=1, elem_classes="logo", height=76) gr.HTML( """
Aria-UI
""" ) gr.Markdown("""| [🤗 Aria-UI Models](https://huggingface.co/Aria-UI/Aria-UI-base) • [🤗 Aria-UI Dataset](https://huggingface.co/datasets/Aria-UI/Aria-UI_Data) • [🌐 Project Page](https://ariaui.github.io) • [📝 Paper](https://arxiv.org/abs/2412.16256) | |:---------------------------------------------------------------------------------------------------------:|""") gr.Markdown("# Aria-UI: Visual Grounding for GUI Instructions") gr.Markdown("🚀🚀 Upload a GUI image and enter a instruction. Aria-UI will try its best to ground the instruction to specific element in the image. 🎯🎯") with gr.Row(): with gr.Column(scale=2): # Make this column smaller image_input = gr.Image(type="pil", label="Upload GUI Image", height=600) prompt_input = gr.Textbox(label="Enter GUI Instruction") submit_button = gr.Button("Process") with gr.Column(scale=3): # Make this column larger output_image = gr.Image(label="Grounding Result", height=500) # Set specific height for larger display with gr.Column(scale=2): # Move examples here and make them vertical gr.Examples( examples=[ [ example["image"], example["prompt"] ] for example in examples ], inputs=[image_input, prompt_input], outputs=[output_image], fn=gradio_interface, cache_examples=False, label="Example Tasks", # Add label for better organization examples_per_page=5 # Control number of examples shown at once ) submit_button.click( fn=gradio_interface, inputs=[image_input, prompt_input], outputs=[output_image] ) demo.launch( server_name="0.0.0.0", server_port=7860, ssr_mode=False, debug=True, )