jbilcke-hf HF Staff commited on
Commit
b2c19b1
·
1 Parent(s): d1a3122
Files changed (3) hide show
  1. README.md +10 -105
  2. app.py +43 -15
  3. handler.py +0 -545
README.md CHANGED
@@ -1,116 +1,21 @@
1
  ---
2
  emoji: 🎥
3
- title: 'Self Forcing Wan 2.1 '
4
- short_description: Real-time video generation
5
  sdk: gradio
6
- sdk_version: 5.34.2
7
  ---
8
- <p align="center">
9
- <h1 align="center">Self Forcing</h1>
10
- <h3 align="center">Bridging the Train-Test Gap in Autoregressive Video Diffusion</h3>
11
- </p>
12
- <p align="center">
13
- <p align="center">
14
- <a href="https://www.xunhuang.me/">Xun Huang</a><sup>1</sup>
15
- ·
16
- <a href="https://zhengqili.github.io/">Zhengqi Li</a><sup>1</sup>
17
- ·
18
- <a href="https://guandehe.github.io/">Guande He</a><sup>2</sup>
19
- ·
20
- <a href="https://mingyuanzhou.github.io/">Mingyuan Zhou</a><sup>2</sup>
21
- ·
22
- <a href="https://research.adobe.com/person/eli-shechtman/">Eli Shechtman</a><sup>1</sup><br>
23
- <sup>1</sup>Adobe Research <sup>2</sup>UT Austin
24
- </p>
25
- <h3 align="center"><a href="https://arxiv.org/abs/2506.08009">Paper</a> | <a href="https://self-forcing.github.io">Website</a> | <a href="https://huggingface.co/gdhe17/Self-Forcing/tree/main">Models (HuggingFace)</a></h3>
26
- </p>
27
-
28
- ---
29
-
30
- Self Forcing trains autoregressive video diffusion models by **simulating the inference process during training**, performing autoregressive rollout with KV caching. It resolves the train-test distribution mismatch and enables **real-time, streaming video generation on a single RTX 4090** while matching the quality of state-of-the-art diffusion models.
31
-
32
- ---
33
-
34
-
35
- https://github.com/user-attachments/assets/7548c2db-fe03-4ba8-8dd3-52d2c6160739
36
 
 
37
 
38
- ## Requirements
39
- We tested this repo on the following setup:
40
- * Nvidia GPU with at least 24 GB memory (RTX 4090, A100, and H100 are tested).
41
- * Linux operating system.
42
- * 64 GB RAM.
43
 
44
- Other hardware setup could also work but hasn't been tested.
45
 
