Aria-UI commited on
Commit
a4c53e8
·
verified ·
1 Parent(s): 07eef99

Upload 10 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/example_web.png filter=lfs diff=lfs merge=lfs -text
app.py CHANGED
@@ -1,12 +1,18 @@
1
  import gradio as gr
2
- import cv2
3
  import numpy as np
4
- from PIL import Image
5
  import base64
6
  from io import BytesIO
7
  import re
8
  import os
9
 
 
 
 
 
 
 
 
10
  # Code from user
11
  openai_api_key = os.environ["aria_ui_api_key"]
12
  openai_api_base = os.environ["aria_ui_api_base"]
@@ -21,16 +27,15 @@ client = OpenAI(
21
  models = client.models.list()
22
  model = models.data[0].id
23
 
24
- def encode_numpy_image_to_base64(image: np.ndarray) -> str:
25
- success, buffer = cv2.imencode('.jpg', image)
26
- if not success:
27
- raise ValueError("Failed to encode image to jpg format")
28
- image_bytes = buffer.tobytes()
29
- base64_string = base64.b64encode(image_bytes).decode('utf-8')
30
- return base64_string
31
 
32
- def request_aria_ui(image: np.ndarray, prompt: str) -> str:
33
- image_base64 = encode_numpy_image_to_base64(image)
34
  chat_completion_from_url = client.chat.completions.create(
35
  messages=[{
36
  "role": "user",
@@ -50,7 +55,7 @@ def request_aria_ui(image: np.ndarray, prompt: str) -> str:
50
  model=model,
51
  max_tokens=512,
52
  stop=["<|im_end|>"],
53
- extra_body={"split_image": True, "image_max_size": 980}
54
  )
55
 
56
  result = chat_completion_from_url.choices[0].message.content
@@ -63,7 +68,7 @@ def _extract_coords_from_response(response: str) -> tuple[int, int]:
63
  raise ValueError(f"Expected exactly 2 coordinates, found {len(numbers)} numbers in response: {response}")
64
  return int(numbers[0]), int(numbers[1])
65
 
66
- def process_image(image: np.ndarray, prompt: str) -> np.ndarray:
67
  try:
68
  # Request processing from API
69
  response = request_aria_ui(image, prompt)
@@ -72,37 +77,102 @@ def process_image(image: np.ndarray, prompt: str) -> np.ndarray:
72
  norm_coords = _extract_coords_from_response(response)
73
 
74
  # Convert normalized coordinates to absolute coordinates
75
- height, width, _ = image.shape
 
76
  abs_coords = (
77
  int(norm_coords[0] * width / 1000), # Scale x-coordinate
78
  int(norm_coords[1] * height / 1000) # Scale y-coordinate
79
  )
80
 
81
- # Draw circle on image
82
- output_image = image.copy()
83
- cv2.circle(output_image, abs_coords, radius=10, color=(0, 255, 0), thickness=-1)
 
 
 
 
 
 
 
 
 
 
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  return output_image
 
86
  except Exception as e:
87
  raise ValueError(f"An error occurred: {e}")
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  # Gradio app
90
  def gradio_interface(input_image, prompt):
91
- input_image = np.array(input_image) # Convert PIL image to numpy
92
- output_image = process_image(input_image, prompt)
93
- return Image.fromarray(output_image)
 
 
94
 
95
  with gr.Blocks() as demo:
96
- gr.Markdown("# GUI Image Processor")
97
- gr.Markdown("Upload a GUI image and enter a prompt. The app will process the image and mark a location based on the response.")
 
 
 
98
 
99
  with gr.Row():
100
- with gr.Column():
101
  image_input = gr.Image(type="pil", label="Upload GUI Image")
102
  prompt_input = gr.Textbox(label="Enter Prompt")
103
  submit_button = gr.Button("Process")
104
- with gr.Column():
105
- output_image = gr.Image(label="Processed Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  submit_button.click(
108
  fn=gradio_interface,
 
1
  import gradio as gr
 
2
  import numpy as np
3
+ from PIL import Image, ImageDraw
4
  import base64
5
  from io import BytesIO
6
  import re
7
  import os
8
 
9
+ examples = [
10
+ {"image": "assets/example_desktop.png", "prompt": "switch off the wired connection"},
11
+ {"image": "assets/example_web.png", "prompt": "view all branches"},
12
+ {"image": "assets/example_mobile.jpg", "prompt": "share the screenshot"},
13
+ ]
14
+
15
+
16
  # Code from user
17
  openai_api_key = os.environ["aria_ui_api_key"]
18
  openai_api_base = os.environ["aria_ui_api_base"]
 
27
  models = client.models.list()
28
  model = models.data[0].id
29
 
30
+ def encode_pil_image_to_base64(image: Image.Image) -> str:
31
+ image = image.convert("RGB")
32
+ buffered = BytesIO()
33
+ image.save(buffered, format="JPEG")
34
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
35
+ return img_str
 
36
 
37
+ def request_aria_ui(image: Image.Image, prompt: str) -> str:
38
+ image_base64 = encode_pil_image_to_base64(image)
39
  chat_completion_from_url = client.chat.completions.create(
40
  messages=[{
41
  "role": "user",
 
55
  model=model,
56
  max_tokens=512,
57
  stop=["<|im_end|>"],
58
+ extra_body={"split_image": True, "image_max_size": 980, "temperature": 0, "top_k": 1}
59
  )
60
 
61
  result = chat_completion_from_url.choices[0].message.content
 
68
  raise ValueError(f"Expected exactly 2 coordinates, found {len(numbers)} numbers in response: {response}")
69
  return int(numbers[0]), int(numbers[1])
70
 
71
+ def image_grounding(image: Image.Image, prompt: str) -> Image.Image:
72
  try:
73
  # Request processing from API
74
  response = request_aria_ui(image, prompt)
 
77
  norm_coords = _extract_coords_from_response(response)
78
 
79
  # Convert normalized coordinates to absolute coordinates
80
+ width, height = image.size
81
+ long_side = max(width, height)
82
  abs_coords = (
83
  int(norm_coords[0] * width / 1000), # Scale x-coordinate
84
  int(norm_coords[1] * height / 1000) # Scale y-coordinate
85
  )
86
 
87
+ # Load and prepare the click indicator image
88
+ click_image = Image.open("assets/click.png")
89
+ # Calculate adaptive size for click indicator
90
+ # Make it proportional to the image width (e.g., 3% of image width)
91
+ target_width = int(long_side * 0.03) # 3% of image width
92
+ aspect_ratio = click_image.width / click_image.height
93
+ target_height = int(target_width / aspect_ratio)
94
+ click_image = click_image.resize((target_width, target_height))
95
+
96
+ # Calculate position to center the click image on the coordinates
97
+ # Add a small offset downward (20% of click image height)
98
+ # Calculate position to align the 30% point of the click image with the coordinates
99
+ click_x = abs_coords[0] - int(click_image.width * 0.3) # Align 30% from left
100
+ click_y = abs_coords[1] - int(click_image.height * 0.3) # Align 30% from top
101
 
102
+ # Create output image and paste the click indicator
103
+ output_image = image.copy()
104
+ # Draw bounding box
105
+ draw = ImageDraw.Draw(output_image)
106
+ bbox = [
107
+ click_x, # left
108
+ click_y, # top
109
+ click_x + click_image.width, # right
110
+ click_y + click_image.height # bottom
111
+ ]
112
+ draw.rectangle(bbox, outline='red', width=int(click_image.width * 0.1))
113
+ output_image.paste(click_image, (click_x, click_y), click_image)
114
  return output_image
115
+
116
  except Exception as e:
117
  raise ValueError(f"An error occurred: {e}")
118
 
119
+ def resize_image_with_max_size(image: Image.Image, max_size: int = 1920) -> Image.Image:
120
+ """Resize image to have a maximum dimension of max_size while maintaining aspect ratio."""
121
+ width, height = image.size
122
+
123
+ if width <= max_size and height <= max_size:
124
+ return image
125
+
126
+ if width > height:
127
+ new_width = max_size
128
+ new_height = int(height * (max_size / width))
129
+ else:
130
+ new_height = max_size
131
+ new_width = int(width * (max_size / height))
132
+
133
+ return image.resize((new_width, new_height), Image.Resampling.LANCZOS)
134
+
135
  # Gradio app
136
  def gradio_interface(input_image, prompt):
137
+ print(input_image.size)
138
+ input_image = resize_image_with_max_size(input_image)
139
+ print(input_image.size)
140
+ output_image = image_grounding(input_image, prompt)
141
+ return output_image
142
 
143
  with gr.Blocks() as demo:
144
+ with gr.Row(elem_classes="container"):
145
+ gr.Image("assets/logo_long.png", show_label=False, container=False, scale=1, elem_classes="logo", height=76)
146
+
147
+ gr.Markdown("# Aria-UI: Visual Grounding for GUI Instructions")
148
+ gr.Markdown("🚀🚀 Upload a GUI image and enter a instruction. Aria-UI will try its best to ground the instruction to specific element in the image. 🎯🎯")
149
 
150
  with gr.Row():
151
+ with gr.Column(scale=2): # Make this column smaller
152
  image_input = gr.Image(type="pil", label="Upload GUI Image")
153
  prompt_input = gr.Textbox(label="Enter Prompt")
154
  submit_button = gr.Button("Process")
155
+
156
+ with gr.Column(scale=3): # Make this column larger
157
+ output_image = gr.Image(label="Grounding Result", height=600) # Set specific height for larger display
158
+
159
+ with gr.Column(scale=2):
160
+ # Move examples here and make them vertical
161
+ gr.Examples(
162
+ examples=[
163
+ [
164
+ example["image"],
165
+ example["prompt"]
166
+ ]
167
+ for example in examples
168
+ ],
169
+ inputs=[image_input, prompt_input],
170
+ outputs=[output_image],
171
+ fn=gradio_interface,
172
+ cache_examples=False,
173
+ label="Example Tasks", # Add label for better organization
174
+ examples_per_page=5 # Control number of examples shown at once
175
+ )
176
 
177
  submit_button.click(
178
  fn=gradio_interface,
assets/aria_ui_logo.png ADDED
assets/click.png ADDED
assets/example_desktop.png ADDED
assets/example_mobile.jpg ADDED
assets/example_web.png ADDED

Git LFS Details

  • SHA256: 9f3add458689b6e3c11ca1ede9032aa65f623bd2c8a6a63976bf77cd4dc865b5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.58 MB
assets/logo_long.png ADDED