Commit
Β·
3fa232c
1
Parent(s):
2f939c6
let's ditch streaming for now
Browse files- app.py +90 -128
- app_with_streaming.py +526 -0
app.py
CHANGED
@@ -3,6 +3,7 @@ import subprocess
|
|
3 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
4 |
|
5 |
import os
|
|
|
6 |
from huggingface_hub import snapshot_download, hf_hub_download
|
7 |
|
8 |
# Configuration for data paths
|
@@ -118,55 +119,71 @@ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
|
|
118 |
APP_STATE["torch_compile_applied"] = True
|
119 |
print("β
torch.compile applied to transformer")
|
120 |
|
121 |
-
def
|
122 |
"""
|
123 |
-
Convert frames directly to
|
124 |
|
125 |
Args:
|
126 |
frames: List of numpy arrays (HWC, RGB, uint8)
|
127 |
-
filepath: Output file path
|
128 |
fps: Frames per second
|
129 |
|
130 |
Returns:
|
131 |
-
|
132 |
"""
|
133 |
if not frames:
|
134 |
-
return
|
135 |
|
136 |
height, width = frames[0].shape[:2]
|
137 |
|
138 |
-
# Create
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
stream = container.add_stream('h264', rate=fps)
|
143 |
-
stream.width = width
|
144 |
-
stream.height = height
|
145 |
-
stream.pix_fmt = 'yuv420p'
|
146 |
-
|
147 |
-
# Optimize for low latency streaming
|
148 |
-
stream.options = {
|
149 |
-
'preset': 'ultrafast',
|
150 |
-
'tune': 'zerolatency',
|
151 |
-
'crf': '23',
|
152 |
-
'profile': 'baseline',
|
153 |
-
'level': '3.0'
|
154 |
-
}
|
155 |
|
156 |
try:
|
157 |
-
for
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
container.mux(packet)
|
|
|
|
|
|
|
162 |
|
163 |
-
|
164 |
-
|
|
|
|
|
|
|
165 |
|
166 |
finally:
|
167 |
-
|
|
|
|
|
168 |
|
169 |
-
return
|
170 |
|
171 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
172 |
if use_trt:
|
@@ -233,9 +250,9 @@ pipeline = CausalInferencePipeline(
|
|
233 |
pipeline.to(dtype=torch.float16).to(gpu)
|
234 |
|
235 |
@torch.no_grad()
|
236 |
-
def
|
237 |
"""
|
238 |
-
|
239 |
"""
|
240 |
# Add fallback values for None parameters
|
241 |
if seed is None:
|
@@ -253,7 +270,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
|
|
253 |
seed = random.randint(0, 2**32 - 1)
|
254 |
|
255 |
|
256 |
-
print(f"π¬
|
257 |
|
258 |
# Setup
|
259 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
@@ -279,7 +296,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
|
|
279 |
current_start_frame = 0
|
280 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
281 |
|
282 |
-
|
|
|
283 |
|
284 |
# Ensure temp directory exists
|
285 |
os.makedirs("gradio_tmp", exist_ok=True)
|
@@ -335,8 +353,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
|
|
335 |
|
336 |
print(f"π DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
|
337 |
|
338 |
-
# Process all frames from this block
|
339 |
-
all_frames_from_block = []
|
340 |
for frame_idx in range(pixels.shape[1]):
|
341 |
frame_tensor = pixels[0, frame_idx]
|
342 |
|
@@ -345,79 +362,35 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
|
|
345 |
frame_np = frame_np.to(torch.uint8).cpu().numpy()
|
346 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
347 |
|
348 |
-
|
349 |
-
|
350 |
|
351 |
-
|
352 |
-
|
353 |
-
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
354 |
-
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
355 |
-
|
356 |
-
# Cap at 100% to avoid going over
|
357 |
-
total_progress = min(total_progress, 100.0)
|
358 |
-
|
359 |
-
frame_status_html = (
|
360 |
-
f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
|
361 |
-
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
|
362 |
-
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
|
363 |
-
f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
|
364 |
-
f" </div>"
|
365 |
-
f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
|
366 |
-
f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%"
|
367 |
-
f" </p>"
|
368 |
-
f"</div>"
|
369 |
-
)
|
370 |
-
|
371 |
-
# Yield None for video but update status (frame-by-frame tracking)
|
372 |
-
yield None, frame_status_html
|
373 |
-
|
374 |
-
# Encode entire block as one chunk
|
375 |
-
if all_frames_from_block:
|
376 |
-
print(f"πΉ Encoding block {idx} with {len(all_frames_from_block)} frames")
|
377 |
-
|
378 |
-
try:
|
379 |
-
chunk_uuid = str(uuid.uuid4())[:8]
|
380 |
-
ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
|
381 |
-
ts_path = os.path.join("gradio_tmp", ts_filename)
|
382 |
-
|
383 |
-
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
384 |
-
|
385 |
-
# Calculate final progress for this block
|
386 |
-
total_progress = (idx + 1) / num_blocks * 100
|
387 |
-
|
388 |
-
# Yield the actual video chunk
|
389 |
-
yield ts_path, gr.update()
|
390 |
-
|
391 |
-
except Exception as e:
|
392 |
-
print(f"β οΈ Error encoding block {idx}: {e}")
|
393 |
-
import traceback
|
394 |
-
traceback.print_exc()
|
395 |
-
|
396 |
current_start_frame += current_num_frames
|
397 |
|
398 |
-
#
|
399 |
-
|
400 |
-
f"
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
print(f"β
PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
417 |
|
418 |
# --- Gradio UI Layout ---
|
419 |
-
with gr.Blocks(title="Wan2.1 1.3B Self-Forcing
|
420 |
-
gr.Markdown("
|
421 |
|
422 |
with gr.Row():
|
423 |
with gr.Column(scale=2):
|
@@ -428,7 +401,7 @@ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
|
|
428 |
lines=4,
|
429 |
value=""
|
430 |
)
|
431 |
-
start_btn = gr.Button("π¬
|
432 |
|
433 |
gr.Markdown("### βοΈ Settings")
|
434 |
with gr.Row():
|
@@ -477,31 +450,20 @@ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
|
|
477 |
)
|
478 |
|
479 |
with gr.Column(scale=3):
|
480 |
-
gr.Markdown("###
|
481 |
-
|
482 |
-
label="
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
|
487 |
-
show_label=False
|
488 |
-
)
|
489 |
-
|
490 |
-
status_display = gr.HTML(
|
491 |
-
value=(
|
492 |
-
"<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
|
493 |
-
"π¬ Ready to start streaming...<br>"
|
494 |
-
"<small>Configure your prompt and click 'Start Streaming'</small>"
|
495 |
-
"</div>"
|
496 |
-
),
|
497 |
-
label="Generation Status"
|
498 |
)
|
499 |
|
500 |
-
# Connect the generator to the
|
501 |
start_btn.click(
|
502 |
-
fn=
|
503 |
inputs=[prompt, seed, fps, width, height, duration],
|
504 |
-
outputs=[
|
505 |
)
|
506 |
|
507 |
# --- Launch App ---
|
@@ -511,9 +473,9 @@ if __name__ == "__main__":
|
|
511 |
shutil.rmtree("gradio_tmp")
|
512 |
os.makedirs("gradio_tmp", exist_ok=True)
|
513 |
|
514 |
-
print("π
|
515 |
print(f"π Temporary files will be stored in: gradio_tmp/")
|
516 |
-
print(f"π―
|
517 |
print(f"β‘ GPU acceleration: {gpu}")
|
518 |
|
519 |
demo.queue().launch(
|
|
|
3 |
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
4 |
|
5 |
import os
|
6 |
+
import base64
|
7 |
from huggingface_hub import snapshot_download, hf_hub_download
|
8 |
|
9 |
# Configuration for data paths
|
|
|
119 |
APP_STATE["torch_compile_applied"] = True
|
120 |
print("β
torch.compile applied to transformer")
|
121 |
|
122 |
+
def frames_to_mp4_base64(frames, fps = 15):
|
123 |
"""
|
124 |
+
Convert frames directly to base64 data URI using PyAV.
|
125 |
|
126 |
Args:
|
127 |
frames: List of numpy arrays (HWC, RGB, uint8)
|
|
|
128 |
fps: Frames per second
|
129 |
|
130 |
Returns:
|
131 |
+
Base64 data URI string for the MP4 video
|
132 |
"""
|
133 |
if not frames:
|
134 |
+
return "data:video/mp4;base64,"
|
135 |
|
136 |
height, width = frames[0].shape[:2]
|
137 |
|
138 |
+
# Create temporary file for MP4 encoding
|
139 |
+
temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
|
140 |
+
temp_filepath = temp_file.name
|
141 |
+
temp_file.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
try:
|
144 |
+
# Create container for MP4 format
|
145 |
+
container = av.open(temp_filepath, mode='w', format='mp4')
|
146 |
+
|
147 |
+
# Add video stream with fast settings
|
148 |
+
stream = container.add_stream('h264', rate=fps)
|
149 |
+
stream.width = width
|
150 |
+
stream.height = height
|
151 |
+
stream.pix_fmt = 'yuv420p'
|
152 |
+
|
153 |
+
# Optimize for low latency streaming
|
154 |
+
stream.options = {
|
155 |
+
'preset': 'ultrafast',
|
156 |
+
'tune': 'zerolatency',
|
157 |
+
'crf': '23',
|
158 |
+
'profile': 'baseline',
|
159 |
+
'level': '3.0'
|
160 |
+
}
|
161 |
+
|
162 |
+
try:
|
163 |
+
for frame_np in frames:
|
164 |
+
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
165 |
+
frame = frame.reformat(format=stream.pix_fmt)
|
166 |
+
for packet in stream.encode(frame):
|
167 |
+
container.mux(packet)
|
168 |
+
|
169 |
+
for packet in stream.encode():
|
170 |
container.mux(packet)
|
171 |
+
|
172 |
+
finally:
|
173 |
+
container.close()
|
174 |
|
175 |
+
# Read the MP4 file and encode to base64
|
176 |
+
with open(temp_filepath, 'rb') as f:
|
177 |
+
video_data = f.read()
|
178 |
+
base64_data = base64.b64encode(video_data).decode('utf-8')
|
179 |
+
return f"data:video/mp4;base64,{base64_data}"
|
180 |
|
181 |
finally:
|
182 |
+
# Clean up temporary file
|
183 |
+
if os.path.exists(temp_filepath):
|
184 |
+
os.unlink(temp_filepath)
|
185 |
|
186 |
+
return "data:video/mp4;base64,"
|
187 |
|
188 |
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
189 |
if use_trt:
|
|
|
250 |
pipeline.to(dtype=torch.float16).to(gpu)
|
251 |
|
252 |
@torch.no_grad()
|
253 |
+
def video_generation_handler(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, duration=5):
|
254 |
"""
|
255 |
+
Generate video and return a single MP4 file.
|
256 |
"""
|
257 |
# Add fallback values for None parameters
|
258 |
if seed is None:
|
|
|
270 |
seed = random.randint(0, 2**32 - 1)
|
271 |
|
272 |
|
273 |
+
print(f"π¬ video_generation_handler called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
|
274 |
|
275 |
# Setup
|
276 |
conditional_dict = text_encoder(text_prompts=[prompt])
|
|
|
296 |
current_start_frame = 0
|
297 |
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
298 |
|
299 |
+
all_frames = []
|
300 |
+
total_frames_generated = 0
|
301 |
|
302 |
# Ensure temp directory exists
|
303 |
os.makedirs("gradio_tmp", exist_ok=True)
|
|
|
353 |
|
354 |
print(f"π DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
|
355 |
|
356 |
+
# Process all frames from this block and add to main collection
|
|
|
357 |
for frame_idx in range(pixels.shape[1]):
|
358 |
frame_tensor = pixels[0, frame_idx]
|
359 |
|
|
|
362 |
frame_np = frame_np.to(torch.uint8).cpu().numpy()
|
363 |
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
364 |
|
365 |
+
all_frames.append(frame_np)
|
366 |
+
total_frames_generated += 1
|
367 |
|
368 |
+
print(f"π¦ Block {idx+1}/{num_blocks}, Frame {frame_idx+1}/{pixels.shape[1]} - Total frames: {total_frames_generated}")
|
369 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
current_start_frame += current_num_frames
|
371 |
|
372 |
+
# Generate final MP4 as base64 data URI
|
373 |
+
if all_frames:
|
374 |
+
print(f"πΉ Encoding final MP4 with {len(all_frames)} frames")
|
375 |
+
|
376 |
+
try:
|
377 |
+
base64_data_uri = frames_to_mp4_base64(all_frames, fps)
|
378 |
+
|
379 |
+
print(f"β
Video generation complete! {total_frames_generated} frames encoded to base64 data URI")
|
380 |
+
return base64_data_uri
|
381 |
+
|
382 |
+
except Exception as e:
|
383 |
+
print(f"β οΈ Error encoding final video: {e}")
|
384 |
+
import traceback
|
385 |
+
traceback.print_exc()
|
386 |
+
return "data:video/mp4;base64,"
|
387 |
+
else:
|
388 |
+
print("β οΈ No frames generated")
|
389 |
+
return "data:video/mp4;base64,"
|
|
|
390 |
|
391 |
# --- Gradio UI Layout ---
|
392 |
+
with gr.Blocks(title="Wan2.1 1.3B Self-Forcing demo") as demo:
|
393 |
+
gr.Markdown("Video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009))")
|
394 |
|
395 |
with gr.Row():
|
396 |
with gr.Column(scale=2):
|
|
|
401 |
lines=4,
|
402 |
value=""
|
403 |
)
|
404 |
+
start_btn = gr.Button("π¬ Generate Video", variant="primary", size="lg")
|
405 |
|
406 |
gr.Markdown("### βοΈ Settings")
|
407 |
with gr.Row():
|
|
|
450 |
)
|
451 |
|
452 |
with gr.Column(scale=3):
|
453 |
+
gr.Markdown("### π¬ Generated Video (Base64)")
|
454 |
+
video_output = gr.Textbox(
|
455 |
+
label="Base64 Video Data URI",
|
456 |
+
lines=10,
|
457 |
+
max_lines=20,
|
458 |
+
show_copy_button=True,
|
459 |
+
placeholder="Generated video will appear here as base64 data URI..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
)
|
461 |
|
462 |
+
# Connect the generator to the text output
|
463 |
start_btn.click(
|
464 |
+
fn=video_generation_handler,
|
465 |
inputs=[prompt, seed, fps, width, height, duration],
|
466 |
+
outputs=[video_output]
|
467 |
)
|
468 |
|
469 |
# --- Launch App ---
|
|
|
473 |
shutil.rmtree("gradio_tmp")
|
474 |
os.makedirs("gradio_tmp", exist_ok=True)
|
475 |
|
476 |
+
print("π Video Generation Node (default engine is Wan2.1 1.3B Self-Forcing)")
|
477 |
print(f"π Temporary files will be stored in: gradio_tmp/")
|
478 |
+
print(f"π― Video encoding: PyAV (MP4/H.264)")
|
479 |
print(f"β‘ GPU acceleration: {gpu}")
|
480 |
|
481 |
demo.queue().launch(
|
app_with_streaming.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
# not sure why it works in the original space but says "pip not found" in mine
|
3 |
+
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
4 |
+
|
5 |
+
import os
|
6 |
+
from huggingface_hub import snapshot_download, hf_hub_download
|
7 |
+
|
8 |
+
# Configuration for data paths
|
9 |
+
DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.'))
|
10 |
+
WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models')
|
11 |
+
OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models')
|
12 |
+
|
13 |
+
snapshot_download(
|
14 |
+
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
|
15 |
+
local_dir=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B"),
|
16 |
+
local_dir_use_symlinks=False,
|
17 |
+
resume_download=True,
|
18 |
+
repo_type="model"
|
19 |
+
)
|
20 |
+
|
21 |
+
hf_hub_download(
|
22 |
+
repo_id="gdhe17/Self-Forcing",
|
23 |
+
filename="checkpoints/self_forcing_dmd.pt",
|
24 |
+
local_dir=OTHER_MODELS_PATH,
|
25 |
+
local_dir_use_symlinks=False
|
26 |
+
)
|
27 |
+
import re
|
28 |
+
import random
|
29 |
+
import argparse
|
30 |
+
import hashlib
|
31 |
+
import urllib.request
|
32 |
+
import time
|
33 |
+
from PIL import Image
|
34 |
+
import torch
|
35 |
+
import gradio as gr
|
36 |
+
from omegaconf import OmegaConf
|
37 |
+
from tqdm import tqdm
|
38 |
+
import imageio
|
39 |
+
import av
|
40 |
+
import uuid
|
41 |
+
import tempfile
|
42 |
+
import shutil
|
43 |
+
from pathlib import Path
|
44 |
+
from typing import Dict, Any, List, Optional, Tuple, Union
|
45 |
+
|
46 |
+
from pipeline import CausalInferencePipeline
|
47 |
+
from demo_utils.constant import ZERO_VAE_CACHE
|
48 |
+
from demo_utils.vae_block3 import VAEDecoderWrapper
|
49 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
|
50 |
+
|
51 |
+
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM #, BitsAndBytesConfig
|
52 |
+
import numpy as np
|
53 |
+
|
54 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
55 |
+
|
56 |
+
DEFAULT_WIDTH = 832
|
57 |
+
DEFAULT_HEIGHT = 480
|
58 |
+
|
59 |
+
# --- Argument Parsing ---
|
60 |
+
parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
|
61 |
+
parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
|
62 |
+
parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
|
63 |
+
parser.add_argument("--checkpoint_path", type=str, default=os.path.join(OTHER_MODELS_PATH, 'checkpoints', 'self_forcing_dmd.pt'), help="Path to the model checkpoint.")
|
64 |
+
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
|
65 |
+
parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
|
66 |
+
parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
|
67 |
+
parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.")
|
68 |
+
args = parser.parse_args()
|
69 |
+
|
70 |
+
gpu = "cuda"
|
71 |
+
|
72 |
+
try:
|
73 |
+
config = OmegaConf.load(args.config_path)
|
74 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
75 |
+
config = OmegaConf.merge(default_config, config)
|
76 |
+
except FileNotFoundError as e:
|
77 |
+
print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.")
|
78 |
+
exit(1)
|
79 |
+
|
80 |
+
# Initialize Models
|
81 |
+
print("Initializing models...")
|
82 |
+
text_encoder = WanTextEncoder()
|
83 |
+
transformer = WanDiffusionWrapper(is_causal=True)
|
84 |
+
|
85 |
+
try:
|
86 |
+
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
|
87 |
+
transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
|
88 |
+
except FileNotFoundError as e:
|
89 |
+
print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.")
|
90 |
+
exit(1)
|
91 |
+
|
92 |
+
text_encoder.eval().to(dtype=torch.float16).requires_grad_(False)
|
93 |
+
transformer.eval().to(dtype=torch.float16).requires_grad_(False)
|
94 |
+
|
95 |
+
text_encoder.to(gpu)
|
96 |
+
transformer.to(gpu)
|
97 |
+
|
98 |
+
APP_STATE = {
|
99 |
+
"torch_compile_applied": False,
|
100 |
+
"fp8_applied": False,
|
101 |
+
"current_use_taehv": False,
|
102 |
+
"current_vae_decoder": None,
|
103 |
+
}
|
104 |
+
|
105 |
+
# I've tried to enable it, but I didn't notice a significant performance improvement..
|
106 |
+
ENABLE_TORCH_COMPILATION = False
|
107 |
+
|
108 |
+
# βdefaultβ: The default mode, used when no mode parameter is specified. It provides a good balance between performance and overhead.
|
109 |
+
# βreduce-overheadβ: Minimizes Python-related overhead using CUDA graphs. However, it may increase memory usage.
|
110 |
+
# βmax-autotuneβ: Uses Triton or template-based matrix multiplications on supported devices. It takes longer to compile but optimizes for the fastest possible execution. On GPUs it enables CUDA graphs by default.
|
111 |
+
# βmax-autotune-no-cudagraphsβ: Similar to βmax-autotuneβ, but without CUDA graphs.
|
112 |
+
TORCH_COMPILATION_MODE = "default"
|
113 |
+
|
114 |
+
# Apply torch.compile for maximum performance
|
115 |
+
if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
|
116 |
+
print("π Applying torch.compile for speed optimization...")
|
117 |
+
transformer.compile(mode=TORCH_COMPILATION_MODE)
|
118 |
+
APP_STATE["torch_compile_applied"] = True
|
119 |
+
print("β
torch.compile applied to transformer")
|
120 |
+
|
121 |
+
def frames_to_ts_file(frames, filepath, fps = 15):
|
122 |
+
"""
|
123 |
+
Convert frames directly to .ts file using PyAV.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
frames: List of numpy arrays (HWC, RGB, uint8)
|
127 |
+
filepath: Output file path
|
128 |
+
fps: Frames per second
|
129 |
+
|
130 |
+
Returns:
|
131 |
+
The filepath of the created file
|
132 |
+
"""
|
133 |
+
if not frames:
|
134 |
+
return filepath
|
135 |
+
|
136 |
+
height, width = frames[0].shape[:2]
|
137 |
+
|
138 |
+
# Create container for MPEG-TS format
|
139 |
+
container = av.open(filepath, mode='w', format='mpegts')
|
140 |
+
|
141 |
+
# Add video stream with optimized settings for streaming
|
142 |
+
stream = container.add_stream('h264', rate=fps)
|
143 |
+
stream.width = width
|
144 |
+
stream.height = height
|
145 |
+
stream.pix_fmt = 'yuv420p'
|
146 |
+
|
147 |
+
# Optimize for low latency streaming
|
148 |
+
stream.options = {
|
149 |
+
'preset': 'ultrafast',
|
150 |
+
'tune': 'zerolatency',
|
151 |
+
'crf': '23',
|
152 |
+
'profile': 'baseline',
|
153 |
+
'level': '3.0'
|
154 |
+
}
|
155 |
+
|
156 |
+
try:
|
157 |
+
for frame_np in frames:
|
158 |
+
frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
|
159 |
+
frame = frame.reformat(format=stream.pix_fmt)
|
160 |
+
for packet in stream.encode(frame):
|
161 |
+
container.mux(packet)
|
162 |
+
|
163 |
+
for packet in stream.encode():
|
164 |
+
container.mux(packet)
|
165 |
+
|
166 |
+
finally:
|
167 |
+
container.close()
|
168 |
+
|
169 |
+
return filepath
|
170 |
+
|
171 |
+
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
172 |
+
if use_trt:
|
173 |
+
from demo_utils.vae import VAETRTWrapper
|
174 |
+
print("Initializing TensorRT VAE Decoder...")
|
175 |
+
vae_decoder = VAETRTWrapper()
|
176 |
+
APP_STATE["current_use_taehv"] = False
|
177 |
+
elif use_taehv:
|
178 |
+
print("Initializing TAEHV VAE Decoder...")
|
179 |
+
from demo_utils.taehv import TAEHV
|
180 |
+
taehv_checkpoint_path = "checkpoints/taew2_1.pth"
|
181 |
+
if not os.path.exists(taehv_checkpoint_path):
|
182 |
+
print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
|
183 |
+
os.makedirs("checkpoints", exist_ok=True)
|
184 |
+
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
|
185 |
+
try:
|
186 |
+
urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
|
187 |
+
except Exception as e:
|
188 |
+
raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
|
189 |
+
|
190 |
+
class DotDict(dict): __getattr__ = dict.get
|
191 |
+
|
192 |
+
class TAEHVDiffusersWrapper(torch.nn.Module):
|
193 |
+
def __init__(self):
|
194 |
+
super().__init__()
|
195 |
+
self.dtype = torch.float16
|
196 |
+
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
|
197 |
+
self.config = DotDict(scaling_factor=1.0)
|
198 |
+
def decode(self, latents, return_dict=None):
|
199 |
+
return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1)
|
200 |
+
|
201 |
+
vae_decoder = TAEHVDiffusersWrapper()
|
202 |
+
APP_STATE["current_use_taehv"] = True
|
203 |
+
else:
|
204 |
+
print("Initializing Default VAE Decoder...")
|
205 |
+
vae_decoder = VAEDecoderWrapper()
|
206 |
+
try:
|
207 |
+
vae_state_dict = torch.load(os.path.join(WAN_MODELS_PATH, 'Wan2.1-T2V-1.3B', 'Wan2.1_VAE.pth'), map_location="cpu")
|
208 |
+
decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
|
209 |
+
vae_decoder.load_state_dict(decoder_state_dict)
|
210 |
+
except FileNotFoundError:
|
211 |
+
print("Warning: Default VAE weights not found.")
|
212 |
+
APP_STATE["current_use_taehv"] = False
|
213 |
+
|
214 |
+
vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu)
|
215 |
+
|
216 |
+
# Apply torch.compile to VAE decoder if enabled (following demo.py pattern)
|
217 |
+
if APP_STATE["torch_compile_applied"] and not use_taehv and not use_trt:
|
218 |
+
print("π Applying torch.compile to VAE decoder...")
|
219 |
+
vae_decoder.compile(mode=TORCH_COMPILATION_MODE)
|
220 |
+
print("β
torch.compile applied to VAE decoder")
|
221 |
+
|
222 |
+
APP_STATE["current_vae_decoder"] = vae_decoder
|
223 |
+
print(f"β
VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
|
224 |
+
|
225 |
+
# Initialize with default VAE
|
226 |
+
initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
|
227 |
+
|
228 |
+
pipeline = CausalInferencePipeline(
|
229 |
+
config, device=gpu, generator=transformer, text_encoder=text_encoder,
|
230 |
+
vae=APP_STATE["current_vae_decoder"]
|
231 |
+
)
|
232 |
+
|
233 |
+
pipeline.to(dtype=torch.float16).to(gpu)
|
234 |
+
|
235 |
+
@torch.no_grad()
|
236 |
+
def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, duration=5):
|
237 |
+
"""
|
238 |
+
Generator function that yields .ts video chunks using PyAV for streaming.
|
239 |
+
"""
|
240 |
+
# Add fallback values for None parameters
|
241 |
+
if seed is None:
|
242 |
+
seed = -1
|
243 |
+
if fps is None:
|
244 |
+
fps = 15
|
245 |
+
if width is None:
|
246 |
+
width = DEFAULT_WIDTH
|
247 |
+
if height is None:
|
248 |
+
height = DEFAULT_HEIGHT
|
249 |
+
if duration is None:
|
250 |
+
duration = 5
|
251 |
+
|
252 |
+
if seed == -1:
|
253 |
+
seed = random.randint(0, 2**32 - 1)
|
254 |
+
|
255 |
+
|
256 |
+
print(f"π¬ video_generation_handler_streaming called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
|
257 |
+
|
258 |
+
# Setup
|
259 |
+
conditional_dict = text_encoder(text_prompts=[prompt])
|
260 |
+
for key, value in conditional_dict.items():
|
261 |
+
conditional_dict[key] = value.to(dtype=torch.float16)
|
262 |
+
|
263 |
+
rnd = torch.Generator(gpu).manual_seed(int(seed))
|
264 |
+
pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
|
265 |
+
pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
|
266 |
+
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
|
267 |
+
|
268 |
+
vae_cache, latents_cache = None, None
|
269 |
+
if not APP_STATE["current_use_taehv"] and not args.trt:
|
270 |
+
vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
|
271 |
+
|
272 |
+
# Calculate number of blocks based on duration
|
273 |
+
# Current setup generates approximately 5 seconds with 7 blocks
|
274 |
+
# So we scale proportionally
|
275 |
+
base_duration = 5.0 # seconds
|
276 |
+
base_blocks = 8
|
277 |
+
num_blocks = max(1, int(base_blocks * duration / base_duration))
|
278 |
+
|
279 |
+
current_start_frame = 0
|
280 |
+
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
281 |
+
|
282 |
+
total_frames_yielded = 0
|
283 |
+
|
284 |
+
# Ensure temp directory exists
|
285 |
+
os.makedirs("gradio_tmp", exist_ok=True)
|
286 |
+
|
287 |
+
# Generation loop
|
288 |
+
for idx, current_num_frames in enumerate(all_num_frames):
|
289 |
+
print(f"π¦ Processing block {idx+1}/{num_blocks}")
|
290 |
+
|
291 |
+
noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
|
292 |
+
|
293 |
+
# Denoising steps
|
294 |
+
for step_idx, current_timestep in enumerate(pipeline.denoising_step_list):
|
295 |
+
timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
|
296 |
+
_, denoised_pred = pipeline.generator(
|
297 |
+
noisy_image_or_video=noisy_input, conditional_dict=conditional_dict,
|
298 |
+
timestep=timestep, kv_cache=pipeline.kv_cache1,
|
299 |
+
crossattn_cache=pipeline.crossattn_cache,
|
300 |
+
current_start=current_start_frame * pipeline.frame_seq_length
|
301 |
+
)
|
302 |
+
if step_idx < len(pipeline.denoising_step_list) - 1:
|
303 |
+
next_timestep = pipeline.denoising_step_list[step_idx + 1]
|
304 |
+
noisy_input = pipeline.scheduler.add_noise(
|
305 |
+
denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)),
|
306 |
+
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
|
307 |
+
).unflatten(0, denoised_pred.shape[:2])
|
308 |
+
|
309 |
+
if idx < len(all_num_frames) - 1:
|
310 |
+
pipeline.generator(
|
311 |
+
noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict,
|
312 |
+
timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1,
|
313 |
+
crossattn_cache=pipeline.crossattn_cache,
|
314 |
+
current_start=current_start_frame * pipeline.frame_seq_length,
|
315 |
+
)
|
316 |
+
|
317 |
+
# Decode to pixels
|
318 |
+
if args.trt:
|
319 |
+
pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache)
|
320 |
+
elif APP_STATE["current_use_taehv"]:
|
321 |
+
if latents_cache is None:
|
322 |
+
latents_cache = denoised_pred
|
323 |
+
else:
|
324 |
+
denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
|
325 |
+
latents_cache = denoised_pred[:, -3:]
|
326 |
+
pixels = pipeline.vae.decode(denoised_pred)
|
327 |
+
else:
|
328 |
+
pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache)
|
329 |
+
|
330 |
+
# Handle frame skipping
|
331 |
+
if idx == 0 and not args.trt:
|
332 |
+
pixels = pixels[:, 3:]
|
333 |
+
elif APP_STATE["current_use_taehv"] and idx > 0:
|
334 |
+
pixels = pixels[:, 12:]
|
335 |
+
|
336 |
+
print(f"π DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
|
337 |
+
|
338 |
+
# Process all frames from this block at once
|
339 |
+
all_frames_from_block = []
|
340 |
+
for frame_idx in range(pixels.shape[1]):
|
341 |
+
frame_tensor = pixels[0, frame_idx]
|
342 |
+
|
343 |
+
# Convert to numpy (HWC, RGB, uint8)
|
344 |
+
frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
|
345 |
+
frame_np = frame_np.to(torch.uint8).cpu().numpy()
|
346 |
+
frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
|
347 |
+
|
348 |
+
all_frames_from_block.append(frame_np)
|
349 |
+
total_frames_yielded += 1
|
350 |
+
|
351 |
+
# Yield status update for each frame (cute tracking!)
|
352 |
+
blocks_completed = idx
|
353 |
+
current_block_progress = (frame_idx + 1) / pixels.shape[1]
|
354 |
+
total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
|
355 |
+
|
356 |
+
# Cap at 100% to avoid going over
|
357 |
+
total_progress = min(total_progress, 100.0)
|
358 |
+
|
359 |
+
frame_status_html = (
|
360 |
+
f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
|
361 |
+
f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
|
362 |
+
f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
|
363 |
+
f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
|
364 |
+
f" </div>"
|
365 |
+
f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
|
366 |
+
f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%"
|
367 |
+
f" </p>"
|
368 |
+
f"</div>"
|
369 |
+
)
|
370 |
+
|
371 |
+
# Yield None for video but update status (frame-by-frame tracking)
|
372 |
+
yield None, frame_status_html
|
373 |
+
|
374 |
+
# Encode entire block as one chunk
|
375 |
+
if all_frames_from_block:
|
376 |
+
print(f"πΉ Encoding block {idx} with {len(all_frames_from_block)} frames")
|
377 |
+
|
378 |
+
try:
|
379 |
+
chunk_uuid = str(uuid.uuid4())[:8]
|
380 |
+
ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
|
381 |
+
ts_path = os.path.join("gradio_tmp", ts_filename)
|
382 |
+
|
383 |
+
frames_to_ts_file(all_frames_from_block, ts_path, fps)
|
384 |
+
|
385 |
+
# Calculate final progress for this block
|
386 |
+
total_progress = (idx + 1) / num_blocks * 100
|
387 |
+
|
388 |
+
# Yield the actual video chunk
|
389 |
+
yield ts_path, gr.update()
|
390 |
+
|
391 |
+
except Exception as e:
|
392 |
+
print(f"β οΈ Error encoding block {idx}: {e}")
|
393 |
+
import traceback
|
394 |
+
traceback.print_exc()
|
395 |
+
|
396 |
+
current_start_frame += current_num_frames
|
397 |
+
|
398 |
+
# Final completion status
|
399 |
+
final_status_html = (
|
400 |
+
f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
|
401 |
+
f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
|
402 |
+
f" <span style='font-size: 24px; margin-right: 12px;'>π</span>"
|
403 |
+
f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
|
404 |
+
f" </div>"
|
405 |
+
f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
|
406 |
+
f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
|
407 |
+
f" π Generated {total_frames_yielded} frames across {num_blocks} blocks"
|
408 |
+
f" </p>"
|
409 |
+
f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
|
410 |
+
f" π¬ Playback: {fps} FPS β’ π Format: MPEG-TS/H.264"
|
411 |
+
f" </p>"
|
412 |
+
f" </div>"
|
413 |
+
f"</div>"
|
414 |
+
)
|
415 |
+
yield None, final_status_html
|
416 |
+
print(f"β
PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
|
417 |
+
|
418 |
+
# --- Gradio UI Layout ---
|
419 |
+
with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
|
420 |
+
gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
|
421 |
+
|
422 |
+
with gr.Row():
|
423 |
+
with gr.Column(scale=2):
|
424 |
+
with gr.Group():
|
425 |
+
prompt = gr.Textbox(
|
426 |
+
label="Prompt",
|
427 |
+
placeholder="A stylish woman walks down a Tokyo street...",
|
428 |
+
lines=4,
|
429 |
+
value=""
|
430 |
+
)
|
431 |
+
start_btn = gr.Button("π¬ Start Streaming", variant="primary", size="lg")
|
432 |
+
|
433 |
+
gr.Markdown("### βοΈ Settings")
|
434 |
+
with gr.Row():
|
435 |
+
seed = gr.Number(
|
436 |
+
label="Seed",
|
437 |
+
value=-1,
|
438 |
+
info="Use -1 for random seed",
|
439 |
+
precision=0
|
440 |
+
)
|
441 |
+
fps = gr.Slider(
|
442 |
+
label="Playback FPS",
|
443 |
+
minimum=1,
|
444 |
+
maximum=30,
|
445 |
+
value=args.fps,
|
446 |
+
step=1,
|
447 |
+
visible=False,
|
448 |
+
info="Frames per second for playback"
|
449 |
+
)
|
450 |
+
|
451 |
+
with gr.Row():
|
452 |
+
duration = gr.Slider(
|
453 |
+
label="Duration (seconds)",
|
454 |
+
minimum=1,
|
455 |
+
maximum=5,
|
456 |
+
value=3,
|
457 |
+
step=1,
|
458 |
+
info="Video duration in seconds"
|
459 |
+
)
|
460 |
+
|
461 |
+
with gr.Row():
|
462 |
+
width = gr.Slider(
|
463 |
+
label="Width",
|
464 |
+
minimum=224,
|
465 |
+
maximum=832,
|
466 |
+
value=DEFAULT_WIDTH,
|
467 |
+
step=8,
|
468 |
+
info="Video width in pixels (8px steps)"
|
469 |
+
)
|
470 |
+
height = gr.Slider(
|
471 |
+
label="Height",
|
472 |
+
minimum=224,
|
473 |
+
maximum=832,
|
474 |
+
value=DEFAULT_HEIGHT,
|
475 |
+
step=8,
|
476 |
+
info="Video height in pixels (8px steps)"
|
477 |
+
)
|
478 |
+
|
479 |
+
with gr.Column(scale=3):
|
480 |
+
gr.Markdown("### πΊ Video Stream")
|
481 |
+
streaming_video = gr.Video(
|
482 |
+
label="Live Stream",
|
483 |
+
streaming=True,
|
484 |
+
loop=True,
|
485 |
+
height=400,
|
486 |
+
autoplay=True,
|
487 |
+
show_label=False
|
488 |
+
)
|
489 |
+
|
490 |
+
status_display = gr.HTML(
|
491 |
+
value=(
|
492 |
+
"<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
|
493 |
+
"π¬ Ready to start streaming...<br>"
|
494 |
+
"<small>Configure your prompt and click 'Start Streaming'</small>"
|
495 |
+
"</div>"
|
496 |
+
),
|
497 |
+
label="Generation Status"
|
498 |
+
)
|
499 |
+
|
500 |
+
# Connect the generator to the streaming video
|
501 |
+
start_btn.click(
|
502 |
+
fn=video_generation_handler_streaming,
|
503 |
+
inputs=[prompt, seed, fps, width, height, duration],
|
504 |
+
outputs=[streaming_video, status_display]
|
505 |
+
)
|
506 |
+
|
507 |
+
# --- Launch App ---
|
508 |
+
if __name__ == "__main__":
|
509 |
+
if os.path.exists("gradio_tmp"):
|
510 |
+
import shutil
|
511 |
+
shutil.rmtree("gradio_tmp")
|
512 |
+
os.makedirs("gradio_tmp", exist_ok=True)
|
513 |
+
|
514 |
+
print("π Clapper Rendering Node (default engine is Wan2.1 1.3B Self-Forcing)")
|
515 |
+
print(f"π Temporary files will be stored in: gradio_tmp/")
|
516 |
+
print(f"π― Chunk encoding: PyAV (MPEG-TS/H.264)")
|
517 |
+
print(f"β‘ GPU acceleration: {gpu}")
|
518 |
+
|
519 |
+
demo.queue().launch(
|
520 |
+
server_name=args.host,
|
521 |
+
server_port=args.port,
|
522 |
+
share=args.share,
|
523 |
+
show_error=True,
|
524 |
+
max_threads=40,
|
525 |
+
mcp_server=True
|
526 |
+
)
|