46
- ## Installation
47
- Create a conda environment and install dependencies:
48
- ```
49
- conda create -n self_forcing python=3.10 -y
50
- conda activate self_forcing
51
- pip install -r requirements.txt
52
- pip install flash-attn --no-build-isolation
53
- python setup.py develop
54
- ```
55
-
56
- ## Quick Start
57
- ### Download checkpoints
58
- ```
59
- huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B
60
- huggingface-cli download gdhe17/Self-Forcing checkpoints/self_forcing_dmd.pt --local-dir .
61
- ```
62
-
63
- ### GUI demo
64
- ```
65
- python demo.py
66
- ```
67
- Note:
68
- * **Our model works better with long, detailed prompts** since it's trained with such prompts. We will integrate prompt extension into the codebase (similar to [Wan2.1](https://github.com/Wan-Video/Wan2.1/tree/main?tab=readme-ov-file#2-using-prompt-extention)) in the future. For now, it is recommended to use third-party LLMs (such as GPT-4o) to extend your prompt before providing to the model.
69
- * You may want to adjust FPS so it plays smoothly on your device.
70
- * The speed can be improved by enabling `torch.compile`, [TAEHV-VAE](https://github.com/madebyollin/taehv/), or using FP8 Linear layers, although the latter two options may sacrifice quality. It is recommended to use `torch.compile` if possible and enable TAEHV-VAE if further speedup is needed.
71
-
72
- ### CLI Inference
73
- Example inference script using the chunk-wise autoregressive checkpoint trained with DMD:
74
- ```
75
- python inference.py \
76
- --config_path configs/self_forcing_dmd.yaml \
77
- --output_folder videos/self_forcing_dmd \
78
- --checkpoint_path checkpoints/self_forcing_dmd.pt \
79
- --data_path prompts/MovieGenVideoBench_extended.txt \
80
- --use_ema
81
- ```
82
- Other config files and corresponding checkpoints can be found in [configs](configs) folder and our [huggingface repo](https://huggingface.co/gdhe17/Self-Forcing/tree/main/checkpoints).
83
-
84
- ## Training
85
- ### Download text prompts and ODE initialized checkpoint
86
- ```
87
- huggingface-cli download gdhe17/Self-Forcing checkpoints/ode_init.pt --local-dir .
88
- huggingface-cli download gdhe17/Self-Forcing vidprom_filtered_extended.txt --local-dir prompts
89
- ```
90
- Note: Our training algorithm (except for the GAN version) is data-free (**no video data is needed**). For now, we directly provide the ODE initialization checkpoint and will add more instructions on how to perform ODE initialization in the future (which is identical to the process described in the [CausVid](https://github.com/tianweiy/CausVid) repo).
91
 
92
- ### Self Forcing Training with DMD
93
- ```
94
- torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
95
- --rdzv_backend=c10d \
96
- --rdzv_endpoint $MASTER_ADDR \
97
- train.py \
98
- --config_path configs/self_forcing_dmd.yaml \
99
- --logdir logs/self_forcing_dmd \
100
- --disable-wandb
101
- ```
102
- Our training run uses 600 iterations and completes in under 2 hours using 64 H100 GPUs. By implementing gradient accumulation, it should be possible to reproduce the results in less than 16 hours using 8 H100 GPUs.
103
 
104
- ## Acknowledgements
105
  This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid) by [Tianwei Yin](https://tianweiy.github.io/) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
106
-
107
- ## Citation
108
- If you find this codebase useful for your research, please kindly cite our paper:
109
- ```
110
- @article{huang2025selfforcing,
111
- title={Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion},
112
- author={Huang, Xun and Li, Zhengqi and He, Guande and Zhou, Mingyuan and Shechtman, Eli},
113
- journal={arXiv preprint arXiv:2506.08009},
114
- year={2025}
115
- }
116
- ```
 
1
  ---
2
  emoji: 🎥
3
+ title: 'Self-Forcing Wan2.1-1.3B'
4
+ short_description: MCP server for real-time video generation
5
  sdk: gradio
6
+ sdk_version: 5.35.0
7
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ This server is meant to be used as an API/MCP server by generative AI video apps.
10
 
11
+ It is best to run it on a big GPU such as a Nvidia A100, H100 or H200.
 
 
 
 
12
 
13
+ It can run on a Nvidia L40S but not in real-time.
14
 
15
+ ## Acknowledgements
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
+ If you want to use this codebase for your project, I invite you to use the [original repository](https://huggingface.co/spaces/multimodalart/self-forcing) instead, it will better fit your needs (it contains things I removed because I didn't need them).
18
+
19
+ A big thank you to [Multimodalart](https://huggingface.co/multimodalart/) who created the original repo!
 
 
 
 
 
 
 
 
20
 
 
21
  This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid) by [Tianwei Yin](https://tianweiy.github.io/) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -91,8 +91,11 @@ APP_STATE = {
91
  "current_vae_decoder": None,
92
  }
93
 
 
 
 
94
  # Apply torch.compile for maximum performance
95
- if not APP_STATE["torch_compile_applied"]:
96
  print("🚀 Applying torch.compile for speed optimization...")
97
  transformer.compile(mode="max-autotune-no-cudagraphs")
98
  APP_STATE["torch_compile_applied"] = True
@@ -213,7 +216,7 @@ pipeline = CausalInferencePipeline(
213
  pipeline.to(dtype=torch.float16).to(gpu)
214
 
215
  @torch.no_grad()
216
- def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224):
217
  """
218
  Generator function that yields .ts video chunks using PyAV for streaming.
219
  Now optimized for block-based processing.
@@ -221,7 +224,18 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
221
  if seed == -1:
222
  seed = random.randint(0, 2**32 - 1)
223
 
224
- print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}")
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  # Setup
227
  conditional_dict = text_encoder(text_prompts=[prompt])
@@ -237,7 +251,13 @@ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, heigh
237
  if not APP_STATE["current_use_taehv"] and not args.trt:
238
  vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
239
 
240
- num_blocks = 7
 
 
 
 
 
 
241
  current_start_frame = 0
242
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
243
 
@@ -394,16 +414,6 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
394
 
395
  start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
396
 
397
- gr.Markdown("### 🎯 Examples")
398
- gr.Examples(
399
- examples=[
400
- "A close-up shot of a ceramic teacup slowly pouring water into a glass mug.",
401
- "A playful cat is seen playing an electronic guitar, strumming the strings with its front paws. The cat has distinctive black facial markings and a bushy tail. It sits comfortably on a small stool, its body slightly tilted as it focuses intently on the instrument. The setting is a cozy, dimly lit room with vintage posters on the walls, adding a retro vibe. The cat's expressive eyes convey a sense of joy and concentration. Medium close-up shot, focusing on the cat's face and hands interacting with the guitar.",
402
- "A dynamic over-the-shoulder perspective of a chef meticulously plating a dish in a bustling kitchen. The chef, a middle-aged woman, deftly arranges ingredients on a pristine white plate. Her hands move with precision, each gesture deliberate and practiced. The background shows a crowded kitchen with steaming pots, whirring blenders, and the clatter of utensils. Bright lights highlight the scene, casting shadows across the busy workspace. The camera angle captures the chef's detailed work from behind, emphasizing his skill and dedication.",
403
- ],
404
- inputs=[prompt],
405
- )
406
-
407
  gr.Markdown("### ⚙️ Settings")
408
  with gr.Row():
409
  seed = gr.Number(
@@ -422,6 +432,24 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
422
  info="Frames per second for playback"
423
  )
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  with gr.Row():
426
  width = gr.Slider(
427
  label="Width",
@@ -465,7 +493,7 @@ with gr.Blocks(title="Self-Forcing Streaming Demo") as demo:
465
  # Connect the generator to the streaming video
466
  start_btn.click(
467
  fn=video_generation_handler_streaming,
468
- inputs=[prompt, seed, fps, width, height],
469
  outputs=[streaming_video, status_display]
470
  )
471
 
 
91
  "current_vae_decoder": None,
92
  }
93
 
94
+ # I've tried to enable it, but I didn't notice a significant performance improvement..
95
+ ENABLE_TORCH_COMPILATION = False
96
+
97
  # Apply torch.compile for maximum performance
98
+ if not APP_STATE["torch_compile_applied"] and ENABLE_TORCH_COMPILATION:
99
  print("🚀 Applying torch.compile for speed optimization...")
100
  transformer.compile(mode="max-autotune-no-cudagraphs")
101
  APP_STATE["torch_compile_applied"] = True
 
216
  pipeline.to(dtype=torch.float16).to(gpu)
217
 
218
  @torch.no_grad()
219
+ def video_generation_handler_streaming(prompt, seed=42, fps=15, width=400, height=224, duration=5, buffering=2):
220
  """
221
  Generator function that yields .ts video chunks using PyAV for streaming.
222
  Now optimized for block-based processing.
 
224
  if seed == -1:
225
  seed = random.randint(0, 2**32 - 1)
226
 
227
+ print(f"🎬 Starting PyAV streaming: '{prompt}', seed: {seed}, duration: {duration}s, buffering: {buffering}s")
228
+
229
+ # Buffering delay
230
+ if buffering > 0:
231
+ buffering_status_html = (
232
+ f"<div style='padding: 10px; border: 1px solid #ffc107; background: #fff3cd; border-radius: 8px; font-family: sans-serif;'>"
233
+ f" <p style='margin: 0 0 8px 0; font-size: 16px; font-weight: bold;'>⏳ Buffering...</p>"
234
+ f" <p style='margin: 0; color: #856404; font-size: 14px;'>Waiting {buffering} seconds before starting stream</p>"
235
+ f"</div>"
236
+ )
237
+ yield None, buffering_status_html
238
+ time.sleep(buffering)
239
 
240
  # Setup
241
  conditional_dict = text_encoder(text_prompts=[prompt])
 
251
  if not APP_STATE["current_use_taehv"] and not args.trt:
252
  vae_cache = [c.to(device=gpu, dtype=torch.float16) for c in ZERO_VAE_CACHE]
253
 
254
+ # Calculate number of blocks based on duration
255
+ # Current setup generates approximately 5 seconds with 7 blocks
256
+ # So we scale proportionally
257
+ base_duration = 5.0 # seconds
258
+ base_blocks = 7
259
+ num_blocks = max(1, int(base_blocks * duration / base_duration))
260
+
261
  current_start_frame = 0
262
  all_num_frames = [pipeline.num_frame_per_block] * num_blocks
263
 
 
414
 
415
  start_btn = gr.Button("🎬 Start Streaming", variant="primary", size="lg")
416
 
 
 
 
 
 
 
 
 
 
 
417
  gr.Markdown("### ⚙️ Settings")
418
  with gr.Row():
419
  seed = gr.Number(
 
432
  info="Frames per second for playback"
433
  )
434
 
435
+ with gr.Row():
436
+ duration = gr.Slider(
437
+ label="Duration (seconds)",
438
+ minimum=1,
439
+ maximum=10,
440
+ value=5,
441
+ step=1,
442
+ info="Video duration in seconds"
443
+ )
444
+ buffering = gr.Slider(
445
+ label="Buffering (seconds)",
446
+ minimum=0,
447
+ maximum=5,
448
+ value=2,
449
+ step=0.5,
450
+ info="Wait time before starting stream"
451
+ )
452
+
453
  with gr.Row():
454
  width = gr.Slider(
455
  label="Width",
 
493
  # Connect the generator to the streaming video
494
  start_btn.click(
495
  fn=video_generation_handler_streaming,
496
+ inputs=[prompt, seed, fps, width, height, duration, buffering],
497
  outputs=[streaming_video, status_display]
498
  )
499
 
handler.py DELETED
@@ -1,545 +0,0 @@
1
- from dataclasses import dataclass
2
- from pathlib import Path
3
- import logging
4
- import base64
5
- import random
6
- import gc
7
- import os
8
- import numpy as np
9
- import torch
10
- from typing import Dict, Any, Optional, List, Union, Tuple
11
- import json
12
- from omegaconf import OmegaConf
13
- from PIL import Image
14
- import io
15
-
16
- from pipeline import CausalInferencePipeline
17
- from demo_utils.constant import ZERO_VAE_CACHE
18
- from demo_utils.vae_block3 import VAEDecoderWrapper
19
- from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
20
-
21
- # Configure logging
22
- logging.basicConfig(level=logging.INFO)
23
- logger = logging.getLogger(__name__)
24
-
25
- # Get token from environment
26
- hf_token = os.getenv("HF_API_TOKEN")
27
-
28
- # Constraints
29
- MAX_LARGE_SIDE = 1280
30
- MAX_SMALL_SIDE = 768
31
- MAX_FRAMES = 169 # Based on Wan model capabilities
32
-
33
- @dataclass
34
- class GenerationConfig:
35
- """Configuration for video generation using Wan model"""
36
-
37
- # general content settings
38
- prompt: str = ""
39
- negative_prompt: str = "worst quality, lowres, blurry, distorted, cropped, watermarked, watermark, logo, subtitle, subtitles"
40
-
41
- # video model settings
42
- width: int = 960 # Wan model default width
43
- height: int = 576 # Wan model default height
44
-
45
- # number of frames (based on Wan model block structure)
46
- num_frames: int = 105 # 7 blocks * 15 frames per block
47
-
48
- # guidance and sampling settings
49
- guidance_scale: float = 7.5
50
- num_inference_steps: int = 4 # Distilled model uses fewer steps
51
-
52
- # reproducible generation settings
53
- seed: int = -1 # -1 means random seed
54
-
55
- # output settings
56
- fps: int = 15 # FPS of the final video
57
- quality: int = 18 # Video quality (CRF)
58
-
59
- # advanced settings
60
- mixed_precision: bool = True
61
- use_taehv: bool = False # Whether to use TAEHV decoder
62
- use_trt: bool = False # Whether to use TensorRT optimized decoder
63
-
64
- def validate_and_adjust(self) -> 'GenerationConfig':
65
- """Validate and adjust parameters to meet constraints"""
66
- # Ensure dimensions are multiples of 32 and within limits
67
- self.width = max(128, min(MAX_LARGE_SIDE, round(self.width / 32) * 32))
68
- self.height = max(128, min(MAX_LARGE_SIDE, round(self.height / 32) * 32))
69
-
70
- # Ensure frame count is reasonable
71
- self.num_frames = min(self.num_frames, MAX_FRAMES)
72
-
73
- # Set random seed if not specified
74
- if self.seed == -1:
75
- self.seed = random.randint(0, 2**32 - 1)
76
-
77
- return self
78
-
79
- def load_image_to_tensor_with_resize_and_crop(
80
- image_input: Union[str, bytes],
81
- target_height: int = 576,
82
- target_width: int = 960,
83
- quality: int = 100
84
- ) -> torch.Tensor:
85
- """Load and process an image into a tensor for Wan model.
86
-
87
- Args:
88
- image_input: Either a file path (str) or image data (bytes)
89
- target_height: Desired height of output tensor
90
- target_width: Desired width of output tensor
91
- quality: JPEG quality to use when re-encoding
92
- """
93
- # Handle base64 data URI
94
- if isinstance(image_input, str) and image_input.startswith('data:'):
95
- header, encoded = image_input.split(",", 1)
96
- image_data = base64.b64decode(encoded)
97
- image = Image.open(io.BytesIO(image_data)).convert("RGB")
98
- # Handle raw bytes
99
- elif isinstance(image_input, bytes):
100
- image = Image.open(io.BytesIO(image_input)).convert("RGB")
101
- # Handle file path
102
- elif isinstance(image_input, str):
103
- image = Image.open(image_input).convert("RGB")
104
- else:
105
- raise ValueError("image_input must be either a file path, bytes, or base64 data URI")
106
-
107
- # Apply JPEG compression if quality < 100
108
- if quality < 100:
109
- buffer = io.BytesIO()
110
- image.save(buffer, format="JPEG", quality=quality)
111
- buffer.seek(0)
112
- image = Image.open(buffer).convert("RGB")
113
-
114
- # Resize and crop to target dimensions
115
- input_width, input_height = image.size
116
- aspect_ratio_target = target_width / target_height
117
- aspect_ratio_frame = input_width / input_height
118
-
119
- if aspect_ratio_frame > aspect_ratio_target:
120
- new_width = int(input_height * aspect_ratio_target)
121
- new_height = input_height
122
- x_start = (input_width - new_width) // 2
123
- y_start = 0
124
- else:
125
- new_width = input_width
126
- new_height = int(input_width / aspect_ratio_target)
127
- x_start = 0
128
- y_start = (input_height - new_height) // 2
129
-
130
- image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height))
131
- image = image.resize((target_width, target_height))
132
-
133
- # Convert to tensor format expected by Wan model
134
- frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float()
135
- frame_tensor = (frame_tensor / 127.5) - 1.0
136
-
137
- return frame_tensor.unsqueeze(0)
138
-
139
- def initialize_vae_decoder(use_taehv=False, use_trt=False, device="cuda"):
140
- """Initialize VAE decoder based on configuration"""
141
- if use_trt:
142
- from demo_utils.vae import VAETRTWrapper
143
- print("Initializing TensorRT VAE Decoder...")
144
- vae_decoder = VAETRTWrapper()
145
- elif use_taehv:
146
- print("Initializing TAEHV VAE Decoder...")
147
- from demo_utils.taehv import TAEHV
148
- taehv_checkpoint_path = "/repository/taehv/taew2_1.pth"
149
-
150
- if not os.path.exists(taehv_checkpoint_path):
151
- print(f"Downloading TAEHV checkpoint to {taehv_checkpoint_path}...")
152
- os.makedirs("checkpoints", exist_ok=True)
153
- import urllib.request
154
- download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
155
- try:
156
- urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
157
- except Exception as e:
158
- raise RuntimeError(f"Failed to download taew2_1.pth: {e}")
159
-
160
- class DotDict(dict):
161
- __getattr__ = dict.get
162
-
163
- class TAEHVDiffusersWrapper(torch.nn.Module):
164
- def __init__(self):
165
- super().__init__()
166
- self.dtype = torch.float16
167
- self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
168
- self.config = DotDict(scaling_factor=1.0)
169
-
170
- def decode(self, latents, return_dict=None):
171
- return self.taehv.decode_video(latents, parallel=True).mul_(2).sub_(1)
172
-
173
- vae_decoder = TAEHVDiffusersWrapper()
174
- else:
175
- print("Initializing Default VAE Decoder...")
176
- vae_decoder = VAEDecoderWrapper()
177
- try:
178
- # I should have called the folder "Wan2.1-T2V-1.3B" instead of "wan2.1"
179
- #vae_state_dict = torch.load('/repository/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
180
- vae_state_dict = torch.load('/repository/wan2.1/Wan2.1_VAE.pth', map_location="cpu")
181
- decoder_state_dict = {k: v for k, v in vae_state_dict.items() if 'decoder.' in k or 'conv2' in k}
182
- vae_decoder.load_state_dict(decoder_state_dict)
183
- except FileNotFoundError:
184
- print("Warning: Default VAE weights not found.")
185
-
186
- vae_decoder.eval().to(dtype=torch.float16).requires_grad_(False).to(device)
187
- return vae_decoder
188
-
189
- def create_wan_pipeline(
190
- config: GenerationConfig,
191
- device: str = "cuda"
192
- ) -> CausalInferencePipeline:
193
- """Create and configure the Wan video pipeline"""
194
-
195
- # Load configuration
196
- try:
197
- wan_config = OmegaConf.load("/repository/configs/self_forcing_dmd.yaml")
198
- default_config = OmegaConf.load("/repository/configs/default_config.yaml")
199
- wan_config = OmegaConf.merge(default_config, wan_config)
200
- except FileNotFoundError as e:
201
- logger.error(f"Error loading config file: {e}")
202
- raise RuntimeError(f"Config files not found: {e}")
203
-
204
- # Initialize model components
205
- text_encoder = WanTextEncoder()
206
- transformer = WanDiffusionWrapper(is_causal=True)
207
-
208
- # Load checkpoint
209
- checkpoint_path = "/repository/self-forcing/checkpoints/self_forcing_dmd.pt"
210
- try:
211
- state_dict = torch.load(checkpoint_path, map_location="cpu")
212
- transformer.load_state_dict(state_dict.get('generator_ema', state_dict.get('generator')))
213
- except FileNotFoundError as e:
214
- logger.error(f"Error loading checkpoint: {e}")
215
- raise RuntimeError(f"Checkpoint not found: {checkpoint_path}")
216
-
217
- # Move to device and set precision
218
- text_encoder.eval().to(dtype=torch.float16).requires_grad_(False).to(device)
219
- transformer.eval().to(dtype=torch.float16).requires_grad_(False).to(device)
220
-
221
- # Initialize VAE decoder
222
- vae_decoder = initialize_vae_decoder(
223
- use_taehv=config.use_taehv,
224
- use_trt=config.use_trt,
225
- device=device
226
- )
227
-
228
- # Create pipeline
229
- pipeline = CausalInferencePipeline(
230
- wan_config,
231
- device=device,
232
- generator=transformer,
233
- text_encoder=text_encoder,
234
- vae=vae_decoder
235
- )
236
-
237
- pipeline.to(dtype=torch.float16).to(device)
238
-
239
- return pipeline
240
-
241
- def frames_to_video_bytes(frames: List[np.ndarray], fps: int = 15, quality: int = 18) -> bytes:
242
- """Convert frames to MP4 video bytes"""
243
- import tempfile
244
- import subprocess
245
-
246
- with tempfile.TemporaryDirectory() as temp_dir:
247
- # Save frames as images
248
- frame_paths = []
249
- for i, frame in enumerate(frames):
250
- frame_path = os.path.join(temp_dir, f"frame_{i:06d}.png")
251
- Image.fromarray(frame).save(frame_path)
252
- frame_paths.append(frame_path)
253
-
254
- # Create video using ffmpeg
255
- output_path = os.path.join(temp_dir, "output.mp4")
256
- cmd = [
257
- "ffmpeg", "-y", "-framerate", str(fps),
258
- "-i", os.path.join(temp_dir, "frame_%06d.png"),
259
- "-c:v", "libx264", "-crf", str(quality),
260
- "-pix_fmt", "yuv420p", "-movflags", "faststart",
261
- output_path
262
- ]
263
-
264
- try:
265
- subprocess.run(cmd, check=True, capture_output=True)
266
- with open(output_path, "rb") as f:
267
- return f.read()
268
- except subprocess.CalledProcessError as e:
269
- logger.error(f"FFmpeg error: {e}")
270
- raise RuntimeError(f"Video encoding failed: {e}")
271
-
272
- class EndpointHandler:
273
- """Handler for the Wan Video endpoint"""
274
-
275
- def __init__(self, model_path: str = "./"):
276
- """Initialize the endpoint handler
277
-
278
- Args:
279
- model_path: Path to model weights
280
- """
281
- # Enable TF32 for potential speedup on Ampere GPUs
282
- torch.backends.cuda.matmul.allow_tf32 = True
283
-
284
- # The pipeline will be loaded during inference to save memory
285
- self.pipeline = None
286
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
287
-
288
- # Perform warm-up inference if GPU is available
289
- if self.device == "cuda":
290
- logger.info("Performing warm-up inference...")
291
- self._warmup()
292
- logger.info("Warm-up completed!")
293
- else:
294
- logger.info("CPU device detected, skipping warm-up")
295
-
296
- def _warmup(self):
297
- """Perform a warm-up inference to prepare the model for future requests"""
298
- try:
299
- # Create a simple test configuration
300
- test_config = GenerationConfig(
301
- prompt="a cat walking",
302
- negative_prompt="worst quality, lowres",
303
- width=480, # Smaller resolution for faster warm-up
304
- height=320,
305
- num_frames=33, # Fewer frames for faster warm-up
306
- guidance_scale=7.5,
307
- num_inference_steps=2, # Fewer steps for faster warm-up
308
- seed=42, # Fixed seed for consistent warm-up
309
- fps=15,
310
- mixed_precision=True,
311
- ).validate_and_adjust()
312
-
313
- # Create the pipeline if it doesn't exist
314
- if self.pipeline is None:
315
- self.pipeline = create_wan_pipeline(test_config, self.device)
316
-
317
- # Run a quick inference
318
- with torch.no_grad():
319
- # Set seeds for reproducibility
320
- random.seed(test_config.seed)
321
- np.random.seed(test_config.seed)
322
- torch.manual_seed(test_config.seed)
323
-
324
- # Generate video frames (simplified version)
325
- conditional_dict = self.pipeline.text_encoder(text_prompts=[test_config.prompt])
326
- for key, value in conditional_dict.items():
327
- conditional_dict[key] = value.to(dtype=torch.float16)
328
-
329
- rnd = torch.Generator(self.device).manual_seed(int(test_config.seed))
330
- self.pipeline._initialize_kv_cache(1, torch.float16, device=self.device)
331
- self.pipeline._initialize_crossattn_cache(1, torch.float16, device=self.device)
332
-
333
- # Generate a small noise tensor for testing
334
- noise = torch.randn([1, 3, 8, 20, 32], device=self.device, dtype=torch.float16, generator=rnd)
335
-
336
- # Clean up
337
- del noise, conditional_dict
338
- torch.cuda.empty_cache()
339
- gc.collect()
340
-
341
- logger.info("Warm-up successful!")
342
-
343
- except Exception as e:
344
- # Log the error but don't fail initialization
345
- import traceback
346
- error_message = f"Warm-up failed (but this is non-critical): {str(e)}\n{traceback.format_exc()}"
347
- logger.warning(error_message)
348
-
349
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
350
- """Process inference requests
351
-
352
- Args:
353
- data: Request data containing inputs and parameters
354
-
355
- Returns:
356
- Dictionary with generated video and metadata
357
- """
358
- # Extract inputs and parameters
359
- inputs = data.get("inputs", {})
360
-
361
- # Support both formats:
362
- # 1. {"inputs": {"prompt": "...", "image": "..."}}
363
- # 2. {"inputs": "..."} (prompt only)
364
- if isinstance(inputs, str):
365
- input_prompt = inputs
366
- input_image = None
367
- else:
368
- input_prompt = inputs.get("prompt", "")
369
- input_image = inputs.get("image")
370
-
371
- params = data.get("parameters", {})
372
-
373
- if not input_prompt:
374
- raise ValueError("Prompt must be provided")
375
-
376
- # Create and validate configuration
377
- config = GenerationConfig(
378
- # general content settings
379
- prompt=input_prompt,
380
- negative_prompt=params.get("negative_prompt", GenerationConfig.negative_prompt),
381
-
382
- # video model settings
383
- width=params.get("width", GenerationConfig.width),
384
- height=params.get("height", GenerationConfig.height),
385
- num_frames=params.get("num_frames", GenerationConfig.num_frames),
386
- guidance_scale=params.get("guidance_scale", GenerationConfig.guidance_scale),
387
- num_inference_steps=params.get("num_inference_steps", GenerationConfig.num_inference_steps),
388
-
389
- # reproducible generation settings
390
- seed=params.get("seed", GenerationConfig.seed),
391
-
392
- # output settings
393
- fps=params.get("fps", GenerationConfig.fps),
394
- quality=params.get("quality", GenerationConfig.quality),
395
-
396
- # advanced settings
397
- mixed_precision=params.get("mixed_precision", GenerationConfig.mixed_precision),
398
- use_taehv=params.get("use_taehv", GenerationConfig.use_taehv),
399
- use_trt=params.get("use_trt", GenerationConfig.use_trt),
400
- ).validate_and_adjust()
401
-
402
- try:
403
- with torch.no_grad():
404
- # Set random seeds for reproducibility
405
- random.seed(config.seed)
406
- np.random.seed(config.seed)
407
- torch.manual_seed(config.seed)
408
-
409
- # Create pipeline if not already created
410
- if self.pipeline is None:
411
- self.pipeline = create_wan_pipeline(config, self.device)
412
-
413
- # Prepare text conditioning
414
- conditional_dict = self.pipeline.text_encoder(text_prompts=[config.prompt])
415
- for key, value in conditional_dict.items():
416
- conditional_dict[key] = value.to(dtype=torch.float16)
417
-
418
- # Initialize caches
419
- rnd = torch.Generator(self.device).manual_seed(int(config.seed))
420
- self.pipeline._initialize_kv_cache(1, torch.float16, device=self.device)
421
- self.pipeline._initialize_crossattn_cache(1, torch.float16, device=self.device)
422
-
423
- # Generate noise tensor
424
- noise = torch.randn(
425
- [1, 21, 16, config.height // 16, config.width // 16],
426
- device=self.device,
427
- dtype=torch.float16,
428
- generator=rnd
429
- )
430
-
431
- # Initialize VAE cache
432
- vae_cache = None
433
- latents_cache = None
434
- if not config.use_taehv and not config.use_trt:
435
- vae_cache = [c.to(device=self.device, dtype=torch.float16) for c in ZERO_VAE_CACHE]
436
-
437
- # Generation parameters
438
- num_blocks = 7
439
- current_start_frame = 0
440
- all_num_frames = [self.pipeline.num_frame_per_block] * num_blocks
441
-
442
- all_frames = []
443
-
444
- # Generate video blocks
445
- for idx, current_num_frames in enumerate(all_num_frames):
446
- logger.info(f"Processing block {idx+1}/{num_blocks}")
447
-
448
- noisy_input = noise[:, current_start_frame : current_start_frame + current_num_frames]
449
-
450
- # Denoising steps
451
- for step_idx, current_timestep in enumerate(self.pipeline.denoising_step_list):
452
- timestep = torch.ones([1, current_num_frames], device=noise.device, dtype=torch.int64) * current_timestep
453
- _, denoised_pred = self.pipeline.generator(
454
- noisy_image_or_video=noisy_input,
455
- conditional_dict=conditional_dict,
456
- timestep=timestep,
457
- kv_cache=self.pipeline.kv_cache1,
458
- crossattn_cache=self.pipeline.crossattn_cache,
459
- current_start=current_start_frame * self.pipeline.frame_seq_length
460
- )
461
-
462
- if step_idx < len(self.pipeline.denoising_step_list) - 1:
463
- next_timestep = self.pipeline.denoising_step_list[step_idx + 1]
464
- noisy_input = self.pipeline.scheduler.add_noise(
465
- denoised_pred.flatten(0, 1),
466
- torch.randn_like(denoised_pred.flatten(0, 1)),
467
- next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
468
- ).unflatten(0, denoised_pred.shape[:2])
469
-
470
- # Update cache for next block
471
- if idx < len(all_num_frames) - 1:
472
- self.pipeline.generator(
473
- noisy_image_or_video=denoised_pred,
474
- conditional_dict=conditional_dict,
475
- timestep=torch.zeros_like(timestep),
476
- kv_cache=self.pipeline.kv_cache1,
477
- crossattn_cache=self.pipeline.crossattn_cache,
478
- current_start=current_start_frame * self.pipeline.frame_seq_length,
479
- )
480
-
481
- # Decode to pixels
482
- if config.use_trt:
483
- pixels, vae_cache = self.pipeline.vae.forward(denoised_pred.half(), *vae_cache)
484
- elif config.use_taehv:
485
- if latents_cache is None:
486
- latents_cache = denoised_pred
487
- else:
488
- denoised_pred = torch.cat([latents_cache, denoised_pred], dim=1)
489
- latents_cache = denoised_pred[:, -3:]
490
- pixels = self.pipeline.vae.decode(denoised_pred)
491
- else:
492
- pixels, vae_cache = self.pipeline.vae(denoised_pred.half(), *vae_cache)
493
-
494
- # Handle frame skipping
495
- if idx == 0 and not config.use_trt:
496
- pixels = pixels[:, 3:]
497
- elif config.use_taehv and idx > 0:
498
- pixels = pixels[:, 12:]
499
-
500
- # Convert frames to numpy
501
- for frame_idx in range(pixels.shape[1]):
502
- frame_tensor = pixels[0, frame_idx]
503
- frame_np = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
504
- frame_np = frame_np.to(torch.uint8).cpu().numpy()
505
- frame_np = np.transpose(frame_np, (1, 2, 0)) # CHW -> HWC
506
- all_frames.append(frame_np)
507
-
508
- current_start_frame += current_num_frames
509
-
510
- # Convert frames to video
511
- video_bytes = frames_to_video_bytes(all_frames, fps=config.fps, quality=config.quality)
512
-
513
- # Convert to base64 data URI
514
- video_b64 = base64.b64encode(video_bytes).decode('utf-8')
515
- video_uri = f"data:video/mp4;base64,{video_b64}"
516
-
517
- # Prepare metadata
518
- metadata = {
519
- "width": config.width,
520
- "height": config.height,
521
- "num_frames": len(all_frames),
522
- "fps": config.fps,
523
- "duration": len(all_frames) / config.fps,
524
- "seed": config.seed,
525
- "prompt": config.prompt,
526
- }
527
-
528
- # Clean up to prevent CUDA OOM errors
529
- del noise, conditional_dict, pixels
530
- if self.device == "cuda":
531
- torch.cuda.empty_cache()
532
- gc.collect()
533
-
534
- return {
535
- "video": video_uri,
536
- "content-type": "video/mp4",
537
- "metadata": metadata
538
- }
539
-
540
- except Exception as e:
541
- # Log the error and reraise
542
- import traceback
543
- error_message = f"Error generating video: {str(e)}\n{traceback.format_exc()}"
544
- logger.error(error_message)
545
- raise RuntimeError(error_message)