jbilcke-hf HF Staff commited on
Commit
3fa232c
Β·
1 Parent(s): 2f939c6

let's ditch streaming for now

Browse files
Files changed (2) hide show
  1. app.py +90 -128
  2. app_with_streaming.py +526 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import subprocess
3
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
 
5
  import os
 
6
  from huggingface_hub import snapshot_download, hf_hub_download
7
 
8
  # Configuration for data paths
@@ -118,55 +119,71 @@ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
118
  APP_STATE["torch_compile_applied"] = True
119
  print("βœ… torch.compile applied to transformer")
120
 
121
- def frames_to_ts_file(frames, filepath, fps = 15):
122
  """
123
- Convert frames directly to .ts file using PyAV.
124
 
125
  Args:
126
  frames: List of numpy arrays (HWC, RGB, uint8)
127
- filepath: Output file path
128
  fps: Frames per second
129
 
130
  Returns:
131
- The filepath of the created file
132
  """
133
  if not frames:
134
- return filepath
135
 
136
  height, width = frames[0].shape[:2]
137
 
138
- # Create container for MPEG-TS format
139
- container = av.open(filepath, mode='w', format='mpegts')
140
-
141
- # Add video stream with optimized settings for streaming
142
- stream = container.add_stream('h264', rate=fps)
143
- stream.width = width
144
- stream.height = height
145
- stream.pix_fmt = 'yuv420p'
146
-
147
- # Optimize for low latency streaming
148
- stream.options = {
149
- 'preset': 'ultrafast',
150
- 'tune': 'zerolatency',
151
- 'crf': '23',
152
- 'profile': 'baseline',
153
- 'level': '3.0'
154
- }
155
 
156
  try:
157
- for frame_np in frames:
158
- frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
159
- frame = frame.reformat(format=stream.pix_fmt)
160
- for packet in stream.encode(frame):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  container.mux(packet)
 
 
 
162
 
163
- for packet in stream.encode():
164
- container.mux(packet)
 
 
 
165
 
166
  finally:
167
- container.close()
 
 
168
 
169
- return filepath
170
 
171
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
172
  if use_trt:
@@ -233,9 +250,9 @@ pipeline = CausalInferencePipeline(
233
  pipeline.to(dtype=torch.float16).to(gpu)
234
 
235
  @torch.no_grad()
236
- def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, duration=5):
237
  """
238
- Generator function that yields .ts video chunks using PyAV for streaming.
239
  """
240
  # Add fallback values for None parameters
241
  if seed is None:
@@ -253,7 +270,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
253
  seed = random.randint(0, 2**32 - 1)
254
 
255
 
