linoyts HF Staff commited on
Commit
600d72c
·
verified ·
1 Parent(s): 8b5bd21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -2
app.py CHANGED
@@ -53,9 +53,52 @@ pipeline = wan.WanTI2V(
53
  )
54
  print("Pipeline initialized and ready.")
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  # --- 2. Gradio Inference Function ---
58
- @spaces.GPU(duration=80)
59
  def generate_video(
60
  image,
61
  prompt,
@@ -71,7 +114,12 @@ def generate_video(
71
  if seed == -1:
72
  seed = random.randint(0, sys.maxsize)
73
 
74
- input_image = Image.fromarray(image).convert("RGB") if image is not None else None
 
 
 
 
 
75
 
76
  # Calculate number of frames based on duration
77
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
@@ -134,6 +182,18 @@ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
134
 
135
  run_button = gr.Button("Generate Video", variant="primary")
136
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
139
  gr.Examples(
 
53
  )
54
  print("Pipeline initialized and ready.")
55
 
56
+ # --- Helper Functions ---
57
+ def select_best_size_for_image(image, available_sizes):
58
+ """Select the size option with aspect ratio closest to the input image."""
59
+ if image is None:
60
+ return available_sizes[0] # Return first option if no image
61
+
62
+ img_width, img_height = image.size
63
+ img_aspect_ratio = img_height / img_width
64
+
65
+ best_size = available_sizes[0]
66
+ best_diff = float('inf')
67
+
68
+ for size_str in available_sizes:
69
+ # Parse size string like "704*1280"
70
+ height, width = map(int, size_str.split('*'))
71
+ size_aspect_ratio = height / width
72
+ diff = abs(img_aspect_ratio - size_aspect_ratio)
73
+
74
+ if diff < best_diff:
75
+ best_diff = diff
76
+ best_size = size_str
77
+
78
+ return best_size
79
+
80
+ def handle_image_upload(image):
81
+ """Handle image upload and return the best matching size."""
82
+ if image is None:
83
+ return gr.update()
84
+
85
+ pil_image = Image.fromarray(image).convert("RGB")
86
+ available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
87
+ best_size = select_best_size_for_image(pil_image, available_sizes)
88
+
89
+ return gr.update(value=best_size)
90
+
91
+ def get_duration(image, prompt, size, duration_seconds, sampling_steps, guide_scale, shift, seed):
92
+ """Calculate dynamic GPU duration based on parameters."""
93
+ if sampling_steps > 35 and duration_seconds > 2:
94
+ return 90
95
+ elif sampling_steps > 35 or duration_seconds > 2:
96
+ return 80
97
+ else:
98
+ return 60
99
 
100
  # --- 2. Gradio Inference Function ---
101
+ @spaces.GPU(duration=get_duration)
102
  def generate_video(
103
  image,
104
  prompt,
 
114
  if seed == -1:
115
  seed = random.randint(0, sys.maxsize)
116
 
117
+ input_image = None
118
+ if image is not None:
119
+ input_image = Image.fromarray(image).convert("RGB")
120
+ # Resize image to match selected size
121
+ target_height, target_width = map(int, size.split('*'))
122
+ input_image = input_image.resize((target_width, target_height))
123
 
124
  # Calculate number of frames based on duration
125
  num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
 
182
 
183
  run_button = gr.Button("Generate Video", variant="primary")
184
 
185
+ # Add image upload handler
186
+ image_input.upload(
187
+ fn=handle_image_upload,
188
+ inputs=[image_input],
189
+ outputs=[size_input]
190
+ )
191
+
192
+ image_input.clear(
193
+ fn=handle_image_upload,
194
+ inputs=[image_input],
195
+ outputs=[size_input]
196
+ )
197
 
198
  example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
199
  gr.Examples(