yuntian-deng commited on
Commit
fff6c83
·
1 Parent(s): 9dfe662

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -8
main.py CHANGED
@@ -48,13 +48,14 @@ model = model.to(device)
48
  def load_initial_images(width, height):
49
  initial_images = []
50
  for i in range(7):
51
- image_path = f"image_{i}.png"
52
- if os.path.exists(image_path):
53
- img = Image.open(image_path).resize((width, height))
54
- initial_images.append(np.array(img))
55
- else:
56
- print(f"Warning: {image_path} not found. Using blank image instead.")
57
- initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
 
58
  return initial_images
59
 
60
  def normalize_images(images, target_range=(-1, 1)):
@@ -90,7 +91,7 @@ def predict_next_frame(previous_frames: List[np.ndarray], previous_actions: List
90
  # Prepare the prompt based on the previous actions
91
  action_descriptions = []
92
  initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
93
-
94
  def unnorm_coords(x, y):
95
  return int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
96
 
 
48
  def load_initial_images(width, height):
49
  initial_images = []
50
  for i in range(7):
51
+ initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
52
+ #image_path = f"image_{i}.png"
53
+ #if os.path.exists(image_path):
54
+ # img = Image.open(image_path).resize((width, height))
55
+ # initial_images.append(np.array(img))
56
+ #else:
57
+ # print(f"Warning: {image_path} not found. Using blank image instead.")
58
+ # initial_images.append(np.zeros((height, width, 3), dtype=np.uint8))
59
  return initial_images
60
 
61
  def normalize_images(images, target_range=(-1, 1)):
 
91
  # Prepare the prompt based on the previous actions
92
  action_descriptions = []
93
  initial_actions = ['901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '901:604', '921:604']
94
+ initial_actions = ['700:897'] * 7
95
  def unnorm_coords(x, y):
96
  return int(x - (1920 - 256) / 2), int(y - (1080 - 256) / 2)
97