256
- print(f"🎬 video_generation_handler_streaming called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
257
 
258
  # Setup
259
  conditional_dict = text_encoder(text_prompts=[prompt])
@@ -279,7 +296,8 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
279
  current_start_frame = 0
280
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
281
 
282
- total_frames_yielded = 0
 
283
 
284
  # Ensure temp directory exists
285
  os.makedirs("gradio_tmp", exist_ok=True)
@@ -335,8 +353,7 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
335
 
336
  print(f"πŸ” DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
337
 
338
- # Process all frames from this block at once
339
- all_frames_from_block = []
340
  for frame_idx in range(pixels.shape[1]):
341
  frame_tensor = pixels[0, frame_idx]
342
 
@@ -345,79 +362,35 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WI
345
  frame_np = frame_np.to(torch.uint8).cpu().numpy()
346
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
347
 
348
- all_frames_from_block.append(frame_np)
349
- total_frames_yielded += 1
350
 
351
- # Yield status update for each frame (cute tracking!)
352
- blocks_completed = idx
353
- current_block_progress = (frame_idx + 1) / pixels.shape[1]
354
- total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
355
-
356
- # Cap at 100% to avoid going over
357
- total_progress = min(total_progress, 100.0)
358
-
359
- frame_status_html = (
360
- f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
361
- f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
362
- f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
363
- f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
364
- f" </div>"
365
- f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
366
- f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%"
367
- f" </p>"
368
- f"</div>"
369
- )
370
-
371
- # Yield None for video but update status (frame-by-frame tracking)
372
- yield None, frame_status_html
373
-
374
- # Encode entire block as one chunk
375
- if all_frames_from_block:
376
- print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
377
-
378
- try:
379
- chunk_uuid = str(uuid.uuid4())[:8]
380
- ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
381
- ts_path = os.path.join("gradio_tmp", ts_filename)
382
-
383
- frames_to_ts_file(all_frames_from_block, ts_path, fps)
384
-
385
- # Calculate final progress for this block
386
- total_progress = (idx + 1) / num_blocks * 100
387
-
388
- # Yield the actual video chunk
389
- yield ts_path, gr.update()
390
-
391
- except Exception as e:
392
- print(f"⚠️ Error encoding block {idx}: {e}")
393
- import traceback
394
- traceback.print_exc()
395
-
396
  current_start_frame += current_num_frames
397
 
398
- # Final completion status
399
- final_status_html = (
400
- f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
401
- f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
402
- f" <span style='font-size: 24px; margin-right: 12px;'>πŸŽ‰</span>"
403
- f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
404
- f" </div>"
405
- f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
406
- f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
407
- f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
408
- f" </p>"
409
- f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
410
- f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: MPEG-TS/H.264"
411
- f" </p>"
412
- f" </div>"
413
- f"</div>"
414
- )
415
- yield None, final_status_html
416
- print(f"βœ… PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
417
 
418
  # --- Gradio UI Layout ---
419
- with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
420
- gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
421
 
422
  with gr.Row():
423
  with gr.Column(scale=2):
@@ -428,7 +401,7 @@ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
428
  lines=4,
429
  value=""
430
  )
431
- start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
432
 
433
  gr.Markdown("### βš™οΈ Settings")
434
  with gr.Row():
@@ -477,31 +450,20 @@ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
477
  )
478
 
479
  with gr.Column(scale=3):
480
- gr.Markdown("### πŸ“Ί Video Stream")
481
- streaming_video = gr.Video(
482
- label="Live Stream",
483
- streaming=True,
484
- loop=True,
485
- height=400,
486
- autoplay=True,
487
- show_label=False
488
- )
489
-
490
- status_display = gr.HTML(
491
- value=(
492
- "<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
493
- "🎬 Ready to start streaming...<br>"
494
- "<small>Configure your prompt and click 'Start Streaming'</small>"
495
- "</div>"
496
- ),
497
- label="Generation Status"
498
  )
499
 
500
- # Connect the generator to the streaming video
501
  start_btn.click(
502
- fn=video_generation_handler_streaming,
503
  inputs=[prompt, seed, fps, width, height, duration],
504
- outputs=[streaming_video, status_display]
505
  )
506
 
507
  # --- Launch App ---
@@ -511,9 +473,9 @@ if __name__ == "__main__":
511
  shutil.rmtree("gradio_tmp")
512
  os.makedirs("gradio_tmp", exist_ok=True)
513
 
514
- print("πŸš€ Clapper Rendering Node (default engine is Wan2.1 1.3B Self-Forcing)")
515
  print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
516
- print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
517
  print(f"⚑ GPU acceleration: {gpu}")
518
 
519
  demo.queue().launch(
 
3
  #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
 
5
  import os
6
+ import base64
7
  from huggingface_hub import snapshot_download, hf_hub_download
8
 
9
  # Configuration for data paths
 
119
  APP_STATE["torch_compile_applied"] = True
120
  print("βœ… torch.compile applied to transformer")
121
 
122
+ def frames_to_mp4_base64(frames, fps = 15):
123
  """
124
+ Convert frames directly to base64 data URI using PyAV.
125
 
126
  Args:
127
  frames: List of numpy arrays (HWC, RGB, uint8)
 
128
  fps: Frames per second
129
 
130
  Returns:
131
+ Base64 data URI string for the MP4 video
132
  """
133
  if not frames:
134
+ return "data:video/mp4;base64,"
135
 
136
  height, width = frames[0].shape[:2]
137
 
138
+ # Create temporary file for MP4 encoding
139
+ temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
140
+ temp_filepath = temp_file.name
141
+ temp_file.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
  try:
144
+ # Create container for MP4 format
145
+ container = av.open(temp_filepath, mode='w', format='mp4')
146
+
147
+ # Add video stream with fast settings
148
+ stream = container.add_stream('h264', rate=fps)
149
+ stream.width = width
150
+ stream.height = height
151
+ stream.pix_fmt = 'yuv420p'
152
+
153
+ # Optimize for low latency streaming
154
+ stream.options = {
155
+ 'preset': 'ultrafast',
156
+ 'tune': 'zerolatency',
157
+ 'crf': '23',
158
+ 'profile': 'baseline',
159
+ 'level': '3.0'
160
+ }
161
+
162
+ try:
163
+ for frame_np in frames:
164
+ frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
165
+ frame = frame.reformat(format=stream.pix_fmt)
166
+ for packet in stream.encode(frame):
167
+ container.mux(packet)
168
+
169
+ for packet in stream.encode():
170
  container.mux(packet)
171
+
172
+ finally:
173
+ container.close()
174
 
175
+ # Read the MP4 file and encode to base64
176
+ with open(temp_filepath, 'rb') as f:
177
+ video_data = f.read()
178
+ base64_data = base64.b64encode(video_data).decode('utf-8')
179
+ return f"data:video/mp4;base64,{base64_data}"
180
 
181
  finally:
182
+ # Clean up temporary file
183
+ if os.path.exists(temp_filepath):
184
+ os.unlink(temp_filepath)
185
 
186
+ return "data:video/mp4;base64,"
187
 
188
  def initialize_vae_decoder(use_taehv=False, use_trt=False):
189
  if use_trt:
 
250
  pipeline.to(dtype=torch.float16).to(gpu)
251
 
252
  @torch.no_grad()
253
+ def video_generation_handler(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, duration=5):
254
  """
255
+ Generate video and return a single MP4 file.
256
  """
257
  # Add fallback values for None parameters
258
  if seed is None:
 
270
  seed = random.randint(0, 2**32 - 1)
271
 
272
 
273
+ print(f"🎬 video_generation_handler called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
274
 
275
  # Setup
276
  conditional_dict = text_encoder(text_prompts=[prompt])
 
296
  current_start_frame = 0
297
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
298
 
299
+ all_frames = []
300
+ total_frames_generated = 0
301
 
302
  # Ensure temp directory exists
303
  os.makedirs("gradio_tmp", exist_ok=True)
 
353
 
354
  print(f"πŸ” DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
355
 
356
+ # Process all frames from this block and add to main collection
 
357
  for frame_idx in range(pixels.shape[1]):
358
  frame_tensor = pixels[0, frame_idx]
359
 
 
362
  frame_np = frame_np.to(torch.uint8).cpu().numpy()
363
  frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
364
 
365
+ all_frames.append(frame_np)
366
+ total_frames_generated += 1
367
 
368
+ print(f"πŸ“¦ Block {idx+1}/{num_blocks}, Frame {frame_idx+1}/{pixels.shape[1]} - Total frames: {total_frames_generated}")
369
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
  current_start_frame += current_num_frames
371
 
372
+ # Generate final MP4 as base64 data URI
373
+ if all_frames:
374
+ print(f"πŸ“Ή Encoding final MP4 with {len(all_frames)} frames")
375
+
376
+ try:
377
+ base64_data_uri = frames_to_mp4_base64(all_frames, fps)
378
+
379
+ print(f"βœ… Video generation complete! {total_frames_generated} frames encoded to base64 data URI")
380
+ return base64_data_uri
381
+
382
+ except Exception as e:
383
+ print(f"⚠️ Error encoding final video: {e}")
384
+ import traceback
385
+ traceback.print_exc()
386
+ return "data:video/mp4;base64,"
387
+ else:
388
+ print("⚠️ No frames generated")
389
+ return "data:video/mp4;base64,"
 
390
 
391
  # --- Gradio UI Layout ---
392
+ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing demo") as demo:
393
+ gr.Markdown("Video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009))")
394
 
395
  with gr.Row():
396
  with gr.Column(scale=2):
 
401
  lines=4,
402
  value=""
403
  )
404
+ start_btn = gr.Button("🎬 Generate Video", variant="primary", size="lg")
405
 
406
  gr.Markdown("### βš™οΈ Settings")
407
  with gr.Row():
 
450
  )
451
 
452
  with gr.Column(scale=3):
453
+ gr.Markdown("### 🎬 Generated Video (Base64)")
454
+ video_output = gr.Textbox(
455
+ label="Base64 Video Data URI",
456
+ lines=10,
457
+ max_lines=20,
458
+ show_copy_button=True,
459
+ placeholder="Generated video will appear here as base64 data URI..."
 
 
 
 
 
 
 
 
 
 
 
460
  )
461
 
462
+ # Connect the generator to the text output
463
  start_btn.click(
464
+ fn=video_generation_handler,
465
  inputs=[prompt, seed, fps, width, height, duration],
466
+ outputs=[video_output]
467
  )
468
 
469
  # --- Launch App ---
 
473
  shutil.rmtree("gradio_tmp")
474
  os.makedirs("gradio_tmp", exist_ok=True)
475
 
476
+ print("πŸš€ Video Generation Node (default engine is Wan2.1 1.3B Self-Forcing)")
477
  print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
478
+ print(f"🎯 Video encoding: PyAV (MP4/H.264)")
479
  print(f"⚑ GPU acceleration: {gpu}")
480
 
481
  demo.queue().launch(
app_with_streaming.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ # not sure why it works in the original space but says "pip not found" in mine
3
+ #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
4
+
5
+ import os
6
+ from huggingface_hub import snapshot_download, hf_hub_download
7
+
8
+ # Configuration for data paths
9
+ DATA_ROOT = os.path.normpath(os.getenv('DATA_ROOT', '.'))
10
+ WAN_MODELS_PATH = os.path.join(DATA_ROOT, 'wan_models')
11
+ OTHER_MODELS_PATH = os.path.join(DATA_ROOT, 'other_models')
12
+
13
+ snapshot_download(
14
+ repo_id="Wan-AI/Wan2.1-T2V-1.3B",
15
+ local_dir=os.path.join(WAN_MODELS_PATH, "Wan2.1-T2V-1.3B"),
16
+ local_dir_use_symlinks=False,
17
+ resume_download=True,
18
+ repo_type="model"
19
+ )
20
+
21
+ hf_hub_download(
22
+ repo_id="gdhe17/Self-Forcing",
23
+ filename="checkpoints/self_forcing_dmd.pt",
24
+ local_dir=OTHER_MODELS_PATH,
25
+ local_dir_use_symlinks=False
26
+ )
27
+ import re
28
+ import random
29
+ import argparse
30
+ import hashlib
31
+ import urllib.request
32
+ import time
33
+ from PIL import Image
34
+ import torch
35
+ import gradio as gr
36
+ from omegaconf import OmegaConf
37
+ from tqdm import tqdm
38
+ import imageio
39
+ import av
40
+ import uuid
41
+ import tempfile
42
+ import shutil
43
+ from pathlib import Path
44
+ from typing import Dict, Any, List, Optional, Tuple, Union
45
+
46
+ from pipeline import CausalInferencePipeline
47
+ from demo_utils.constant import ZERO_VAE_CACHE
48
+ from demo_utils.vae_block3 import VAEDecoderWrapper
49
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
50
+
51
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM #, BitsAndBytesConfig
52
+ import numpy as np
53
+
54
+ device = "cuda" if torch.cuda.is_available() else "cpu"
55
+
56
+ DEFAULT_WIDTH = 832
57
+ DEFAULT_HEIGHT = 480
58
+
59
+ # --- Argument Parsing ---
60
+ parser = argparse.ArgumentParser(description="Gradio Demo for Self-Forcing with Frame Streaming")
61
+ parser.add_argument('--port', type=int, default=7860, help="Port to run the Gradio app on.")
62
+ parser.add_argument('--host', type=str, default='0.0.0.0', help="Host to bind the Gradio app to.")
63
+ parser.add_argument("--checkpoint_path", type=str, default=os.path.join(OTHER_MODELS_PATH, 'checkpoints', 'self_forcing_dmd.pt'), help="Path to the model checkpoint.")
64
+ parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml', help="Path to the model config.")
65
+ parser.add_argument('--share', action='store_true', help="Create a public Gradio link.")
66
+ parser.add_argument('--trt', action='store_true', help="Use TensorRT optimized VAE decoder.")
67
+ parser.add_argument('--fps', type=float, default=15.0, help="Playback FPS for frame streaming.")
68
+ args = parser.parse_args()
69
+
70
+ gpu = "cuda"
71
+
72
+ try:
73
+ config = OmegaConf.load(args.config_path)
74
+ default_config = OmegaConf.load("configs/default_config.yaml")
75
+ config = OmegaConf.merge(default_config, config)
76
+ except FileNotFoundError as e:
77
+ print(f"Error loading config file: {e}\n. Please ensure config files are in the correct path.")
78
+ exit(1)
79
+
80
+ # Initialize Models
81
+ print("Initializing models...")
82
+ text_encoder = WanTextEncoder()
83
+ transformer = WanDiffusionWrapper(is_causal=True)
84
+
85
+ try:
86
+ state_dict = torch.load(args.checkpoint_path, map_location="cpu")
87
+ transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
88
+ except FileNotFoundError as e:
89
+ print(f"Error loading checkpoint: {e}\nPlease ensure the checkpoint '{args.checkpoint_path}' exists.")
90
+ exit(1)
91
+
92
+ text_encoder.eval().to(dtype=torch.float16).requires_grad_(False)
93
+ transformer.eval().to(dtype=torch.float16).requires_grad_(False)
94
+
95
+ text_encoder.to(gpu)
96
+ transformer.to(gpu)
97
+
98
+ APP_STATE = {
99
+ "torch_compile_applied": False,
100
+ "fp8_applied": False,
101
+ "current_use_taehv": False,
102
+ "current_vae_decoder": None,
103
+ }
104
+
105
+ # I've tried to enable it, but I didn't notice a significant performance improvement..
106
+ ENABLE_TORCH_COMPILATION = False
107
+
108
+ # β€œdefault”: The default mode, used when no mode parameter is specified. It provides a good balance between performance and overhead.
109
+ # β€œreduce-overhead”: Minimizes Python-related overhead using CUDA graphs. However, it may increase memory usage.
110
+ # β€œmax-autotune”: Uses Triton or template-based matrix multiplications on supported devices. It takes longer to compile but optimizes for the fastest possible execution. On GPUs it enables CUDA graphs by default.
111
+ # β€œmax-autotune-no-cudagraphs”: Similar to β€œmax-autotune”, but without CUDA graphs.
112
+ TORCH_COMPILATION_MODE = "default"
113
+
114
+ # Apply torch.compile for maximum performance
115
+ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
116
+ print("πŸš€ Applying torch.compile for speed optimization...")
117
+ transformer.compile(mode=TORCH_COMPILATION_MODE)
118
+ APP_STATE["torch_compile_applied"] = True
119
+ print("βœ… torch.compile applied to transformer")
120
+
121
+ def frames_to_ts_file(frames, filepath, fps = 15):
122
+ """
123
+ Convert frames directly to .ts file using PyAV.
124
+
125
+ Args:
126
+ frames: List of numpy arrays (HWC, RGB, uint8)
127
+ filepath: Output file path
128
+ fps: Frames per second
129
+
130
+ Returns:
131
+ The filepath of the created file
132
+ """
133
+ if not frames:
134
+ return filepath
135
+
136
+ height, width = frames[0].shape[:2]
137
+
138
+ # Create container for MPEG-TS format
139
+ container = av.open(filepath, mode='w', format='mpegts')
140
+
141
+ # Add video stream with optimized settings for streaming
142
+ stream = container.add_stream('h264', rate=fps)
143
+ stream.width = width
144
+ stream.height = height
145
+ stream.pix_fmt = 'yuv420p'
146
+
147
+ # Optimize for low latency streaming
148
+ stream.options = {
149
+ 'preset': 'ultrafast',
150
+ 'tune': 'zerolatency',
151
+ 'crf': '23',
152
+ 'profile': 'baseline',
153
+ 'level': '3.0'
154
+ }
155
+
156
+ try:
157
+ for frame_np in frames:
158
+ frame = av.VideoFrame.from_ndarray(frame_np, format='rgb24')
159
+ frame = frame.reformat(format=stream.pix_fmt)
160
+ for packet in stream.encode(frame):
161
+ container.mux(packet)
162
+
163
+ for packet in stream.encode():
164
+ container.mux(packet)
165
+
166
+ finally:
167
+ container.close()
168
+
169
+ return filepath
170
+
171
+ def initialize_vae_decoder(use_taehv=False, use_trt=False):
172
+ if use_trt:
173
+ from demo_utils.vae import VAETRTWrapper
174
+ print("Initializing TensorRT VAE Decoder...")
175
+ vae_decoder = VAETRTWrapper()
176
+ APP_STATE["current_use_taehv"] = False
177
+ elif use_taehv:
178
+ print("Initializing TAEHV VAE Decoder...")
179
+ from demo_utils.taehv import TAEHV
180
+ taehv_checkpoint_path = "checkpoints/taew2_1.pth"
181
+ if not os.path.exists(taehv_checkpoint_path):
182
+ print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
183
+ os.makedirs("checkpoints", exist_ok=True)
184
+ download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
185
+ try:
186
+ urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
187
+ except Exception as e:
188
+ raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
189
+
190
+ class DotDict(dict): __getattr__ = dict.get
191
+
192
+ class TAEHVDiffusersWrapper(torch.nn.Module):
193
+ def __init__(self):
194
+ super().__init__()
195
+ self.dtype = torch.float16
196
+ self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
197
+ self.config = DotDict(scaling_factor=1.0)
198
+ def decode(self, latents, return_dict=None):
199
+ return self.taehv.decode_video(latents, parallel=not LOW_MEMORY).mul_(2).sub_(1)
200
+
201
+ vae_decoder = TAEHVDiffusersWrapper()
202
+ APP_STATE["current_use_taehv"] = True
203
+ else:
204
+ print("Initializing Default VAE Decoder...")
205
+ vae_decoder = VAEDecoderWrapper()
206
+ try:
207
+ vae_state_dict = torch.load(os.path.join(WAN_MODELS_PATH, 'Wan2.1-T2V-1.3B', 'Wan2.1_VAE.pth'), map_location="cpu")
208
+ decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
209
+ vae_decoder.load_state_dict(decoder_state_dict)
210
+ except FileNotFoundError:
211
+ print("Warning: Default VAE weights not found.")
212
+ APP_STATE["current_use_taehv"] = False
213
+
214
+ vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(gpu)
215
+
216
+ # Apply torch.compile to VAE decoder if enabled (following demo.py pattern)
217
+ if APP_STATE["torch_compile_applied"] and not use_taehv and not use_trt:
218
+ print("πŸš€ Applying torch.compile to VAE decoder...")
219
+ vae_decoder.compile(mode=TORCH_COMPILATION_MODE)
220
+ print("βœ… torch.compile applied to VAE decoder")
221
+
222
+ APP_STATE["current_vae_decoder"] = vae_decoder
223
+ print(f"βœ… VAE decoder initialized: {'TAEHV' if use_taehv else 'Default VAE'}")
224
+
225
+ # Initialize with default VAE
226
+ initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
227
+
228
+ pipeline = CausalInferencePipeline(
229
+ config, device=gpu, generator=transformer, text_encoder=text_encoder,
230
+ vae=APP_STATE["current_vae_decoder"]
231
+ )
232
+
233
+ pipeline.to(dtype=torch.float16).to(gpu)
234
+
235
+ @torch.no_grad()
236
+ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, duration=5):
237
+ """
238
+ Generator function that yields .ts video chunks using PyAV for streaming.
239
+ """
240
+ # Add fallback values for None parameters
241
+ if seed is None:
242
+ seed = -1
243
+ if fps is None:
244
+ fps = 15
245
+ if width is None:
246
+ width = DEFAULT_WIDTH
247
+ if height is None:
248
+ height = DEFAULT_HEIGHT
249
+ if duration is None:
250
+ duration = 5
251
+
252
+ if seed == -1:
253
+ seed = random.randint(0, 2**32 - 1)
254
+
255
+
256
+ print(f"🎬 video_generation_handler_streaming called, seed: {seed}, duration: {duration}s, fps: {fps}, width: {width}, height: {height}")
257
+
258
+ # Setup
259
+ conditional_dict = text_encoder(text_prompts=[prompt])
260
+ for key, value in conditional_dict.items():
261
+ conditional_dict[key] = value.to(dtype=torch.float16)
262
+
263
+ rnd = torch.Generator(gpu).manual_seed(int(seed))
264
+ pipeline._initialize_kv_cache(1, torch.float16, device=gpu)
265
+ pipeline._initialize_crossattn_cache(1, torch.float16, device=gpu)
266
+ noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
267
+
268
+ vae_cache, latents_cache = None, None
269
+ if not APP_STATE["current_use_taehv"] and not args.trt:
270
+ vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
271
+
272
+ # Calculate number of blocks based on duration
273
+ # Current setup generates approximately 5 seconds with 7 blocks
274
+ # So we scale proportionally
275
+ base_duration = 5.0 # seconds
276
+ base_blocks = 8
277
+ num_blocks = max(1, int(base_blocks * duration / base_duration))
278
+
279
+ current_start_frame = 0
280
+ all_num_frames = [pipeline.num_frame_per_block] * num_blocks
281
+
282
+ total_frames_yielded = 0
283
+
284
+ # Ensure temp directory exists
285
+ os.makedirs("gradio_tmp", exist_ok=True)
286
+
287
+ # Generation loop
288
+ for idx, current_num_frames in enumerate(all_num_frames):
289
+ print(f"πŸ“¦ Processing block {idx+1}/{num_blocks}")
290
+
291
+ noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
292
+
293
+ # Denoising steps
294
+ for step_idx, current_timestep in enumerate(pipeline.denoising_step_list):
295
+ timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
296
+ _, denoised_pred = pipeline.generator(
297
+ noisy_image_or_video=noisy_input, conditional_dict=conditional_dict,
298
+ timestep=timestep, kv_cache=pipeline.kv_cache1,
299
+ crossattn_cache=pipeline.crossattn_cache,
300
+ current_start=current_start_frame * pipeline.frame_seq_length
301
+ )
302
+ if step_idx < len(pipeline.denoising_step_list) - 1:
303
+ next_timestep = pipeline.denoising_step_list[step_idx + 1]
304
+ noisy_input = pipeline.scheduler.add_noise(
305
+ denoised_pred.flatten(0, 1), torch.randn_like(denoised_pred.flatten(0, 1)),
306
+ next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
307
+ ).unflatten(0, denoised_pred.shape[:2])
308
+
309
+ if idx < len(all_num_frames) - 1:
310
+ pipeline.generator(
311
+ noisy_image_or_video=denoised_pred, conditional_dict=conditional_dict,
312
+ timestep=torch.zeros_like(timestep), kv_cache=pipeline.kv_cache1,
313
+ crossattn_cache=pipeline.crossattn_cache,
314
+ current_start=current_start_frame * pipeline.frame_seq_length,
315
+ )
316
+
317
+ # Decode to pixels
318
+ if args.trt:
319
+ pixels, vae_cache = pipeline.vae.forward(denoised_pred.half(), *vae_cache)
320
+ elif APP_STATE["current_use_taehv"]:
321
+ if latents_cache is None:
322
+ latents_cache = denoised_pred
323
+ else:
324
+ denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
325
+ latents_cache = denoised_pred[:, -3:]
326
+ pixels = pipeline.vae.decode(denoised_pred)
327
+ else:
328
+ pixels, vae_cache = pipeline.vae(denoised_pred.half(), *vae_cache)
329
+
330
+ # Handle frame skipping
331
+ if idx == 0 and not args.trt:
332
+ pixels = pixels[:, 3:]
333
+ elif APP_STATE["current_use_taehv"] and idx > 0:
334
+ pixels = pixels[:, 12:]
335
+
336
+ print(f"πŸ” DEBUG Block {idx}: Pixels shape after skipping: {pixels.shape}")
337
+
338
+ # Process all frames from this block at once
339
+ all_frames_from_block = []
340
+ for frame_idx in range(pixels.shape[1]):
341
+ frame_tensor = pixels[0, frame_idx]
342
+
343
+ # Convert to numpy (HWC, RGB, uint8)
344
+ frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
345
+ frame_np = frame_np.to(torch.uint8).cpu().numpy()
346
+ frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
347
+
348
+ all_frames_from_block.append(frame_np)
349
+ total_frames_yielded += 1
350
+
351
+ # Yield status update for each frame (cute tracking!)
352
+ blocks_completed = idx
353
+ current_block_progress = (frame_idx + 1) / pixels.shape[1]
354
+ total_progress = (blocks_completed + current_block_progress) / num_blocks * 100
355
+
356
+ # Cap at 100% to avoid going over
357
+ total_progress = min(total_progress, 100.0)
358
+
359
+ frame_status_html = (
360
+ f"<div style='padding: 10px; border: 1px solid #ddd; border-radius: 8px; font-family: sans-serif;'>"
361
+ f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>Generating Video...</p>"
362
+ f" <div style='background: #e9ecef; border-radius: 4px; width: 100%; overflow: hidden;'>"
363
+ f" <div style='width: {total_progress:.1f}%; height: 20px; background-color: #0d6efd; transition: width 0.2s;'></div>"
364
+ f" </div>"
365
+ f" <p style='margin: 8px 0 0 0; color: #555; font-size: 14px; text-align: right;'>"
366
+ f" Block {idx+1}/{num_blocks} | Frame {total_frames_yielded} | {total_progress:.1f}%"
367
+ f" </p>"
368
+ f"</div>"
369
+ )
370
+
371
+ # Yield None for video but update status (frame-by-frame tracking)
372
+ yield None, frame_status_html
373
+
374
+ # Encode entire block as one chunk
375
+ if all_frames_from_block:
376
+ print(f"πŸ“Ή Encoding block {idx} with {len(all_frames_from_block)} frames")
377
+
378
+ try:
379
+ chunk_uuid = str(uuid.uuid4())[:8]
380
+ ts_filename = f"block_{idx:04d}_{chunk_uuid}.ts"
381
+ ts_path = os.path.join("gradio_tmp", ts_filename)
382
+
383
+ frames_to_ts_file(all_frames_from_block, ts_path, fps)
384
+
385
+ # Calculate final progress for this block
386
+ total_progress = (idx + 1) / num_blocks * 100
387
+
388
+ # Yield the actual video chunk
389
+ yield ts_path, gr.update()
390
+
391
+ except Exception as e:
392
+ print(f"⚠️ Error encoding block {idx}: {e}")
393
+ import traceback
394
+ traceback.print_exc()
395
+
396
+ current_start_frame += current_num_frames
397
+
398
+ # Final completion status
399
+ final_status_html = (
400
+ f"<div style='padding: 16px; border: 1px solid #198754; background: linear-gradient(135deg, #d1e7dd, #f8f9fa); border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);'>"
401
+ f" <div style='display: flex; align-items: center; margin-bottom: 8px;'>"
402
+ f" <span style='font-size: 24px; margin-right: 12px;'>πŸŽ‰</span>"
403
+ f" <h4 style='margin: 0; color: #0f5132; font-size: 18px;'>Stream Complete!</h4>"
404
+ f" </div>"
405
+ f" <div style='background: rgba(255,255,255,0.7); padding: 8px; border-radius: 4px;'>"
406
+ f" <p style='margin: 0; color: #0f5132; font-weight: 500;'>"
407
+ f" πŸ“Š Generated {total_frames_yielded} frames across {num_blocks} blocks"
408
+ f" </p>"
409
+ f" <p style='margin: 4px 0 0 0; color: #0f5132; font-size: 14px;'>"
410
+ f" 🎬 Playback: {fps} FPS β€’ πŸ“ Format: MPEG-TS/H.264"
411
+ f" </p>"
412
+ f" </div>"
413
+ f"</div>"
414
+ )
415
+ yield None, final_status_html
416
+ print(f"βœ… PyAV streaming complete! {total_frames_yielded} frames across {num_blocks} blocks")
417
+
418
+ # --- Gradio UI Layout ---
419
+ with gr.Blocks(title="Wan2.1 1.3B Self-Forcing streaming demo") as demo:
420
+ gr.Markdown("Real-time video generation with distilled Wan2-1 1.3B [[Model]](https://huggingface.co/gdhe17/Self-Forcing), [[Project page]](https://self-forcing.github.io), [[Paper]](https://huggingface.co/papers/2506.08009)")
421
+
422
+ with gr.Row():
423
+ with gr.Column(scale=2):
424
+ with gr.Group():
425
+ prompt = gr.Textbox(
426
+ label="Prompt",
427
+ placeholder="A stylish woman walks down a Tokyo street...",
428
+ lines=4,
429
+ value=""
430
+ )
431
+ start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
432
+
433
+ gr.Markdown("### βš™οΈ Settings")
434
+ with gr.Row():
435
+ seed = gr.Number(
436
+ label="Seed",
437
+ value=-1,
438
+ info="Use -1 for random seed",
439
+ precision=0
440
+ )
441
+ fps = gr.Slider(
442
+ label="Playback FPS",
443
+ minimum=1,
444
+ maximum=30,
445
+ value=args.fps,
446
+ step=1,
447
+ visible=False,
448
+ info="Frames per second for playback"
449
+ )
450
+
451
+ with gr.Row():
452
+ duration = gr.Slider(
453
+ label="Duration (seconds)",
454
+ minimum=1,
455
+ maximum=5,
456
+ value=3,
457
+ step=1,
458
+ info="Video duration in seconds"
459
+ )
460
+
461
+ with gr.Row():
462
+ width = gr.Slider(
463
+ label="Width",
464
+ minimum=224,
465
+ maximum=832,
466
+ value=DEFAULT_WIDTH,
467
+ step=8,
468
+ info="Video width in pixels (8px steps)"
469
+ )
470
+ height = gr.Slider(
471
+ label="Height",
472
+ minimum=224,
473
+ maximum=832,
474
+ value=DEFAULT_HEIGHT,
475
+ step=8,
476
+ info="Video height in pixels (8px steps)"
477
+ )
478
+
479
+ with gr.Column(scale=3):
480
+ gr.Markdown("### πŸ“Ί Video Stream")
481
+ streaming_video = gr.Video(
482
+ label="Live Stream",
483
+ streaming=True,
484
+ loop=True,
485
+ height=400,
486
+ autoplay=True,
487
+ show_label=False
488
+ )
489
+
490
+ status_display = gr.HTML(
491
+ value=(
492
+ "<div style='text-align: center; padding: 20px; color: #666; border: 1px dashed #ddd; border-radius: 8px;'>"
493
+ "🎬 Ready to start streaming...<br>"
494
+ "<small>Configure your prompt and click 'Start Streaming'</small>"
495
+ "</div>"
496
+ ),
497
+ label="Generation Status"
498
+ )
499
+
500
+ # Connect the generator to the streaming video
501
+ start_btn.click(
502
+ fn=video_generation_handler_streaming,
503
+ inputs=[prompt, seed, fps, width, height, duration],
504
+ outputs=[streaming_video, status_display]
505
+ )
506
+
507
+ # --- Launch App ---
508
+ if __name__ == "__main__":
509
+ if os.path.exists("gradio_tmp"):
510
+ import shutil
511
+ shutil.rmtree("gradio_tmp")
512
+ os.makedirs("gradio_tmp", exist_ok=True)
513
+
514
+ print("πŸš€ Clapper Rendering Node (default engine is Wan2.1 1.3B Self-Forcing)")
515
+ print(f"πŸ“ Temporary files will be stored in: gradio_tmp/")
516
+ print(f"🎯 Chunk encoding: PyAV (MPEG-TS/H.264)")
517
+ print(f"⚑ GPU acceleration: {gpu}")
518
+
519
+ demo.queue().launch(
520
+ server_name=args.host,
521
+ server_port=args.port,
522
+ share=args.share,
523
+ show_error=True,
524
+ max_threads=40,
525
+ mcp_server=True
526
+ )