Spaces:
Runtime error
Runtime error
Commit
·
f1ee247
1
Parent(s):
ee927fc
Update main.py
Browse files
main.py
CHANGED
@@ -85,7 +85,23 @@ def denormalize_image(image, source_range=(-1, 1)):
|
|
85 |
return (image * 255).clip(0, 255).astype(np.uint8)
|
86 |
else:
|
87 |
raise ValueError(f"Unsupported source range: {source_range}")
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
90 |
width, height = 256, 256
|
91 |
initial_images = load_initial_images(width, height)
|
@@ -103,7 +119,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
103 |
action_descriptions = []
|
104 |
initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
105 |
initial_actions = ['0:0'] * 7
|
106 |
-
initial_actions = ['N N N N N : N N N N N'] * 7
|
107 |
def unnorm_coords(x, y):
|
108 |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
109 |
|
@@ -121,7 +137,8 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
|
|
121 |
if DEBUG:
|
122 |
norm_x = x
|
123 |
norm_y = y
|
124 |
-
action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
|
|
|
125 |
prev_x = norm_x
|
126 |
prev_y = norm_y
|
127 |
elif action_type == "left_click":
|
@@ -180,7 +197,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
180 |
#positions = positions[1:]
|
181 |
mouse_position = position.split('~')
|
182 |
mouse_position = [int(item) for item in mouse_position]
|
183 |
-
mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
184 |
|
185 |
#previous_actions.append((action_type, mouse_position))
|
186 |
previous_actions = [(action_type, mouse_position))]
|
|
|
85 |
return (image * 255).clip(0, 255).astype(np.uint8)
|
86 |
else:
|
87 |
raise ValueError(f"Unsupported source range: {source_range}")
|
88 |
+
|
89 |
+
def format_action(action_str, is_padding=False):
|
90 |
+
if is_padding:
|
91 |
+
return "N N N N N : N N N N N"
|
92 |
+
|
93 |
+
# Split the x~y coordinates
|
94 |
+
x, y = map(int, action_str.split('~'))
|
95 |
+
|
96 |
+
# Convert numbers to padded strings and add spaces between digits
|
97 |
+
x_str = f"{abs(x):04d}"
|
98 |
+
y_str = f"{abs(y):04d}"
|
99 |
+
x_spaced = ' '.join(x_str)
|
100 |
+
y_spaced = ' '.join(y_str)
|
101 |
+
|
102 |
+
# Format with sign and proper spacing
|
103 |
+
return f"{'+ ' if x >= 0 else '- '}{x_spaced} : {'+ ' if y >= 0 else '- '}{y_spaced}"
|
104 |
+
|
105 |
def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List[Tuple[str, List[int]]]) -> np.ndarray:
|
106 |
width, height = 256, 256
|
107 |
initial_images = load_initial_images(width, height)
|
|
|
119 |
action_descriptions = []
|
120 |
initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
|
121 |
initial_actions = ['0:0'] * 7
|
122 |
+
#initial_actions = ['N N N N N : N N N N N'] * 7
|
123 |
def unnorm_coords(x, y):
|
124 |
return int(x), int(y) #int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
|
125 |
|
|
|
137 |
if DEBUG:
|
138 |
norm_x = x
|
139 |
norm_y = y
|
140 |
+
#action_descriptions.append(f"{(norm_x-prev_x):.0f}~{(norm_y-prev_y):.0f}")
|
141 |
+
action_descriptions.append(format_action(f'{norm_x-prev_x:.0f}~{norm_y-prev_y:.0f}'), pos=='0~0')
|
142 |
prev_x = norm_x
|
143 |
prev_y = norm_y
|
144 |
elif action_type == "left_click":
|
|
|
197 |
#positions = positions[1:]
|
198 |
mouse_position = position.split('~')
|
199 |
mouse_position = [int(item) for item in mouse_position]
|
200 |
+
#mouse_position = '+ 0 8 1 5 : + 0 3 3 5'
|
201 |
|
202 |
#previous_actions.append((action_type, mouse_position))
|
203 |
previous_actions = [(action_type, mouse_position))]
|