linoyts HF Staff commited on
Commit
ba7cb71
·
verified ·
1 Parent(s): 3733bbd

Upload 57 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. examples/i2v_input.JPG +3 -0
  3. generate.py +411 -0
  4. gradio_ti2v.py +154 -0
  5. requirements.txt +15 -0
  6. wan/__init__.py +5 -0
  7. wan/__pycache__/__init__.cpython-310.pyc +0 -0
  8. wan/__pycache__/image2video.cpython-310.pyc +0 -0
  9. wan/__pycache__/text2video.cpython-310.pyc +0 -0
  10. wan/__pycache__/textimage2video.cpython-310.pyc +0 -0
  11. wan/configs/__init__.py +39 -0
  12. wan/configs/__pycache__/__init__.cpython-310.pyc +0 -0
  13. wan/configs/__pycache__/shared_config.cpython-310.pyc +0 -0
  14. wan/configs/__pycache__/wan_i2v_A14B.cpython-310.pyc +0 -0
  15. wan/configs/__pycache__/wan_t2v_A14B.cpython-310.pyc +0 -0
  16. wan/configs/__pycache__/wan_ti2v_5B.cpython-310.pyc +0 -0
  17. wan/configs/shared_config.py +20 -0
  18. wan/configs/wan_i2v_A14B.py +37 -0
  19. wan/configs/wan_t2v_A14B.py +37 -0
  20. wan/configs/wan_ti2v_5B.py +36 -0
  21. wan/distributed/__init__.py +1 -0
  22. wan/distributed/__pycache__/__init__.cpython-310.pyc +0 -0
  23. wan/distributed/__pycache__/fsdp.cpython-310.pyc +0 -0
  24. wan/distributed/__pycache__/sequence_parallel.cpython-310.pyc +0 -0
  25. wan/distributed/__pycache__/ulysses.cpython-310.pyc +0 -0
  26. wan/distributed/__pycache__/util.cpython-310.pyc +0 -0
  27. wan/distributed/fsdp.py +43 -0
  28. wan/distributed/sequence_parallel.py +176 -0
  29. wan/distributed/ulysses.py +47 -0
  30. wan/distributed/util.py +51 -0
  31. wan/image2video.py +431 -0
  32. wan/modules/__init__.py +19 -0
  33. wan/modules/__pycache__/__init__.cpython-310.pyc +0 -0
  34. wan/modules/__pycache__/attention.cpython-310.pyc +0 -0
  35. wan/modules/__pycache__/model.cpython-310.pyc +0 -0
  36. wan/modules/__pycache__/t5.cpython-310.pyc +0 -0
  37. wan/modules/__pycache__/tokenizers.cpython-310.pyc +0 -0
  38. wan/modules/__pycache__/vae2_1.cpython-310.pyc +0 -0
  39. wan/modules/__pycache__/vae2_2.cpython-310.pyc +0 -0
  40. wan/modules/attention.py +179 -0
  41. wan/modules/model.py +546 -0
  42. wan/modules/t5.py +513 -0
  43. wan/modules/tokenizers.py +82 -0
  44. wan/modules/vae2_1.py +663 -0
  45. wan/modules/vae2_2.py +1051 -0
  46. wan/text2video.py +378 -0
  47. wan/textimage2video.py +619 -0
  48. wan/utils/__init__.py +12 -0
  49. wan/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  50. wan/utils/__pycache__/fm_solvers.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/i2v_input.JPG filter=lfs diff=lfs merge=lfs -text
examples/i2v_input.JPG ADDED

Git LFS Details

  • SHA256: 077e3d965090c9028c69c00931675f42e1acc815c6eb450ab291b3b72d211a8e
  • Pointer size: 131 Bytes
  • Size of remote file: 251 kB
generate.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import logging
4
+ import os
5
+ import sys
6
+ import warnings
7
+ from datetime import datetime
8
+
9
+ warnings.filterwarnings('ignore')
10
+
11
+ import random
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from PIL import Image
16
+
17
+ import wan
18
+ from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
19
+ from wan.distributed.util import init_distributed_group
20
+ from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
21
+ from wan.utils.utils import cache_video, str2bool
22
+
23
+ EXAMPLE_PROMPT = {
24
+ "t2v-A14B": {
25
+ "prompt":
26
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
27
+ },
28
+ "i2v-A14B": {
29
+ "prompt":
30
+ "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
31
+ "image":
32
+ "examples/i2v_input.JPG",
33
+ },
34
+ "ti2v-5B": {
35
+ "prompt":
36
+ "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
37
+ },
38
+ }
39
+
40
+
41
+ def _validate_args(args):
42
+ # Basic check
43
+ assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
44
+ assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
45
+ assert args.task in EXAMPLE_PROMPT, f"Unsupport task: {args.task}"
46
+
47
+ if args.prompt is None:
48
+ args.prompt = EXAMPLE_PROMPT[args.task]["prompt"]
49
+ if args.image is None and "image" in EXAMPLE_PROMPT[args.task]:
50
+ args.image = EXAMPLE_PROMPT[args.task]["image"]
51
+
52
+ if args.task == "i2v-A14B":
53
+ assert args.image is not None, "Please specify the image path for i2v."
54
+
55
+ cfg = WAN_CONFIGS[args.task]
56
+
57
+ if args.sample_steps is None:
58
+ args.sample_steps = cfg.sample_steps
59
+
60
+ if args.sample_shift is None:
61
+ args.sample_shift = cfg.sample_shift
62
+
63
+ if args.sample_guide_scale is None:
64
+ args.sample_guide_scale = cfg.sample_guide_scale
65
+
66
+ if args.frame_num is None:
67
+ args.frame_num = cfg.frame_num
68
+
69
+ args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
70
+ 0, sys.maxsize)
71
+ # Size check
72
+ assert args.size in SUPPORTED_SIZES[
73
+ args.
74
+ task], f"Unsupport size {args.size} for task {args.task}, supported sizes are: {', '.join(SUPPORTED_SIZES[args.task])}"
75
+
76
+
77
+ def _parse_args():
78
+ parser = argparse.ArgumentParser(
79
+ description="Generate a image or video from a text prompt or image using Wan"
80
+ )
81
+ parser.add_argument(
82
+ "--task",
83
+ type=str,
84
+ default="t2v-A14B",
85
+ choices=list(WAN_CONFIGS.keys()),
86
+ help="The task to run.")
87
+ parser.add_argument(
88
+ "--size",
89
+ type=str,
90
+ default="1280*720",
91
+ choices=list(SIZE_CONFIGS.keys()),
92
+ help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
93
+ )
94
+ parser.add_argument(
95
+ "--frame_num",
96
+ type=int,
97
+ default=None,
98
+ help="How many frames of video are generated. The number should be 4n+1"
99
+ )
100
+ parser.add_argument(
101
+ "--ckpt_dir",
102
+ type=str,
103
+ default=None,
104
+ help="The path to the checkpoint directory.")
105
+ parser.add_argument(
106
+ "--offload_model",
107
+ type=str2bool,
108
+ default=None,
109
+ help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
110
+ )
111
+ parser.add_argument(
112
+ "--ulysses_size",
113
+ type=int,
114
+ default=1,
115
+ help="The size of the ulysses parallelism in DiT.")
116
+ parser.add_argument(
117
+ "--t5_fsdp",
118
+ action="store_true",
119
+ default=False,
120
+ help="Whether to use FSDP for T5.")
121
+ parser.add_argument(
122
+ "--t5_cpu",
123
+ action="store_true",
124
+ default=False,
125
+ help="Whether to place T5 model on CPU.")
126
+ parser.add_argument(
127
+ "--dit_fsdp",
128
+ action="store_true",
129
+ default=False,
130
+ help="Whether to use FSDP for DiT.")
131
+ parser.add_argument(
132
+ "--save_file",
133
+ type=str,
134
+ default=None,
135
+ help="The file to save the generated video to.")
136
+ parser.add_argument(
137
+ "--prompt",
138
+ type=str,
139
+ default=None,
140
+ help="The prompt to generate the video from.")
141
+ parser.add_argument(
142
+ "--use_prompt_extend",
143
+ action="store_true",
144
+ default=False,
145
+ help="Whether to use prompt extend.")
146
+ parser.add_argument(
147
+ "--prompt_extend_method",
148
+ type=str,
149
+ default="local_qwen",
150
+ choices=["dashscope", "local_qwen"],
151
+ help="The prompt extend method to use.")
152
+ parser.add_argument(
153
+ "--prompt_extend_model",
154
+ type=str,
155
+ default=None,
156
+ help="The prompt extend model to use.")
157
+ parser.add_argument(
158
+ "--prompt_extend_target_lang",
159
+ type=str,
160
+ default="zh",
161
+ choices=["zh", "en"],
162
+ help="The target language of prompt extend.")
163
+ parser.add_argument(
164
+ "--base_seed",
165
+ type=int,
166
+ default=-1,
167
+ help="The seed to use for generating the video.")
168
+ parser.add_argument(
169
+ "--image",
170
+ type=str,
171
+ default=None,
172
+ help="The image to generate the video from.")
173
+ parser.add_argument(
174
+ "--sample_solver",
175
+ type=str,
176
+ default='unipc',
177
+ choices=['unipc', 'dpm++'],
178
+ help="The solver used to sample.")
179
+ parser.add_argument(
180
+ "--sample_steps", type=int, default=None, help="The sampling steps.")
181
+ parser.add_argument(
182
+ "--sample_shift",
183
+ type=float,
184
+ default=None,
185
+ help="Sampling shift factor for flow matching schedulers.")
186
+ parser.add_argument(
187
+ "--sample_guide_scale",
188
+ type=float,
189
+ default=None,
190
+ help="Classifier free guidance scale.")
191
+ parser.add_argument(
192
+ "--convert_model_dtype",
193
+ action="store_true",
194
+ default=False,
195
+ help="Whether to convert model paramerters dtype.")
196
+
197
+ args = parser.parse_args()
198
+
199
+ _validate_args(args)
200
+
201
+ return args
202
+
203
+
204
+ def _init_logging(rank):
205
+ # logging
206
+ if rank == 0:
207
+ # set format
208
+ logging.basicConfig(
209
+ level=logging.INFO,
210
+ format="[%(asctime)s] %(levelname)s: %(message)s",
211
+ handlers=[logging.StreamHandler(stream=sys.stdout)])
212
+ else:
213
+ logging.basicConfig(level=logging.ERROR)
214
+
215
+
216
+ def generate(args):
217
+ rank = int(os.getenv("RANK", 0))
218
+ world_size = int(os.getenv("WORLD_SIZE", 1))
219
+ local_rank = int(os.getenv("LOCAL_RANK", 0))
220
+ device = local_rank
221
+ _init_logging(rank)
222
+
223
+ if args.offload_model is None:
224
+ args.offload_model = False if world_size > 1 else True
225
+ logging.info(
226
+ f"offload_model is not specified, set to {args.offload_model}.")
227
+ if world_size > 1:
228
+ torch.cuda.set_device(local_rank)
229
+ dist.init_process_group(
230
+ backend="nccl",
231
+ init_method="env://",
232
+ rank=rank,
233
+ world_size=world_size)
234
+ else:
235
+ assert not (
236
+ args.t5_fsdp or args.dit_fsdp
237
+ ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
238
+ assert not (
239
+ args.ulysses_size > 1
240
+ ), f"sequence parallel are not supported in non-distributed environments."
241
+
242
+ if args.ulysses_size > 1:
243
+ assert args.ulysses_size == world_size, f"The number of ulysses_size should be equal to the world size."
244
+ init_distributed_group()
245
+
246
+ if args.use_prompt_extend:
247
+ if args.prompt_extend_method == "dashscope":
248
+ prompt_expander = DashScopePromptExpander(
249
+ model_name=args.prompt_extend_model,
250
+ task=args.task,
251
+ is_vl=args.image is not None)
252
+ elif args.prompt_extend_method == "local_qwen":
253
+ prompt_expander = QwenPromptExpander(
254
+ model_name=args.prompt_extend_model,
255
+ task=args.task,
256
+ is_vl=args.image is not None,
257
+ device=rank)
258
+ else:
259
+ raise NotImplementedError(
260
+ f"Unsupport prompt_extend_method: {args.prompt_extend_method}")
261
+
262
+ cfg = WAN_CONFIGS[args.task]
263
+ if args.ulysses_size > 1:
264
+ assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`."
265
+
266
+ logging.info(f"Generation job args: {args}")
267
+ logging.info(f"Generation model config: {cfg}")
268
+
269
+ if dist.is_initialized():
270
+ base_seed = [args.base_seed] if rank == 0 else [None]
271
+ dist.broadcast_object_list(base_seed, src=0)
272
+ args.base_seed = base_seed[0]
273
+
274
+ logging.info(f"Input prompt: {args.prompt}")
275
+ img = None
276
+ if args.image is not None:
277
+ img = Image.open(args.image).convert("RGB")
278
+ logging.info(f"Input image: {args.image}")
279
+
280
+ # prompt extend
281
+ if args.use_prompt_extend:
282
+ logging.info("Extending prompt ...")
283
+ if rank == 0:
284
+ prompt_output = prompt_expander(
285
+ args.prompt,
286
+ image=img,
287
+ tar_lang=args.prompt_extend_target_lang,
288
+ seed=args.base_seed)
289
+ if prompt_output.status == False:
290
+ logging.info(
291
+ f"Extending prompt failed: {prompt_output.message}")
292
+ logging.info("Falling back to original prompt.")
293
+ input_prompt = args.prompt
294
+ else:
295
+ input_prompt = prompt_output.prompt
296
+ input_prompt = [input_prompt]
297
+ else:
298
+ input_prompt = [None]
299
+ if dist.is_initialized():
300
+ dist.broadcast_object_list(input_prompt, src=0)
301
+ args.prompt = input_prompt[0]
302
+ logging.info(f"Extended prompt: {args.prompt}")
303
+
304
+ if "t2v" in args.task:
305
+ logging.info("Creating WanT2V pipeline.")
306
+ wan_t2v = wan.WanT2V(
307
+ config=cfg,
308
+ checkpoint_dir=args.ckpt_dir,
309
+ device_id=device,
310
+ rank=rank,
311
+ t5_fsdp=args.t5_fsdp,
312
+ dit_fsdp=args.dit_fsdp,
313
+ use_sp=(args.ulysses_size > 1),
314
+ t5_cpu=args.t5_cpu,
315
+ convert_model_dtype=args.convert_model_dtype,
316
+ )
317
+
318
+ logging.info(f"Generating video ...")
319
+ video = wan_t2v.generate(
320
+ args.prompt,
321
+ size=SIZE_CONFIGS[args.size],
322
+ frame_num=args.frame_num,
323
+ shift=args.sample_shift,
324
+ sample_solver=args.sample_solver,
325
+ sampling_steps=args.sample_steps,
326
+ guide_scale=args.sample_guide_scale,
327
+ seed=args.base_seed,
328
+ offload_model=args.offload_model)
329
+ elif "ti2v" in args.task:
330
+ logging.info("Creating WanTI2V pipeline.")
331
+ wan_ti2v = wan.WanTI2V(
332
+ config=cfg,
333
+ checkpoint_dir=args.ckpt_dir,
334
+ device_id=device,
335
+ rank=rank,
336
+ t5_fsdp=args.t5_fsdp,
337
+ dit_fsdp=args.dit_fsdp,
338
+ use_sp=(args.ulysses_size > 1),
339
+ t5_cpu=args.t5_cpu,
340
+ convert_model_dtype=args.convert_model_dtype,
341
+ )
342
+
343
+ logging.info(f"Generating video ...")
344
+ video = wan_ti2v.generate(
345
+ args.prompt,
346
+ img=img,
347
+ size=SIZE_CONFIGS[args.size],
348
+ max_area=MAX_AREA_CONFIGS[args.size],
349
+ frame_num=args.frame_num,
350
+ shift=args.sample_shift,
351
+ sample_solver=args.sample_solver,
352
+ sampling_steps=args.sample_steps,
353
+ guide_scale=args.sample_guide_scale,
354
+ seed=args.base_seed,
355
+ offload_model=args.offload_model)
356
+ else:
357
+ logging.info("Creating WanI2V pipeline.")
358
+ wan_i2v = wan.WanI2V(
359
+ config=cfg,
360
+ checkpoint_dir=args.ckpt_dir,
361
+ device_id=device,
362
+ rank=rank,
363
+ t5_fsdp=args.t5_fsdp,
364
+ dit_fsdp=args.dit_fsdp,
365
+ use_sp=(args.ulysses_size > 1),
366
+ t5_cpu=args.t5_cpu,
367
+ convert_model_dtype=args.convert_model_dtype,
368
+ )
369
+
370
+ logging.info("Generating video ...")
371
+ video = wan_i2v.generate(
372
+ args.prompt,
373
+ img,
374
+ max_area=MAX_AREA_CONFIGS[args.size],
375
+ frame_num=args.frame_num,
376
+ shift=args.sample_shift,
377
+ sample_solver=args.sample_solver,
378
+ sampling_steps=args.sample_steps,
379
+ guide_scale=args.sample_guide_scale,
380
+ seed=args.base_seed,
381
+ offload_model=args.offload_model)
382
+
383
+ if rank == 0:
384
+ if args.save_file is None:
385
+ formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S")
386
+ formatted_prompt = args.prompt.replace(" ", "_").replace("/",
387
+ "_")[:50]
388
+ suffix = '.mp4'
389
+ args.save_file = f"{args.task}_{args.size.replace('*','x') if sys.platform=='win32' else args.size}_{args.ulysses_size}_{formatted_prompt}_{formatted_time}" + suffix
390
+
391
+ logging.info(f"Saving generated video to {args.save_file}")
392
+ cache_video(
393
+ tensor=video[None],
394
+ save_file=args.save_file,
395
+ fps=cfg.sample_fps,
396
+ nrow=1,
397
+ normalize=True,
398
+ value_range=(-1, 1))
399
+ del video
400
+
401
+ torch.cuda.synchronize()
402
+ if dist.is_initialized():
403
+ dist.barrier()
404
+ dist.destroy_process_group()
405
+
406
+ logging.info("Finished.")
407
+
408
+
409
+ if __name__ == "__main__":
410
+ args = _parse_args()
411
+ generate(args)
gradio_ti2v.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # wan2.2-main/gradio_ti2v.py
2
+ import gradio as gr
3
+ import torch
4
+ import os
5
+ import sys
6
+ from huggingface_hub import snapshot_download
7
+ from PIL import Image
8
+ import random
9
+ import numpy as np
10
+
11
+ # Add project root to sys.path to allow importing 'wan'
12
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ import wan
15
+ from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
16
+ from wan.utils.utils import cache_video
17
+
18
+ # --- 1. Global Setup and Model Loading ---
19
+
20
+ print("Starting Gradio App for Wan 2.2 TI2V-5B...")
21
+
22
+ # Download model snapshots from Hugging Face Hub
23
+ repo_id = "Wan-AI/Wan2.2-TI2V-5B"
24
+ print(f"Downloading/loading checkpoints for {repo_id}...")
25
+ ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
26
+ print(f"Using checkpoints from {ckpt_dir}")
27
+
28
+ # Load the model configuration
29
+ TASK_NAME = 'ti2v-5B'
30
+ cfg = WAN_CONFIGS[TASK_NAME]
31
+ FIXED_FPS = 24
32
+ MIN_FRAMES_MODEL = 8
33
+ MAX_FRAMES_MODEL = 121
34
+
35
+ # Instantiate the pipeline in the global scope
36
+ print("Initializing WanTI2V pipeline...")
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+ device_id = 0 if torch.cuda.is_available() else -1
39
+ pipeline = wan.WanTI2V(
40
+ config=cfg,
41
+ checkpoint_dir=ckpt_dir,
42
+ device_id=device_id,
43
+ rank=0,
44
+ t5_fsdp=False,
45
+ dit_fsdp=False,
46
+ use_sp=False,
47
+ t5_cpu=False,
48
+ init_on_cpu=True,
49
+ convert_model_dtype=True,
50
+ )
51
+ print("Pipeline initialized and ready.")
52
+
53
+
54
+ # --- 2. Gradio Inference Function ---
55
+ def generate_video(
56
+ image,
57
+ prompt,
58
+ size,
59
+ duration_seconds,
60
+ sampling_steps,
61
+ guide_scale,
62
+ shift,
63
+ seed,
64
+ progress=gr.Progress(track_tqdm=True)
65
+ ):
66
+ """The main function to generate video, called by the Gradio interface."""
67
+ if seed == -1:
68
+ seed = random.randint(0, sys.maxsize)
69
+
70
+ input_image = Image.fromarray(image).convert("RGB") if image is not None else None
71
+
72
+ # Calculate number of frames based on duration
73
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
74
+
75
+ video_tensor = pipeline.generate(
76
+ input_prompt=prompt,
77
+ img=input_image, # Pass None for T2V, Image for I2V
78
+ size=SIZE_CONFIGS[size],
79
+ max_area=MAX_AREA_CONFIGS[size],
80
+ frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
81
+ shift=shift,
82
+ sample_solver='unipc',
83
+ sampling_steps=int(sampling_steps),
84
+ guide_scale=guide_scale,
85
+ seed=seed,
86
+ offload_model=True
87
+ )
88
+
89
+ # Save the video to a temporary file
90
+ video_path = cache_video(
91
+ tensor=video_tensor[None], # Add a batch dimension
92
+ save_file=None, # cache_video will create a temp file
93
+ fps=cfg.sample_fps,
94
+ normalize=True,
95
+ value_range=(-1, 1)
96
+ )
97
+
98
+ return video_path
99
+
100
+
101
+ # --- 3. Gradio Interface ---
102
+ css = ".gradio-container {max-width: 1100px !important} #output_video {height: 500px;} #input_image {height: 500px;}"
103
+
104
+ with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
105
+ gr.Markdown("# Wan 2.2 Text/Image-to-Video Demo (ti2v-5B)")
106
+ gr.Markdown("Generate a video from a text prompt. Optionally, provide an initial image to guide the generation (Image-to-Video).")
107
+
108
+ with gr.Row():
109
+ with gr.Column(scale=2):
110
+ image_input = gr.Image(type="numpy", label="Input Image (Optional)", elem_id="input_image")
111
+ prompt_input = gr.Textbox(label="Prompt", value="A beautiful waterfall in a lush jungle, cinematic.", lines=3)
112
+ duration_input = gr.Slider(
113
+ minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
114
+ maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
115
+ step=0.1,
116
+ value=2.0,
117
+ label="Duration (seconds)",
118
+ info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
119
+ )
120
+ size_input = gr.Dropdown(label="Output Resolution", choices=list(SUPPORTED_SIZES[TASK_NAME]), value="704*1280")
121
+ with gr.Column(scale=2):
122
+ video_output = gr.Video(label="Generated Video", elem_id="output_video")
123
+
124
+
125
+ with gr.Accordion("Advanced Settings", open=False):
126
+ steps_input = gr.Slider(label="Sampling Steps", minimum=10, maximum=70, value=35, step=1)
127
+ scale_input = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=10.0, value=cfg.sample_guide_scale, step=0.1)
128
+ shift_input = gr.Slider(label="Sample Shift", minimum=1.0, maximum=20.0, value=cfg.sample_shift, step=0.1)
129
+ seed_input = gr.Number(label="Seed (-1 for random)", value=-1, precision=0)
130
+
131
+ run_button = gr.Button("Generate Video", variant="primary")
132
+
133
+
134
+ example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
135
+ gr.Examples(
136
+ examples=[
137
+ [None, "A cinematic shot of a boat sailing on a calm sea at sunset.", "1280*704", 2.0],
138
+ [example_image_path, "The cat slowly blinks its eyes.", "704*1280", 1.5],
139
+ [None, "Drone footage flying over a futuristic city with flying cars.", "1280*704", 3.0],
140
+ ],
141
+ inputs=[image_input, prompt_input, size_input, duration_input],
142
+ outputs=video_output,
143
+ fn=generate_video,
144
+ cache_examples=False,
145
+ )
146
+
147
+ run_button.click(
148
+ fn=generate_video,
149
+ inputs=[image_input, prompt_input, size_input, duration_input, steps_input, scale_input, shift_input, seed_input],
150
+ outputs=video_output
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.4.0
2
+ torchvision>=0.19.0
3
+ opencv-python>=4.9.0.80
4
+ diffusers>=0.31.0
5
+ transformers>=4.49.0
6
+ tokenizers>=0.20.3
7
+ accelerate>=1.1.1
8
+ tqdm
9
+ imageio
10
+ easydict
11
+ ftfy
12
+ dashscope
13
+ imageio-ffmpeg
14
+ flash_attn
15
+ numpy>=1.23.5,<2
wan/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from . import configs, distributed, modules
3
+ from .image2video import WanI2V
4
+ from .text2video import WanT2V
5
+ from .textimage2video import WanTI2V
wan/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (341 Bytes). View file
 
wan/__pycache__/image2video.cpython-310.pyc ADDED
Binary file (12.3 kB). View file
 
wan/__pycache__/text2video.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
wan/__pycache__/textimage2video.cpython-310.pyc ADDED
Binary file (17.5 kB). View file
 
wan/configs/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import copy
3
+ import os
4
+
5
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
6
+
7
+ from .wan_i2v_A14B import i2v_A14B
8
+ from .wan_t2v_A14B import t2v_A14B
9
+ from .wan_ti2v_5B import ti2v_5B
10
+
11
+ WAN_CONFIGS = {
12
+ 't2v-A14B': t2v_A14B,
13
+ 'i2v-A14B': i2v_A14B,
14
+ 'ti2v-5B': ti2v_5B,
15
+ }
16
+
17
+ SIZE_CONFIGS = {
18
+ '720*1280': (720, 1280),
19
+ '1280*720': (1280, 720),
20
+ '480*832': (480, 832),
21
+ '832*480': (832, 480),
22
+ '704*1280': (704, 1280),
23
+ '1280*704': (1280, 704)
24
+ }
25
+
26
+ MAX_AREA_CONFIGS = {
27
+ '720*1280': 720 * 1280,
28
+ '1280*720': 1280 * 720,
29
+ '480*832': 480 * 832,
30
+ '832*480': 832 * 480,
31
+ '704*1280': 704 * 1280,
32
+ '1280*704': 1280 * 704,
33
+ }
34
+
35
+ SUPPORTED_SIZES = {
36
+ 't2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
37
+ 'i2v-A14B': ('720*1280', '1280*720', '480*832', '832*480'),
38
+ 'ti2v-5B': ('704*1280', '1280*704'),
39
+ }
wan/configs/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (745 Bytes). View file
 
wan/configs/__pycache__/shared_config.cpython-310.pyc ADDED
Binary file (856 Bytes). View file
 
wan/configs/__pycache__/wan_i2v_A14B.cpython-310.pyc ADDED
Binary file (976 Bytes). View file
 
wan/configs/__pycache__/wan_t2v_A14B.cpython-310.pyc ADDED
Binary file (963 Bytes). View file
 
wan/configs/__pycache__/wan_ti2v_5B.cpython-310.pyc ADDED
Binary file (871 Bytes). View file
 
wan/configs/shared_config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ #------------------------ Wan shared config ------------------------#
6
+ wan_shared_cfg = EasyDict()
7
+
8
+ # t5
9
+ wan_shared_cfg.t5_model = 'umt5_xxl'
10
+ wan_shared_cfg.t5_dtype = torch.bfloat16
11
+ wan_shared_cfg.text_len = 512
12
+
13
+ # transformer
14
+ wan_shared_cfg.param_dtype = torch.bfloat16
15
+
16
+ # inference
17
+ wan_shared_cfg.num_train_timesteps = 1000
18
+ wan_shared_cfg.sample_fps = 16
19
+ wan_shared_cfg.sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
20
+ wan_shared_cfg.frame_num = 81
wan/configs/wan_i2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ from easydict import EasyDict
4
+
5
+ from .shared_config import wan_shared_cfg
6
+
7
+ #------------------------ Wan I2V A14B ------------------------#
8
+
9
+ i2v_A14B = EasyDict(__name__='Config: Wan I2V A14B')
10
+ i2v_A14B.update(wan_shared_cfg)
11
+
12
+ i2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ i2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ i2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ i2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ i2v_A14B.patch_size = (1, 2, 2)
21
+ i2v_A14B.dim = 5120
22
+ i2v_A14B.ffn_dim = 13824
23
+ i2v_A14B.freq_dim = 256
24
+ i2v_A14B.num_heads = 40
25
+ i2v_A14B.num_layers = 40
26
+ i2v_A14B.window_size = (-1, -1)
27
+ i2v_A14B.qk_norm = True
28
+ i2v_A14B.cross_attn_norm = True
29
+ i2v_A14B.eps = 1e-6
30
+ i2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ i2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ i2v_A14B.sample_shift = 5.0
35
+ i2v_A14B.sample_steps = 40
36
+ i2v_A14B.boundary = 0.900
37
+ i2v_A14B.sample_guide_scale = (3.5, 3.5) # low noise, high noise
wan/configs/wan_t2v_A14B.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan T2V A14B ------------------------#
7
+
8
+ t2v_A14B = EasyDict(__name__='Config: Wan T2V A14B')
9
+ t2v_A14B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ t2v_A14B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ t2v_A14B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ t2v_A14B.vae_checkpoint = 'Wan2.1_VAE.pth'
17
+ t2v_A14B.vae_stride = (4, 8, 8)
18
+
19
+ # transformer
20
+ t2v_A14B.patch_size = (1, 2, 2)
21
+ t2v_A14B.dim = 5120
22
+ t2v_A14B.ffn_dim = 13824
23
+ t2v_A14B.freq_dim = 256
24
+ t2v_A14B.num_heads = 40
25
+ t2v_A14B.num_layers = 40
26
+ t2v_A14B.window_size = (-1, -1)
27
+ t2v_A14B.qk_norm = True
28
+ t2v_A14B.cross_attn_norm = True
29
+ t2v_A14B.eps = 1e-6
30
+ t2v_A14B.low_noise_checkpoint = 'low_noise_model'
31
+ t2v_A14B.high_noise_checkpoint = 'high_noise_model'
32
+
33
+ # inference
34
+ t2v_A14B.sample_shift = 12.0
35
+ t2v_A14B.sample_steps = 40
36
+ t2v_A14B.boundary = 0.875
37
+ t2v_A14B.sample_guide_scale = (3.0, 4.0) # low noise, high noise
wan/configs/wan_ti2v_5B.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from easydict import EasyDict
3
+
4
+ from .shared_config import wan_shared_cfg
5
+
6
+ #------------------------ Wan TI2V 5B ------------------------#
7
+
8
+ ti2v_5B = EasyDict(__name__='Config: Wan TI2V 5B')
9
+ ti2v_5B.update(wan_shared_cfg)
10
+
11
+ # t5
12
+ ti2v_5B.t5_checkpoint = 'models_t5_umt5-xxl-enc-bf16.pth'
13
+ ti2v_5B.t5_tokenizer = 'google/umt5-xxl'
14
+
15
+ # vae
16
+ ti2v_5B.vae_checkpoint = 'Wan2.2_VAE.pth'
17
+ ti2v_5B.vae_stride = (4, 16, 16)
18
+
19
+ # transformer
20
+ ti2v_5B.patch_size = (1, 2, 2)
21
+ ti2v_5B.dim = 3072
22
+ ti2v_5B.ffn_dim = 14336
23
+ ti2v_5B.freq_dim = 256
24
+ ti2v_5B.num_heads = 24
25
+ ti2v_5B.num_layers = 30
26
+ ti2v_5B.window_size = (-1, -1)
27
+ ti2v_5B.qk_norm = True
28
+ ti2v_5B.cross_attn_norm = True
29
+ ti2v_5B.eps = 1e-6
30
+
31
+ # inference
32
+ ti2v_5B.sample_fps = 24
33
+ ti2v_5B.sample_shift = 5.0
34
+ ti2v_5B.sample_steps = 50
35
+ ti2v_5B.sample_guide_scale = 5.0
36
+ ti2v_5B.frame_num = 121
wan/distributed/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
wan/distributed/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (151 Bytes). View file
 
wan/distributed/__pycache__/fsdp.cpython-310.pyc ADDED
Binary file (1.37 kB). View file
 
wan/distributed/__pycache__/sequence_parallel.cpython-310.pyc ADDED
Binary file (5.25 kB). View file
 
wan/distributed/__pycache__/ulysses.cpython-310.pyc ADDED
Binary file (1.24 kB). View file
 
wan/distributed/__pycache__/util.cpython-310.pyc ADDED
Binary file (1.94 kB). View file
 
wan/distributed/fsdp.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ from functools import partial
4
+
5
+ import torch
6
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
7
+ from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
8
+ from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
9
+ from torch.distributed.utils import _free_storage
10
+
11
+
12
+ def shard_model(
13
+ model,
14
+ device_id,
15
+ param_dtype=torch.bfloat16,
16
+ reduce_dtype=torch.float32,
17
+ buffer_dtype=torch.float32,
18
+ process_group=None,
19
+ sharding_strategy=ShardingStrategy.FULL_SHARD,
20
+ sync_module_states=True,
21
+ ):
22
+ model = FSDP(
23
+ module=model,
24
+ process_group=process_group,
25
+ sharding_strategy=sharding_strategy,
26
+ auto_wrap_policy=partial(
27
+ lambda_auto_wrap_policy, lambda_fn=lambda m: m in model.blocks),
28
+ mixed_precision=MixedPrecision(
29
+ param_dtype=param_dtype,
30
+ reduce_dtype=reduce_dtype,
31
+ buffer_dtype=buffer_dtype),
32
+ device_id=device_id,
33
+ sync_module_states=sync_module_states)
34
+ return model
35
+
36
+
37
+ def free_model(model):
38
+ for m in model.modules():
39
+ if isinstance(m, FSDP):
40
+ _free_storage(m._handle.flat_param.data)
41
+ del model
42
+ gc.collect()
43
+ torch.cuda.empty_cache()
wan/distributed/sequence_parallel.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.cuda.amp as amp
4
+
5
+ from ..modules.model import sinusoidal_embedding_1d
6
+ from .ulysses import distributed_attention
7
+ from .util import gather_forward, get_rank, get_world_size
8
+
9
+
10
+ def pad_freqs(original_tensor, target_len):
11
+ seq_len, s1, s2 = original_tensor.shape
12
+ pad_size = target_len - seq_len
13
+ padding_tensor = torch.ones(
14
+ pad_size,
15
+ s1,
16
+ s2,
17
+ dtype=original_tensor.dtype,
18
+ device=original_tensor.device)
19
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
+ return padded_tensor
21
+
22
+
23
+ @torch.amp.autocast('cuda', enabled=False)
24
+ def rope_apply(x, grid_sizes, freqs):
25
+ """
26
+ x: [B, L, N, C].
27
+ grid_sizes: [B, 3].
28
+ freqs: [M, C // 2].
29
+ """
30
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
+ # split freqs
32
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
+
34
+ # loop over samples
35
+ output = []
36
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
+ seq_len = f * h * w
38
+
39
+ # precompute multipliers
40
+ x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
41
+ s, n, -1, 2))
42
+ freqs_i = torch.cat([
43
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
+ ],
47
+ dim=-1).reshape(seq_len, 1, -1)
48
+
49
+ # apply rotary embedding
50
+ sp_size = get_world_size()
51
+ sp_rank = get_rank()
52
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
53
+ s_per_rank = s
54
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
+ s_per_rank), :, :]
56
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
+ x_i = torch.cat([x_i, x[i, s:]])
58
+
59
+ # append to collection
60
+ output.append(x_i)
61
+ return torch.stack(output).float()
62
+
63
+
64
+ def sp_dit_forward(
65
+ self,
66
+ x,
67
+ t,
68
+ context,
69
+ seq_len,
70
+ y=None,
71
+ ):
72
+ """
73
+ x: A list of videos each with shape [C, T, H, W].
74
+ t: [B].
75
+ context: A list of text embeddings each with shape [L, C].
76
+ """
77
+ if self.model_type == 'i2v':
78
+ assert y is not None
79
+ # params
80
+ device = self.patch_embedding.weight.device
81
+ if self.freqs.device != device:
82
+ self.freqs = self.freqs.to(device)
83
+
84
+ if y is not None:
85
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
86
+
87
+ # embeddings
88
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
89
+ grid_sizes = torch.stack(
90
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
91
+ x = [u.flatten(2).transpose(1, 2) for u in x]
92
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
93
+ assert seq_lens.max() <= seq_len
94
+ x = torch.cat([
95
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
96
+ for u in x
97
+ ])
98
+
99
+ # time embeddings
100
+ if t.dim() == 1:
101
+ t = t.expand(t.size(0), seq_len)
102
+ with torch.amp.autocast('cuda', dtype=torch.float32):
103
+ bt = t.size(0)
104
+ t = t.flatten()
105
+ e = self.time_embedding(
106
+ sinusoidal_embedding_1d(self.freq_dim,
107
+ t).unflatten(0, (bt, seq_len)).float())
108
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
109
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
110
+
111
+ # context
112
+ context_lens = None
113
+ context = self.text_embedding(
114
+ torch.stack([
115
+ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
116
+ for u in context
117
+ ]))
118
+
119
+ # Context Parallel
120
+ x = torch.chunk(x, get_world_size(), dim=1)[get_rank()]
121
+ e = torch.chunk(e, get_world_size(), dim=1)[get_rank()]
122
+ e0 = torch.chunk(e0, get_world_size(), dim=1)[get_rank()]
123
+
124
+ # arguments
125
+ kwargs = dict(
126
+ e=e0,
127
+ seq_lens=seq_lens,
128
+ grid_sizes=grid_sizes,
129
+ freqs=self.freqs,
130
+ context=context,
131
+ context_lens=context_lens)
132
+
133
+ for block in self.blocks:
134
+ x = block(x, **kwargs)
135
+
136
+ # head
137
+ x = self.head(x, e)
138
+
139
+ # Context Parallel
140
+ x = gather_forward(x, dim=1)
141
+
142
+ # unpatchify
143
+ x = self.unpatchify(x, grid_sizes)
144
+ return [u.float() for u in x]
145
+
146
+
147
+ def sp_attn_forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16):
148
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
149
+ half_dtypes = (torch.float16, torch.bfloat16)
150
+
151
+ def half(x):
152
+ return x if x.dtype in half_dtypes else x.to(dtype)
153
+
154
+ # query, key, value function
155
+ def qkv_fn(x):
156
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
157
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
158
+ v = self.v(x).view(b, s, n, d)
159
+ return q, k, v
160
+
161
+ q, k, v = qkv_fn(x)
162
+ q = rope_apply(q, grid_sizes, freqs)
163
+ k = rope_apply(k, grid_sizes, freqs)
164
+
165
+ x = distributed_attention(
166
+ half(q),
167
+ half(k),
168
+ half(v),
169
+ seq_lens,
170
+ window_size=self.window_size,
171
+ )
172
+
173
+ # output
174
+ x = x.flatten(2)
175
+ x = self.o(x)
176
+ return x
wan/distributed/ulysses.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+ from ..modules.attention import flash_attention
6
+ from .util import all_to_all
7
+
8
+
9
+ def distributed_attention(
10
+ q,
11
+ k,
12
+ v,
13
+ seq_lens,
14
+ window_size=(-1, -1),
15
+ ):
16
+ """
17
+ Performs distributed attention based on DeepSpeed Ulysses attention mechanism.
18
+ please refer to https://arxiv.org/pdf/2309.14509
19
+
20
+ Args:
21
+ q: [B, Lq // p, Nq, C1].
22
+ k: [B, Lk // p, Nk, C1].
23
+ v: [B, Lk // p, Nk, C2]. Nq must be divisible by Nk.
24
+ seq_lens: [B], length of each sequence in batch
25
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
26
+ """
27
+ if not dist.is_initialized():
28
+ raise ValueError("distributed group should be initialized.")
29
+ b = q.shape[0]
30
+
31
+ # gather q/k/v sequence
32
+ q = all_to_all(q, scatter_dim=2, gather_dim=1)
33
+ k = all_to_all(k, scatter_dim=2, gather_dim=1)
34
+ v = all_to_all(v, scatter_dim=2, gather_dim=1)
35
+
36
+ # apply attention
37
+ x = flash_attention(
38
+ q,
39
+ k,
40
+ v,
41
+ k_lens=seq_lens,
42
+ window_size=window_size,
43
+ )
44
+
45
+ # scatter q/k/v sequence
46
+ x = all_to_all(x, scatter_dim=1, gather_dim=2)
47
+ return x
wan/distributed/util.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+ import torch.distributed as dist
4
+
5
+
6
+ def init_distributed_group():
7
+ """r initialize sequence parallel group.
8
+ """
9
+ if not dist.is_initialized():
10
+ dist.init_process_group(backend='nccl')
11
+
12
+
13
+ def get_rank():
14
+ return dist.get_rank()
15
+
16
+
17
+ def get_world_size():
18
+ return dist.get_world_size()
19
+
20
+
21
+ def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
22
+ """
23
+ `scatter` along one dimension and `gather` along another.
24
+ """
25
+ world_size = get_world_size()
26
+ if world_size > 1:
27
+ inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
28
+ outputs = [torch.empty_like(u) for u in inputs]
29
+ dist.all_to_all(outputs, inputs, group=group, **kwargs)
30
+ x = torch.cat(outputs, dim=gather_dim).contiguous()
31
+ return x
32
+
33
+
34
+ def all_gather(tensor):
35
+ world_size = dist.get_world_size()
36
+ if world_size == 1:
37
+ return [tensor]
38
+ tensor_list = [torch.empty_like(tensor) for _ in range(world_size)]
39
+ torch.distributed.all_gather(tensor_list, tensor)
40
+ return tensor_list
41
+
42
+
43
+ def gather_forward(input, dim):
44
+ # skip if world_size == 1
45
+ world_size = dist.get_world_size()
46
+ if world_size == 1:
47
+ return input
48
+
49
+ # gather sequence
50
+ output = all_gather(input)
51
+ return torch.cat(output, dim=dim).contiguous()
wan/image2video.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch.cuda.amp as amp
15
+ import torch.distributed as dist
16
+ import torchvision.transforms.functional as TF
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
+ from .distributed.util import get_world_size
22
+ from .modules.model import WanModel
23
+ from .modules.t5 import T5EncoderModel
24
+ from .modules.vae2_1 import Wan2_1_VAE
25
+ from .utils.fm_solvers import (
26
+ FlowDPMSolverMultistepScheduler,
27
+ get_sampling_sigmas,
28
+ retrieve_timesteps,
29
+ )
30
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
+
32
+
33
+ class WanI2V:
34
+
35
+ def __init__(
36
+ self,
37
+ config,
38
+ checkpoint_dir,
39
+ device_id=0,
40
+ rank=0,
41
+ t5_fsdp=False,
42
+ dit_fsdp=False,
43
+ use_sp=False,
44
+ t5_cpu=False,
45
+ init_on_cpu=True,
46
+ convert_model_dtype=False,
47
+ ):
48
+ r"""
49
+ Initializes the image-to-video generation model components.
50
+
51
+ Args:
52
+ config (EasyDict):
53
+ Object containing model parameters initialized from config.py
54
+ checkpoint_dir (`str`):
55
+ Path to directory containing model checkpoints
56
+ device_id (`int`, *optional*, defaults to 0):
57
+ Id of target GPU device
58
+ rank (`int`, *optional*, defaults to 0):
59
+ Process rank for distributed training
60
+ t5_fsdp (`bool`, *optional*, defaults to False):
61
+ Enable FSDP sharding for T5 model
62
+ dit_fsdp (`bool`, *optional*, defaults to False):
63
+ Enable FSDP sharding for DiT model
64
+ use_sp (`bool`, *optional*, defaults to False):
65
+ Enable distribution strategy of sequence parallel.
66
+ t5_cpu (`bool`, *optional*, defaults to False):
67
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
68
+ init_on_cpu (`bool`, *optional*, defaults to True):
69
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
70
+ convert_model_dtype (`bool`, *optional*, defaults to False):
71
+ Convert DiT model parameters dtype to 'config.param_dtype'.
72
+ Only works without FSDP.
73
+ """
74
+ self.device = torch.device(f"cuda:{device_id}")
75
+ self.config = config
76
+ self.rank = rank
77
+ self.t5_cpu = t5_cpu
78
+ self.init_on_cpu = init_on_cpu
79
+
80
+ self.num_train_timesteps = config.num_train_timesteps
81
+ self.boundary = config.boundary
82
+ self.param_dtype = config.param_dtype
83
+
84
+ if t5_fsdp or dit_fsdp or use_sp:
85
+ self.init_on_cpu = False
86
+
87
+ shard_fn = partial(shard_model, device_id=device_id)
88
+ self.text_encoder = T5EncoderModel(
89
+ text_len=config.text_len,
90
+ dtype=config.t5_dtype,
91
+ device=torch.device('cpu'),
92
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
+ shard_fn=shard_fn if t5_fsdp else None,
95
+ )
96
+
97
+ self.vae_stride = config.vae_stride
98
+ self.patch_size = config.patch_size
99
+ self.vae = Wan2_1_VAE(
100
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
101
+ device=self.device)
102
+
103
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
104
+ self.low_noise_model = WanModel.from_pretrained(
105
+ checkpoint_dir, subfolder=config.low_noise_checkpoint)
106
+ self.low_noise_model = self._configure_model(
107
+ model=self.low_noise_model,
108
+ use_sp=use_sp,
109
+ dit_fsdp=dit_fsdp,
110
+ shard_fn=shard_fn,
111
+ convert_model_dtype=convert_model_dtype)
112
+
113
+ self.high_noise_model = WanModel.from_pretrained(
114
+ checkpoint_dir, subfolder=config.high_noise_checkpoint)
115
+ self.high_noise_model = self._configure_model(
116
+ model=self.high_noise_model,
117
+ use_sp=use_sp,
118
+ dit_fsdp=dit_fsdp,
119
+ shard_fn=shard_fn,
120
+ convert_model_dtype=convert_model_dtype)
121
+ if use_sp:
122
+ self.sp_size = get_world_size()
123
+ else:
124
+ self.sp_size = 1
125
+
126
+ self.sample_neg_prompt = config.sample_neg_prompt
127
+
128
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
129
+ convert_model_dtype):
130
+ """
131
+ Configures a model object. This includes setting evaluation modes,
132
+ applying distributed parallel strategy, and handling device placement.
133
+
134
+ Args:
135
+ model (torch.nn.Module):
136
+ The model instance to configure.
137
+ use_sp (`bool`):
138
+ Enable distribution strategy of sequence parallel.
139
+ dit_fsdp (`bool`):
140
+ Enable FSDP sharding for DiT model.
141
+ shard_fn (callable):
142
+ The function to apply FSDP sharding.
143
+ convert_model_dtype (`bool`):
144
+ Convert DiT model parameters dtype to 'config.param_dtype'.
145
+ Only works without FSDP.
146
+
147
+ Returns:
148
+ torch.nn.Module:
149
+ The configured model.
150
+ """
151
+ model.eval().requires_grad_(False)
152
+
153
+ if use_sp:
154
+ for block in model.blocks:
155
+ block.self_attn.forward = types.MethodType(
156
+ sp_attn_forward, block.self_attn)
157
+ model.forward = types.MethodType(sp_dit_forward, model)
158
+
159
+ if dist.is_initialized():
160
+ dist.barrier()
161
+
162
+ if dit_fsdp:
163
+ model = shard_fn(model)
164
+ else:
165
+ if convert_model_dtype:
166
+ model.to(self.param_dtype)
167
+ if not self.init_on_cpu:
168
+ model.to(self.device)
169
+
170
+ return model
171
+
172
+ def _prepare_model_for_timestep(self, t, boundary, offload_model):
173
+ r"""
174
+ Prepares and returns the required model for the current timestep.
175
+
176
+ Args:
177
+ t (torch.Tensor):
178
+ current timestep.
179
+ boundary (`int`):
180
+ The timestep threshold. If `t` is at or above this value,
181
+ the `high_noise_model` is considered as the required model.
182
+ offload_model (`bool`):
183
+ A flag intended to control the offloading behavior.
184
+
185
+ Returns:
186
+ torch.nn.Module:
187
+ The active model on the target device for the current timestep.
188
+ """
189
+ if t.item() >= boundary:
190
+ required_model_name = 'high_noise_model'
191
+ offload_model_name = 'low_noise_model'
192
+ else:
193
+ required_model_name = 'low_noise_model'
194
+ offload_model_name = 'high_noise_model'
195
+ if offload_model or self.init_on_cpu:
196
+ if next(getattr(
197
+ self,
198
+ offload_model_name).parameters()).device.type == 'cuda':
199
+ getattr(self, offload_model_name).to('cpu')
200
+ if next(getattr(
201
+ self,
202
+ required_model_name).parameters()).device.type == 'cpu':
203
+ getattr(self, required_model_name).to(self.device)
204
+ return getattr(self, required_model_name)
205
+
206
+ def generate(self,
207
+ input_prompt,
208
+ img,
209
+ max_area=720 * 1280,
210
+ frame_num=81,
211
+ shift=5.0,
212
+ sample_solver='unipc',
213
+ sampling_steps=40,
214
+ guide_scale=5.0,
215
+ n_prompt="",
216
+ seed=-1,
217
+ offload_model=True):
218
+ r"""
219
+ Generates video frames from input image and text prompt using diffusion process.
220
+
221
+ Args:
222
+ input_prompt (`str`):
223
+ Text prompt for content generation.
224
+ img (PIL.Image.Image):
225
+ Input image tensor. Shape: [3, H, W]
226
+ max_area (`int`, *optional*, defaults to 720*1280):
227
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
228
+ frame_num (`int`, *optional*, defaults to 81):
229
+ How many frames to sample from a video. The number should be 4n+1
230
+ shift (`float`, *optional*, defaults to 5.0):
231
+ Noise schedule shift parameter. Affects temporal dynamics
232
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
233
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
234
+ Solver used to sample the video.
235
+ sampling_steps (`int`, *optional*, defaults to 40):
236
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
237
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
238
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
239
+ If tuple, the first guide_scale will be used for low noise model and
240
+ the second guide_scale will be used for high noise model.
241
+ n_prompt (`str`, *optional*, defaults to ""):
242
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
243
+ seed (`int`, *optional*, defaults to -1):
244
+ Random seed for noise generation. If -1, use random seed
245
+ offload_model (`bool`, *optional*, defaults to True):
246
+ If True, offloads models to CPU during generation to save VRAM
247
+
248
+ Returns:
249
+ torch.Tensor:
250
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
251
+ - C: Color channels (3 for RGB)
252
+ - N: Number of frames (81)
253
+ - H: Frame height (from max_area)
254
+ - W: Frame width from max_area)
255
+ """
256
+ # preprocess
257
+ guide_scale = (guide_scale, guide_scale) if isinstance(
258
+ guide_scale, float) else guide_scale
259
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device)
260
+
261
+ F = frame_num
262
+ h, w = img.shape[1:]
263
+ aspect_ratio = h / w
264
+ lat_h = round(
265
+ np.sqrt(max_area * aspect_ratio) // self.vae_stride[1] //
266
+ self.patch_size[1] * self.patch_size[1])
267
+ lat_w = round(
268
+ np.sqrt(max_area / aspect_ratio) // self.vae_stride[2] //
269
+ self.patch_size[2] * self.patch_size[2])
270
+ h = lat_h * self.vae_stride[1]
271
+ w = lat_w * self.vae_stride[2]
272
+
273
+ max_seq_len = ((F - 1) // self.vae_stride[0] + 1) * lat_h * lat_w // (
274
+ self.patch_size[1] * self.patch_size[2])
275
+ max_seq_len = int(math.ceil(max_seq_len / self.sp_size)) * self.sp_size
276
+
277
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
278
+ seed_g = torch.Generator(device=self.device)
279
+ seed_g.manual_seed(seed)
280
+ noise = torch.randn(
281
+ 16,
282
+ 21,
283
+ lat_h,
284
+ lat_w,
285
+ dtype=torch.float32,
286
+ generator=seed_g,
287
+ device=self.device)
288
+
289
+ msk = torch.ones(1, 81, lat_h, lat_w, device=self.device)
290
+ msk[:, 1:] = 0
291
+ msk = torch.concat([
292
+ torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]
293
+ ],
294
+ dim=1)
295
+ msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
296
+ msk = msk.transpose(1, 2)[0]
297
+
298
+ if n_prompt == "":
299
+ n_prompt = self.sample_neg_prompt
300
+
301
+ # preprocess
302
+ if not self.t5_cpu:
303
+ self.text_encoder.model.to(self.device)
304
+ context = self.text_encoder([input_prompt], self.device)
305
+ context_null = self.text_encoder([n_prompt], self.device)
306
+ if offload_model:
307
+ self.text_encoder.model.cpu()
308
+ else:
309
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
310
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
311
+ context = [t.to(self.device) for t in context]
312
+ context_null = [t.to(self.device) for t in context_null]
313
+
314
+ y = self.vae.encode([
315
+ torch.concat([
316
+ torch.nn.functional.interpolate(
317
+ img[None].cpu(), size=(h, w), mode='bicubic').transpose(
318
+ 0, 1),
319
+ torch.zeros(3, 80, h, w)
320
+ ],
321
+ dim=1).to(self.device)
322
+ ])[0]
323
+ y = torch.concat([msk, y])
324
+
325
+ @contextmanager
326
+ def noop_no_sync():
327
+ yield
328
+
329
+ no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
330
+ noop_no_sync)
331
+ no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
332
+ noop_no_sync)
333
+
334
+ # evaluation mode
335
+ with (
336
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
337
+ torch.no_grad(),
338
+ no_sync_low_noise(),
339
+ no_sync_high_noise(),
340
+ ):
341
+ boundary = self.boundary * self.num_train_timesteps
342
+
343
+ if sample_solver == 'unipc':
344
+ sample_scheduler = FlowUniPCMultistepScheduler(
345
+ num_train_timesteps=self.num_train_timesteps,
346
+ shift=1,
347
+ use_dynamic_shifting=False)
348
+ sample_scheduler.set_timesteps(
349
+ sampling_steps, device=self.device, shift=shift)
350
+ timesteps = sample_scheduler.timesteps
351
+ elif sample_solver == 'dpm++':
352
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
353
+ num_train_timesteps=self.num_train_timesteps,
354
+ shift=1,
355
+ use_dynamic_shifting=False)
356
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
357
+ timesteps, _ = retrieve_timesteps(
358
+ sample_scheduler,
359
+ device=self.device,
360
+ sigmas=sampling_sigmas)
361
+ else:
362
+ raise NotImplementedError("Unsupported solver.")
363
+
364
+ # sample videos
365
+ latent = noise
366
+
367
+ arg_c = {
368
+ 'context': [context[0]],
369
+ 'seq_len': max_seq_len,
370
+ 'y': [y],
371
+ }
372
+
373
+ arg_null = {
374
+ 'context': context_null,
375
+ 'seq_len': max_seq_len,
376
+ 'y': [y],
377
+ }
378
+
379
+ if offload_model:
380
+ torch.cuda.empty_cache()
381
+
382
+ for _, t in enumerate(tqdm(timesteps)):
383
+ latent_model_input = [latent.to(self.device)]
384
+ timestep = [t]
385
+
386
+ timestep = torch.stack(timestep).to(self.device)
387
+
388
+ model = self._prepare_model_for_timestep(
389
+ t, boundary, offload_model)
390
+ sample_guide_scale = guide_scale[1] if t.item(
391
+ ) >= boundary else guide_scale[0]
392
+
393
+ noise_pred_cond = model(
394
+ latent_model_input, t=timestep, **arg_c)[0]
395
+ if offload_model:
396
+ torch.cuda.empty_cache()
397
+ noise_pred_uncond = model(
398
+ latent_model_input, t=timestep, **arg_null)[0]
399
+ if offload_model:
400
+ torch.cuda.empty_cache()
401
+ noise_pred = noise_pred_uncond + sample_guide_scale * (
402
+ noise_pred_cond - noise_pred_uncond)
403
+
404
+ temp_x0 = sample_scheduler.step(
405
+ noise_pred.unsqueeze(0),
406
+ t,
407
+ latent.unsqueeze(0),
408
+ return_dict=False,
409
+ generator=seed_g)[0]
410
+ latent = temp_x0.squeeze(0)
411
+
412
+ x0 = [latent]
413
+ del latent_model_input, timestep
414
+
415
+ if offload_model:
416
+ self.low_noise_model.cpu()
417
+ self.high_noise_model.cpu()
418
+ torch.cuda.empty_cache()
419
+
420
+ if self.rank == 0:
421
+ videos = self.vae.decode(x0)
422
+
423
+ del noise, latent, x0
424
+ del sample_scheduler
425
+ if offload_model:
426
+ gc.collect()
427
+ torch.cuda.synchronize()
428
+ if dist.is_initialized():
429
+ dist.barrier()
430
+
431
+ return videos[0] if self.rank == 0 else None
wan/modules/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .attention import flash_attention
3
+ from .model import WanModel
4
+ from .t5 import T5Decoder, T5Encoder, T5EncoderModel, T5Model
5
+ from .tokenizers import HuggingfaceTokenizer
6
+ from .vae2_1 import Wan2_1_VAE
7
+ from .vae2_2 import Wan2_2_VAE
8
+
9
+ __all__ = [
10
+ 'Wan2_1_VAE',
11
+ 'Wan2_2_VAE',
12
+ 'WanModel',
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ 'HuggingfaceTokenizer',
18
+ 'flash_attention',
19
+ ]
wan/modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (536 Bytes). View file
 
wan/modules/__pycache__/attention.cpython-310.pyc ADDED
Binary file (3.96 kB). View file
 
wan/modules/__pycache__/model.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
wan/modules/__pycache__/t5.cpython-310.pyc ADDED
Binary file (12.9 kB). View file
 
wan/modules/__pycache__/tokenizers.cpython-310.pyc ADDED
Binary file (2.56 kB). View file
 
wan/modules/__pycache__/vae2_1.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
wan/modules/__pycache__/vae2_2.cpython-310.pyc ADDED
Binary file (22.1 kB). View file
 
wan/modules/attention.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn_interface
6
+ FLASH_ATTN_3_AVAILABLE = True
7
+ except ModuleNotFoundError:
8
+ FLASH_ATTN_3_AVAILABLE = False
9
+
10
+ try:
11
+ import flash_attn
12
+ FLASH_ATTN_2_AVAILABLE = True
13
+ except ModuleNotFoundError:
14
+ FLASH_ATTN_2_AVAILABLE = False
15
+
16
+ import warnings
17
+
18
+ __all__ = [
19
+ 'flash_attention',
20
+ 'attention',
21
+ ]
22
+
23
+
24
+ def flash_attention(
25
+ q,
26
+ k,
27
+ v,
28
+ q_lens=None,
29
+ k_lens=None,
30
+ dropout_p=0.,
31
+ softmax_scale=None,
32
+ q_scale=None,
33
+ causal=False,
34
+ window_size=(-1, -1),
35
+ deterministic=False,
36
+ dtype=torch.bfloat16,
37
+ version=None,
38
+ ):
39
+ """
40
+ q: [B, Lq, Nq, C1].
41
+ k: [B, Lk, Nk, C1].
42
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
43
+ q_lens: [B].
44
+ k_lens: [B].
45
+ dropout_p: float. Dropout probability.
46
+ softmax_scale: float. The scaling of QK^T before applying softmax.
47
+ causal: bool. Whether to apply causal attention mask.
48
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
49
+ deterministic: bool. If True, slightly slower and uses more memory.
50
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
51
+ """
52
+ half_dtypes = (torch.float16, torch.bfloat16)
53
+ assert dtype in half_dtypes
54
+ assert q.device.type == 'cuda' and q.size(-1) <= 256
55
+
56
+ # params
57
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
58
+
59
+ def half(x):
60
+ return x if x.dtype in half_dtypes else x.to(dtype)
61
+
62
+ # preprocess query
63
+ if q_lens is None:
64
+ q = half(q.flatten(0, 1))
65
+ q_lens = torch.tensor(
66
+ [lq] * b, dtype=torch.int32).to(
67
+ device=q.device, non_blocking=True)
68
+ else:
69
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
70
+
71
+ # preprocess key, value
72
+ if k_lens is None:
73
+ k = half(k.flatten(0, 1))
74
+ v = half(v.flatten(0, 1))
75
+ k_lens = torch.tensor(
76
+ [lk] * b, dtype=torch.int32).to(
77
+ device=k.device, non_blocking=True)
78
+ else:
79
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
80
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
81
+
82
+ q = q.to(v.dtype)
83
+ k = k.to(v.dtype)
84
+
85
+ if q_scale is not None:
86
+ q = q * q_scale
87
+
88
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
89
+ warnings.warn(
90
+ 'Flash attention 3 is not available, use flash attention 2 instead.'
91
+ )
92
+
93
+ # apply attention
94
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
95
+ # Note: dropout_p, window_size are not supported in FA3 now.
96
+ x = flash_attn_interface.flash_attn_varlen_func(
97
+ q=q,
98
+ k=k,
99
+ v=v,
100
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
101
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
102
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
103
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
104
+ seqused_q=None,
105
+ seqused_k=None,
106
+ max_seqlen_q=lq,
107
+ max_seqlen_k=lk,
108
+ softmax_scale=softmax_scale,
109
+ causal=causal,
110
+ deterministic=deterministic)[0].unflatten(0, (b, lq))
111
+ else:
112
+ assert FLASH_ATTN_2_AVAILABLE
113
+ x = flash_attn.flash_attn_varlen_func(
114
+ q=q,
115
+ k=k,
116
+ v=v,
117
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
118
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
119
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
120
+ 0, dtype=torch.int32).to(q.device, non_blocking=True),
121
+ max_seqlen_q=lq,
122
+ max_seqlen_k=lk,
123
+ dropout_p=dropout_p,
124
+ softmax_scale=softmax_scale,
125
+ causal=causal,
126
+ window_size=window_size,
127
+ deterministic=deterministic).unflatten(0, (b, lq))
128
+
129
+ # output
130
+ return x.type(out_dtype)
131
+
132
+
133
+ def attention(
134
+ q,
135
+ k,
136
+ v,
137
+ q_lens=None,
138
+ k_lens=None,
139
+ dropout_p=0.,
140
+ softmax_scale=None,
141
+ q_scale=None,
142
+ causal=False,
143
+ window_size=(-1, -1),
144
+ deterministic=False,
145
+ dtype=torch.bfloat16,
146
+ fa_version=None,
147
+ ):
148
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
149
+ return flash_attention(
150
+ q=q,
151
+ k=k,
152
+ v=v,
153
+ q_lens=q_lens,
154
+ k_lens=k_lens,
155
+ dropout_p=dropout_p,
156
+ softmax_scale=softmax_scale,
157
+ q_scale=q_scale,
158
+ causal=causal,
159
+ window_size=window_size,
160
+ deterministic=deterministic,
161
+ dtype=dtype,
162
+ version=fa_version,
163
+ )
164
+ else:
165
+ if q_lens is not None or k_lens is not None:
166
+ warnings.warn(
167
+ 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
168
+ )
169
+ attn_mask = None
170
+
171
+ q = q.transpose(1, 2).to(dtype)
172
+ k = k.transpose(1, 2).to(dtype)
173
+ v = v.transpose(1, 2).to(dtype)
174
+
175
+ out = torch.nn.functional.scaled_dot_product_attention(
176
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
177
+
178
+ out = out.transpose(1, 2).contiguous()
179
+ return out
wan/modules/model.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.modeling_utils import ModelMixin
8
+
9
+ from .attention import flash_attention
10
+
11
+ __all__ = ['WanModel']
12
+
13
+
14
+ def sinusoidal_embedding_1d(dim, position):
15
+ # preprocess
16
+ assert dim % 2 == 0
17
+ half = dim // 2
18
+ position = position.type(torch.float64)
19
+
20
+ # calculation
21
+ sinusoid = torch.outer(
22
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
23
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
24
+ return x
25
+
26
+
27
+ @torch.amp.autocast('cuda', enabled=False)
28
+ def rope_params(max_seq_len, dim, theta=10000):
29
+ assert dim % 2 == 0
30
+ freqs = torch.outer(
31
+ torch.arange(max_seq_len),
32
+ 1.0 / torch.pow(theta,
33
+ torch.arange(0, dim, 2).to(torch.float64).div(dim)))
34
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
35
+ return freqs
36
+
37
+
38
+ @torch.amp.autocast('cuda', enabled=False)
39
+ def rope_apply(x, grid_sizes, freqs):
40
+ n, c = x.size(2), x.size(3) // 2
41
+
42
+ # split freqs
43
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
44
+
45
+ # loop over samples
46
+ output = []
47
+ for i, (f, h, w) in enumerate(grid_sizes.tolist()):
48
+ seq_len = f * h * w
49
+
50
+ # precompute multipliers
51
+ x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
52
+ seq_len, n, -1, 2))
53
+ freqs_i = torch.cat([
54
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
55
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
56
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
57
+ ],
58
+ dim=-1).reshape(seq_len, 1, -1)
59
+
60
+ # apply rotary embedding
61
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
62
+ x_i = torch.cat([x_i, x[i, seq_len:]])
63
+
64
+ # append to collection
65
+ output.append(x_i)
66
+ return torch.stack(output).float()
67
+
68
+
69
+ class WanRMSNorm(nn.Module):
70
+
71
+ def __init__(self, dim, eps=1e-5):
72
+ super().__init__()
73
+ self.dim = dim
74
+ self.eps = eps
75
+ self.weight = nn.Parameter(torch.ones(dim))
76
+
77
+ def forward(self, x):
78
+ r"""
79
+ Args:
80
+ x(Tensor): Shape [B, L, C]
81
+ """
82
+ return self._norm(x.float()).type_as(x) * self.weight
83
+
84
+ def _norm(self, x):
85
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
86
+
87
+
88
+ class WanLayerNorm(nn.LayerNorm):
89
+
90
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
91
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
92
+
93
+ def forward(self, x):
94
+ r"""
95
+ Args:
96
+ x(Tensor): Shape [B, L, C]
97
+ """
98
+ return super().forward(x.float()).type_as(x)
99
+
100
+
101
+ class WanSelfAttention(nn.Module):
102
+
103
+ def __init__(self,
104
+ dim,
105
+ num_heads,
106
+ window_size=(-1, -1),
107
+ qk_norm=True,
108
+ eps=1e-6):
109
+ assert dim % num_heads == 0
110
+ super().__init__()
111
+ self.dim = dim
112
+ self.num_heads = num_heads
113
+ self.head_dim = dim // num_heads
114
+ self.window_size = window_size
115
+ self.qk_norm = qk_norm
116
+ self.eps = eps
117
+
118
+ # layers
119
+ self.q = nn.Linear(dim, dim)
120
+ self.k = nn.Linear(dim, dim)
121
+ self.v = nn.Linear(dim, dim)
122
+ self.o = nn.Linear(dim, dim)
123
+ self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
124
+ self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
+
126
+ def forward(self, x, seq_lens, grid_sizes, freqs):
127
+ r"""
128
+ Args:
129
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
130
+ seq_lens(Tensor): Shape [B]
131
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
132
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
133
+ """
134
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
135
+
136
+ # query, key, value function
137
+ def qkv_fn(x):
138
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
139
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
140
+ v = self.v(x).view(b, s, n, d)
141
+ return q, k, v
142
+
143
+ q, k, v = qkv_fn(x)
144
+
145
+ x = flash_attention(
146
+ q=rope_apply(q, grid_sizes, freqs),
147
+ k=rope_apply(k, grid_sizes, freqs),
148
+ v=v,
149
+ k_lens=seq_lens,
150
+ window_size=self.window_size)
151
+
152
+ # output
153
+ x = x.flatten(2)
154
+ x = self.o(x)
155
+ return x
156
+
157
+
158
+ class WanCrossAttention(WanSelfAttention):
159
+
160
+ def forward(self, x, context, context_lens):
161
+ r"""
162
+ Args:
163
+ x(Tensor): Shape [B, L1, C]
164
+ context(Tensor): Shape [B, L2, C]
165
+ context_lens(Tensor): Shape [B]
166
+ """
167
+ b, n, d = x.size(0), self.num_heads, self.head_dim
168
+
169
+ # compute query, key, value
170
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
171
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
172
+ v = self.v(context).view(b, -1, n, d)
173
+
174
+ # compute attention
175
+ x = flash_attention(q, k, v, k_lens=context_lens)
176
+
177
+ # output
178
+ x = x.flatten(2)
179
+ x = self.o(x)
180
+ return x
181
+
182
+
183
+ class WanAttentionBlock(nn.Module):
184
+
185
+ def __init__(self,
186
+ dim,
187
+ ffn_dim,
188
+ num_heads,
189
+ window_size=(-1, -1),
190
+ qk_norm=True,
191
+ cross_attn_norm=False,
192
+ eps=1e-6):
193
+ super().__init__()
194
+ self.dim = dim
195
+ self.ffn_dim = ffn_dim
196
+ self.num_heads = num_heads
197
+ self.window_size = window_size
198
+ self.qk_norm = qk_norm
199
+ self.cross_attn_norm = cross_attn_norm
200
+ self.eps = eps
201
+
202
+ # layers
203
+ self.norm1 = WanLayerNorm(dim, eps)
204
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
205
+ eps)
206
+ self.norm3 = WanLayerNorm(
207
+ dim, eps,
208
+ elementwise_affine=True) if cross_attn_norm else nn.Identity()
209
+ self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm,
210
+ eps)
211
+ self.norm2 = WanLayerNorm(dim, eps)
212
+ self.ffn = nn.Sequential(
213
+ nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
214
+ nn.Linear(ffn_dim, dim))
215
+
216
+ # modulation
217
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
218
+
219
+ def forward(
220
+ self,
221
+ x,
222
+ e,
223
+ seq_lens,
224
+ grid_sizes,
225
+ freqs,
226
+ context,
227
+ context_lens,
228
+ ):
229
+ r"""
230
+ Args:
231
+ x(Tensor): Shape [B, L, C]
232
+ e(Tensor): Shape [B, L1, 6, C]
233
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
234
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
235
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
236
+ """
237
+ assert e.dtype == torch.float32
238
+ with torch.amp.autocast('cuda', dtype=torch.float32):
239
+ e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
240
+ assert e[0].dtype == torch.float32
241
+
242
+ # self-attention
243
+ y = self.self_attn(
244
+ self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2),
245
+ seq_lens, grid_sizes, freqs)
246
+ with torch.amp.autocast('cuda', dtype=torch.float32):
247
+ x = x + y * e[2].squeeze(2)
248
+
249
+ # cross-attention & ffn function
250
+ def cross_attn_ffn(x, context, context_lens, e):
251
+ x = x + self.cross_attn(self.norm3(x), context, context_lens)
252
+ y = self.ffn(
253
+ self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2))
254
+ with torch.amp.autocast('cuda', dtype=torch.float32):
255
+ x = x + y * e[5].squeeze(2)
256
+ return x
257
+
258
+ x = cross_attn_ffn(x, context, context_lens, e)
259
+ return x
260
+
261
+
262
+ class Head(nn.Module):
263
+
264
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
265
+ super().__init__()
266
+ self.dim = dim
267
+ self.out_dim = out_dim
268
+ self.patch_size = patch_size
269
+ self.eps = eps
270
+
271
+ # layers
272
+ out_dim = math.prod(patch_size) * out_dim
273
+ self.norm = WanLayerNorm(dim, eps)
274
+ self.head = nn.Linear(dim, out_dim)
275
+
276
+ # modulation
277
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
278
+
279
+ def forward(self, x, e):
280
+ r"""
281
+ Args:
282
+ x(Tensor): Shape [B, L1, C]
283
+ e(Tensor): Shape [B, L1, C]
284
+ """
285
+ assert e.dtype == torch.float32
286
+ with torch.amp.autocast('cuda', dtype=torch.float32):
287
+ e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
288
+ x = (
289
+ self.head(
290
+ self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)))
291
+ return x
292
+
293
+
294
+ class WanModel(ModelMixin, ConfigMixin):
295
+ r"""
296
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
297
+ """
298
+
299
+ ignore_for_config = [
300
+ 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
301
+ ]
302
+ _no_split_modules = ['WanAttentionBlock']
303
+
304
+ @register_to_config
305
+ def __init__(self,
306
+ model_type='t2v',
307
+ patch_size=(1, 2, 2),
308
+ text_len=512,
309
+ in_dim=16,
310
+ dim=2048,
311
+ ffn_dim=8192,
312
+ freq_dim=256,
313
+ text_dim=4096,
314
+ out_dim=16,
315
+ num_heads=16,
316
+ num_layers=32,
317
+ window_size=(-1, -1),
318
+ qk_norm=True,
319
+ cross_attn_norm=True,
320
+ eps=1e-6):
321
+ r"""
322
+ Initialize the diffusion model backbone.
323
+
324
+ Args:
325
+ model_type (`str`, *optional*, defaults to 't2v'):
326
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
327
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
328
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
329
+ text_len (`int`, *optional*, defaults to 512):
330
+ Fixed length for text embeddings
331
+ in_dim (`int`, *optional*, defaults to 16):
332
+ Input video channels (C_in)
333
+ dim (`int`, *optional*, defaults to 2048):
334
+ Hidden dimension of the transformer
335
+ ffn_dim (`int`, *optional*, defaults to 8192):
336
+ Intermediate dimension in feed-forward network
337
+ freq_dim (`int`, *optional*, defaults to 256):
338
+ Dimension for sinusoidal time embeddings
339
+ text_dim (`int`, *optional*, defaults to 4096):
340
+ Input dimension for text embeddings
341
+ out_dim (`int`, *optional*, defaults to 16):
342
+ Output video channels (C_out)
343
+ num_heads (`int`, *optional*, defaults to 16):
344
+ Number of attention heads
345
+ num_layers (`int`, *optional*, defaults to 32):
346
+ Number of transformer blocks
347
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
348
+ Window size for local attention (-1 indicates global attention)
349
+ qk_norm (`bool`, *optional*, defaults to True):
350
+ Enable query/key normalization
351
+ cross_attn_norm (`bool`, *optional*, defaults to False):
352
+ Enable cross-attention normalization
353
+ eps (`float`, *optional*, defaults to 1e-6):
354
+ Epsilon value for normalization layers
355
+ """
356
+
357
+ super().__init__()
358
+
359
+ assert model_type in ['t2v', 'i2v', 'ti2v']
360
+ self.model_type = model_type
361
+
362
+ self.patch_size = patch_size
363
+ self.text_len = text_len
364
+ self.in_dim = in_dim
365
+ self.dim = dim
366
+ self.ffn_dim = ffn_dim
367
+ self.freq_dim = freq_dim
368
+ self.text_dim = text_dim
369
+ self.out_dim = out_dim
370
+ self.num_heads = num_heads
371
+ self.num_layers = num_layers
372
+ self.window_size = window_size
373
+ self.qk_norm = qk_norm
374
+ self.cross_attn_norm = cross_attn_norm
375
+ self.eps = eps
376
+
377
+ # embeddings
378
+ self.patch_embedding = nn.Conv3d(
379
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
380
+ self.text_embedding = nn.Sequential(
381
+ nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
382
+ nn.Linear(dim, dim))
383
+
384
+ self.time_embedding = nn.Sequential(
385
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
386
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
387
+
388
+ # blocks
389
+ self.blocks = nn.ModuleList([
390
+ WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm,
391
+ cross_attn_norm, eps) for _ in range(num_layers)
392
+ ])
393
+
394
+ # head
395
+ self.head = Head(dim, out_dim, patch_size, eps)
396
+
397
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
398
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
399
+ d = dim // num_heads
400
+ self.freqs = torch.cat([
401
+ rope_params(1024, d - 4 * (d // 6)),
402
+ rope_params(1024, 2 * (d // 6)),
403
+ rope_params(1024, 2 * (d // 6))
404
+ ],
405
+ dim=1)
406
+
407
+ # initialize weights
408
+ self.init_weights()
409
+
410
+ def forward(
411
+ self,
412
+ x,
413
+ t,
414
+ context,
415
+ seq_len,
416
+ y=None,
417
+ ):
418
+ r"""
419
+ Forward pass through the diffusion model
420
+
421
+ Args:
422
+ x (List[Tensor]):
423
+ List of input video tensors, each with shape [C_in, F, H, W]
424
+ t (Tensor):
425
+ Diffusion timesteps tensor of shape [B]
426
+ context (List[Tensor]):
427
+ List of text embeddings each with shape [L, C]
428
+ seq_len (`int`):
429
+ Maximum sequence length for positional encoding
430
+ y (List[Tensor], *optional*):
431
+ Conditional video inputs for image-to-video mode, same shape as x
432
+
433
+ Returns:
434
+ List[Tensor]:
435
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
436
+ """
437
+ if self.model_type == 'i2v':
438
+ assert y is not None
439
+ # params
440
+ device = self.patch_embedding.weight.device
441
+ if self.freqs.device != device:
442
+ self.freqs = self.freqs.to(device)
443
+
444
+ if y is not None:
445
+ x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
446
+
447
+ # embeddings
448
+ x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
449
+ grid_sizes = torch.stack(
450
+ [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
451
+ x = [u.flatten(2).transpose(1, 2) for u in x]
452
+ seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
453
+ assert seq_lens.max() <= seq_len
454
+ x = torch.cat([
455
+ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
456
+ dim=1) for u in x
457
+ ])
458
+
459
+ # time embeddings
460
+ if t.dim() == 1:
461
+ t = t.expand(t.size(0), seq_len)
462
+ with torch.amp.autocast('cuda', dtype=torch.float32):
463
+ bt = t.size(0)
464
+ t = t.flatten()
465
+ e = self.time_embedding(
466
+ sinusoidal_embedding_1d(self.freq_dim,
467
+ t).unflatten(0, (bt, seq_len)).float())
468
+ e0 = self.time_projection(e).unflatten(2, (6, self.dim))
469
+ assert e.dtype == torch.float32 and e0.dtype == torch.float32
470
+
471
+ # context
472
+ context_lens = None
473
+ context = self.text_embedding(
474
+ torch.stack([
475
+ torch.cat(
476
+ [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
477
+ for u in context
478
+ ]))
479
+
480
+ # arguments
481
+ kwargs = dict(
482
+ e=e0,
483
+ seq_lens=seq_lens,
484
+ grid_sizes=grid_sizes,
485
+ freqs=self.freqs,
486
+ context=context,
487
+ context_lens=context_lens)
488
+
489
+ for block in self.blocks:
490
+ x = block(x, **kwargs)
491
+
492
+ # head
493
+ x = self.head(x, e)
494
+
495
+ # unpatchify
496
+ x = self.unpatchify(x, grid_sizes)
497
+ return [u.float() for u in x]
498
+
499
+ def unpatchify(self, x, grid_sizes):
500
+ r"""
501
+ Reconstruct video tensors from patch embeddings.
502
+
503
+ Args:
504
+ x (List[Tensor]):
505
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
506
+ grid_sizes (Tensor):
507
+ Original spatial-temporal grid dimensions before patching,
508
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
509
+
510
+ Returns:
511
+ List[Tensor]:
512
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
513
+ """
514
+
515
+ c = self.out_dim
516
+ out = []
517
+ for u, v in zip(x, grid_sizes.tolist()):
518
+ u = u[:math.prod(v)].view(*v, *self.patch_size, c)
519
+ u = torch.einsum('fhwpqrc->cfphqwr', u)
520
+ u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
521
+ out.append(u)
522
+ return out
523
+
524
+ def init_weights(self):
525
+ r"""
526
+ Initialize model parameters using Xavier initialization.
527
+ """
528
+
529
+ # basic init
530
+ for m in self.modules():
531
+ if isinstance(m, nn.Linear):
532
+ nn.init.xavier_uniform_(m.weight)
533
+ if m.bias is not None:
534
+ nn.init.zeros_(m.bias)
535
+
536
+ # init embeddings
537
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
538
+ for m in self.text_embedding.modules():
539
+ if isinstance(m, nn.Linear):
540
+ nn.init.normal_(m.weight, std=.02)
541
+ for m in self.time_embedding.modules():
542
+ if isinstance(m, nn.Linear):
543
+ nn.init.normal_(m.weight, std=.02)
544
+
545
+ # init output layer
546
+ nn.init.zeros_(self.head.head.weight)
wan/modules/t5.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from transformers.models.t5.modeling_t5
2
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
+ import logging
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from .tokenizers import HuggingfaceTokenizer
11
+
12
+ __all__ = [
13
+ 'T5Model',
14
+ 'T5Encoder',
15
+ 'T5Decoder',
16
+ 'T5EncoderModel',
17
+ ]
18
+
19
+
20
+ def fp16_clamp(x):
21
+ if x.dtype == torch.float16 and torch.isinf(x).any():
22
+ clamp = torch.finfo(x.dtype).max - 1000
23
+ x = torch.clamp(x, min=-clamp, max=clamp)
24
+ return x
25
+
26
+
27
+ def init_weights(m):
28
+ if isinstance(m, T5LayerNorm):
29
+ nn.init.ones_(m.weight)
30
+ elif isinstance(m, T5Model):
31
+ nn.init.normal_(m.token_embedding.weight, std=1.0)
32
+ elif isinstance(m, T5FeedForward):
33
+ nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
34
+ nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
35
+ nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
36
+ elif isinstance(m, T5Attention):
37
+ nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
38
+ nn.init.normal_(m.k.weight, std=m.dim**-0.5)
39
+ nn.init.normal_(m.v.weight, std=m.dim**-0.5)
40
+ nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
41
+ elif isinstance(m, T5RelativeEmbedding):
42
+ nn.init.normal_(
43
+ m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
44
+
45
+
46
+ class GELU(nn.Module):
47
+
48
+ def forward(self, x):
49
+ return 0.5 * x * (1.0 + torch.tanh(
50
+ math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
51
+
52
+
53
+ class T5LayerNorm(nn.Module):
54
+
55
+ def __init__(self, dim, eps=1e-6):
56
+ super(T5LayerNorm, self).__init__()
57
+ self.dim = dim
58
+ self.eps = eps
59
+ self.weight = nn.Parameter(torch.ones(dim))
60
+
61
+ def forward(self, x):
62
+ x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
63
+ self.eps)
64
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
65
+ x = x.type_as(self.weight)
66
+ return self.weight * x
67
+
68
+
69
+ class T5Attention(nn.Module):
70
+
71
+ def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
72
+ assert dim_attn % num_heads == 0
73
+ super(T5Attention, self).__init__()
74
+ self.dim = dim
75
+ self.dim_attn = dim_attn
76
+ self.num_heads = num_heads
77
+ self.head_dim = dim_attn // num_heads
78
+
79
+ # layers
80
+ self.q = nn.Linear(dim, dim_attn, bias=False)
81
+ self.k = nn.Linear(dim, dim_attn, bias=False)
82
+ self.v = nn.Linear(dim, dim_attn, bias=False)
83
+ self.o = nn.Linear(dim_attn, dim, bias=False)
84
+ self.dropout = nn.Dropout(dropout)
85
+
86
+ def forward(self, x, context=None, mask=None, pos_bias=None):
87
+ """
88
+ x: [B, L1, C].
89
+ context: [B, L2, C] or None.
90
+ mask: [B, L2] or [B, L1, L2] or None.
91
+ """
92
+ # check inputs
93
+ context = x if context is None else context
94
+ b, n, c = x.size(0), self.num_heads, self.head_dim
95
+
96
+ # compute query, key, value
97
+ q = self.q(x).view(b, -1, n, c)
98
+ k = self.k(context).view(b, -1, n, c)
99
+ v = self.v(context).view(b, -1, n, c)
100
+
101
+ # attention bias
102
+ attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
103
+ if pos_bias is not None:
104
+ attn_bias += pos_bias
105
+ if mask is not None:
106
+ assert mask.ndim in [2, 3]
107
+ mask = mask.view(b, 1, 1,
108
+ -1) if mask.ndim == 2 else mask.unsqueeze(1)
109
+ attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
110
+
111
+ # compute attention (T5 does not use scaling)
112
+ attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
113
+ attn = F.softmax(attn.float(), dim=-1).type_as(attn)
114
+ x = torch.einsum('bnij,bjnc->binc', attn, v)
115
+
116
+ # output
117
+ x = x.reshape(b, -1, n * c)
118
+ x = self.o(x)
119
+ x = self.dropout(x)
120
+ return x
121
+
122
+
123
+ class T5FeedForward(nn.Module):
124
+
125
+ def __init__(self, dim, dim_ffn, dropout=0.1):
126
+ super(T5FeedForward, self).__init__()
127
+ self.dim = dim
128
+ self.dim_ffn = dim_ffn
129
+
130
+ # layers
131
+ self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
132
+ self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
133
+ self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
134
+ self.dropout = nn.Dropout(dropout)
135
+
136
+ def forward(self, x):
137
+ x = self.fc1(x) * self.gate(x)
138
+ x = self.dropout(x)
139
+ x = self.fc2(x)
140
+ x = self.dropout(x)
141
+ return x
142
+
143
+
144
+ class T5SelfAttention(nn.Module):
145
+
146
+ def __init__(self,
147
+ dim,
148
+ dim_attn,
149
+ dim_ffn,
150
+ num_heads,
151
+ num_buckets,
152
+ shared_pos=True,
153
+ dropout=0.1):
154
+ super(T5SelfAttention, self).__init__()
155
+ self.dim = dim
156
+ self.dim_attn = dim_attn
157
+ self.dim_ffn = dim_ffn
158
+ self.num_heads = num_heads
159
+ self.num_buckets = num_buckets
160
+ self.shared_pos = shared_pos
161
+
162
+ # layers
163
+ self.norm1 = T5LayerNorm(dim)
164
+ self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
165
+ self.norm2 = T5LayerNorm(dim)
166
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
167
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
168
+ num_buckets, num_heads, bidirectional=True)
169
+
170
+ def forward(self, x, mask=None, pos_bias=None):
171
+ e = pos_bias if self.shared_pos else self.pos_embedding(
172
+ x.size(1), x.size(1))
173
+ x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
174
+ x = fp16_clamp(x + self.ffn(self.norm2(x)))
175
+ return x
176
+
177
+
178
+ class T5CrossAttention(nn.Module):
179
+
180
+ def __init__(self,
181
+ dim,
182
+ dim_attn,
183
+ dim_ffn,
184
+ num_heads,
185
+ num_buckets,
186
+ shared_pos=True,
187
+ dropout=0.1):
188
+ super(T5CrossAttention, self).__init__()
189
+ self.dim = dim
190
+ self.dim_attn = dim_attn
191
+ self.dim_ffn = dim_ffn
192
+ self.num_heads = num_heads
193
+ self.num_buckets = num_buckets
194
+ self.shared_pos = shared_pos
195
+
196
+ # layers
197
+ self.norm1 = T5LayerNorm(dim)
198
+ self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
199
+ self.norm2 = T5LayerNorm(dim)
200
+ self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
201
+ self.norm3 = T5LayerNorm(dim)
202
+ self.ffn = T5FeedForward(dim, dim_ffn, dropout)
203
+ self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
204
+ num_buckets, num_heads, bidirectional=False)
205
+
206
+ def forward(self,
207
+ x,
208
+ mask=None,
209
+ encoder_states=None,
210
+ encoder_mask=None,
211
+ pos_bias=None):
212
+ e = pos_bias if self.shared_pos else self.pos_embedding(
213
+ x.size(1), x.size(1))
214
+ x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
215
+ x = fp16_clamp(x + self.cross_attn(
216
+ self.norm2(x), context=encoder_states, mask=encoder_mask))
217
+ x = fp16_clamp(x + self.ffn(self.norm3(x)))
218
+ return x
219
+
220
+
221
+ class T5RelativeEmbedding(nn.Module):
222
+
223
+ def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
224
+ super(T5RelativeEmbedding, self).__init__()
225
+ self.num_buckets = num_buckets
226
+ self.num_heads = num_heads
227
+ self.bidirectional = bidirectional
228
+ self.max_dist = max_dist
229
+
230
+ # layers
231
+ self.embedding = nn.Embedding(num_buckets, num_heads)
232
+
233
+ def forward(self, lq, lk):
234
+ device = self.embedding.weight.device
235
+ # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
236
+ # torch.arange(lq).unsqueeze(1).to(device)
237
+ rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
238
+ torch.arange(lq, device=device).unsqueeze(1)
239
+ rel_pos = self._relative_position_bucket(rel_pos)
240
+ rel_pos_embeds = self.embedding(rel_pos)
241
+ rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
242
+ 0) # [1, N, Lq, Lk]
243
+ return rel_pos_embeds.contiguous()
244
+
245
+ def _relative_position_bucket(self, rel_pos):
246
+ # preprocess
247
+ if self.bidirectional:
248
+ num_buckets = self.num_buckets // 2
249
+ rel_buckets = (rel_pos > 0).long() * num_buckets
250
+ rel_pos = torch.abs(rel_pos)
251
+ else:
252
+ num_buckets = self.num_buckets
253
+ rel_buckets = 0
254
+ rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
255
+
256
+ # embeddings for small and large positions
257
+ max_exact = num_buckets // 2
258
+ rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
259
+ math.log(self.max_dist / max_exact) *
260
+ (num_buckets - max_exact)).long()
261
+ rel_pos_large = torch.min(
262
+ rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
263
+ rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
264
+ return rel_buckets
265
+
266
+
267
+ class T5Encoder(nn.Module):
268
+
269
+ def __init__(self,
270
+ vocab,
271
+ dim,
272
+ dim_attn,
273
+ dim_ffn,
274
+ num_heads,
275
+ num_layers,
276
+ num_buckets,
277
+ shared_pos=True,
278
+ dropout=0.1):
279
+ super(T5Encoder, self).__init__()
280
+ self.dim = dim
281
+ self.dim_attn = dim_attn
282
+ self.dim_ffn = dim_ffn
283
+ self.num_heads = num_heads
284
+ self.num_layers = num_layers
285
+ self.num_buckets = num_buckets
286
+ self.shared_pos = shared_pos
287
+
288
+ # layers
289
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
290
+ else nn.Embedding(vocab, dim)
291
+ self.pos_embedding = T5RelativeEmbedding(
292
+ num_buckets, num_heads, bidirectional=True) if shared_pos else None
293
+ self.dropout = nn.Dropout(dropout)
294
+ self.blocks = nn.ModuleList([
295
+ T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
296
+ shared_pos, dropout) for _ in range(num_layers)
297
+ ])
298
+ self.norm = T5LayerNorm(dim)
299
+
300
+ # initialize weights
301
+ self.apply(init_weights)
302
+
303
+ def forward(self, ids, mask=None):
304
+ x = self.token_embedding(ids)
305
+ x = self.dropout(x)
306
+ e = self.pos_embedding(x.size(1),
307
+ x.size(1)) if self.shared_pos else None
308
+ for block in self.blocks:
309
+ x = block(x, mask, pos_bias=e)
310
+ x = self.norm(x)
311
+ x = self.dropout(x)
312
+ return x
313
+
314
+
315
+ class T5Decoder(nn.Module):
316
+
317
+ def __init__(self,
318
+ vocab,
319
+ dim,
320
+ dim_attn,
321
+ dim_ffn,
322
+ num_heads,
323
+ num_layers,
324
+ num_buckets,
325
+ shared_pos=True,
326
+ dropout=0.1):
327
+ super(T5Decoder, self).__init__()
328
+ self.dim = dim
329
+ self.dim_attn = dim_attn
330
+ self.dim_ffn = dim_ffn
331
+ self.num_heads = num_heads
332
+ self.num_layers = num_layers
333
+ self.num_buckets = num_buckets
334
+ self.shared_pos = shared_pos
335
+
336
+ # layers
337
+ self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
338
+ else nn.Embedding(vocab, dim)
339
+ self.pos_embedding = T5RelativeEmbedding(
340
+ num_buckets, num_heads, bidirectional=False) if shared_pos else None
341
+ self.dropout = nn.Dropout(dropout)
342
+ self.blocks = nn.ModuleList([
343
+ T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
344
+ shared_pos, dropout) for _ in range(num_layers)
345
+ ])
346
+ self.norm = T5LayerNorm(dim)
347
+
348
+ # initialize weights
349
+ self.apply(init_weights)
350
+
351
+ def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None):
352
+ b, s = ids.size()
353
+
354
+ # causal mask
355
+ if mask is None:
356
+ mask = torch.tril(torch.ones(1, s, s).to(ids.device))
357
+ elif mask.ndim == 2:
358
+ mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1))
359
+
360
+ # layers
361
+ x = self.token_embedding(ids)
362
+ x = self.dropout(x)
363
+ e = self.pos_embedding(x.size(1),
364
+ x.size(1)) if self.shared_pos else None
365
+ for block in self.blocks:
366
+ x = block(x, mask, encoder_states, encoder_mask, pos_bias=e)
367
+ x = self.norm(x)
368
+ x = self.dropout(x)
369
+ return x
370
+
371
+
372
+ class T5Model(nn.Module):
373
+
374
+ def __init__(self,
375
+ vocab_size,
376
+ dim,
377
+ dim_attn,
378
+ dim_ffn,
379
+ num_heads,
380
+ encoder_layers,
381
+ decoder_layers,
382
+ num_buckets,
383
+ shared_pos=True,
384
+ dropout=0.1):
385
+ super(T5Model, self).__init__()
386
+ self.vocab_size = vocab_size
387
+ self.dim = dim
388
+ self.dim_attn = dim_attn
389
+ self.dim_ffn = dim_ffn
390
+ self.num_heads = num_heads
391
+ self.encoder_layers = encoder_layers
392
+ self.decoder_layers = decoder_layers
393
+ self.num_buckets = num_buckets
394
+
395
+ # layers
396
+ self.token_embedding = nn.Embedding(vocab_size, dim)
397
+ self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn,
398
+ num_heads, encoder_layers, num_buckets,
399
+ shared_pos, dropout)
400
+ self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn,
401
+ num_heads, decoder_layers, num_buckets,
402
+ shared_pos, dropout)
403
+ self.head = nn.Linear(dim, vocab_size, bias=False)
404
+
405
+ # initialize weights
406
+ self.apply(init_weights)
407
+
408
+ def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask):
409
+ x = self.encoder(encoder_ids, encoder_mask)
410
+ x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask)
411
+ x = self.head(x)
412
+ return x
413
+
414
+
415
+ def _t5(name,
416
+ encoder_only=False,
417
+ decoder_only=False,
418
+ return_tokenizer=False,
419
+ tokenizer_kwargs={},
420
+ dtype=torch.float32,
421
+ device='cpu',
422
+ **kwargs):
423
+ # sanity check
424
+ assert not (encoder_only and decoder_only)
425
+
426
+ # params
427
+ if encoder_only:
428
+ model_cls = T5Encoder
429
+ kwargs['vocab'] = kwargs.pop('vocab_size')
430
+ kwargs['num_layers'] = kwargs.pop('encoder_layers')
431
+ _ = kwargs.pop('decoder_layers')
432
+ elif decoder_only:
433
+ model_cls = T5Decoder
434
+ kwargs['vocab'] = kwargs.pop('vocab_size')
435
+ kwargs['num_layers'] = kwargs.pop('decoder_layers')
436
+ _ = kwargs.pop('encoder_layers')
437
+ else:
438
+ model_cls = T5Model
439
+
440
+ # init model
441
+ with torch.device(device):
442
+ model = model_cls(**kwargs)
443
+
444
+ # set device
445
+ model = model.to(dtype=dtype, device=device)
446
+
447
+ # init tokenizer
448
+ if return_tokenizer:
449
+ from .tokenizers import HuggingfaceTokenizer
450
+ tokenizer = HuggingfaceTokenizer(f'google/{name}', **tokenizer_kwargs)
451
+ return model, tokenizer
452
+ else:
453
+ return model
454
+
455
+
456
+ def umt5_xxl(**kwargs):
457
+ cfg = dict(
458
+ vocab_size=256384,
459
+ dim=4096,
460
+ dim_attn=4096,
461
+ dim_ffn=10240,
462
+ num_heads=64,
463
+ encoder_layers=24,
464
+ decoder_layers=24,
465
+ num_buckets=32,
466
+ shared_pos=False,
467
+ dropout=0.1)
468
+ cfg.update(**kwargs)
469
+ return _t5('umt5-xxl', **cfg)
470
+
471
+
472
+ class T5EncoderModel:
473
+
474
+ def __init__(
475
+ self,
476
+ text_len,
477
+ dtype=torch.bfloat16,
478
+ device=torch.cuda.current_device(),
479
+ checkpoint_path=None,
480
+ tokenizer_path=None,
481
+ shard_fn=None,
482
+ ):
483
+ self.text_len = text_len
484
+ self.dtype = dtype
485
+ self.device = device
486
+ self.checkpoint_path = checkpoint_path
487
+ self.tokenizer_path = tokenizer_path
488
+
489
+ # init model
490
+ model = umt5_xxl(
491
+ encoder_only=True,
492
+ return_tokenizer=False,
493
+ dtype=dtype,
494
+ device=device).eval().requires_grad_(False)
495
+ logging.info(f'loading {checkpoint_path}')
496
+ model.load_state_dict(torch.load(checkpoint_path, map_location='cpu'))
497
+ self.model = model
498
+ if shard_fn is not None:
499
+ self.model = shard_fn(self.model, sync_module_states=False)
500
+ else:
501
+ self.model.to(self.device)
502
+ # init tokenizer
503
+ self.tokenizer = HuggingfaceTokenizer(
504
+ name=tokenizer_path, seq_len=text_len, clean='whitespace')
505
+
506
+ def __call__(self, texts, device):
507
+ ids, mask = self.tokenizer(
508
+ texts, return_mask=True, add_special_tokens=True)
509
+ ids = ids.to(device)
510
+ mask = mask.to(device)
511
+ seq_lens = mask.gt(0).sum(dim=1).long()
512
+ context = self.model(ids, mask)
513
+ return [u[:v] for u, v in zip(context, seq_lens)]
wan/modules/tokenizers.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import html
3
+ import string
4
+
5
+ import ftfy
6
+ import regex as re
7
+ from transformers import AutoTokenizer
8
+
9
+ __all__ = ['HuggingfaceTokenizer']
10
+
11
+
12
+ def basic_clean(text):
13
+ text = ftfy.fix_text(text)
14
+ text = html.unescape(html.unescape(text))
15
+ return text.strip()
16
+
17
+
18
+ def whitespace_clean(text):
19
+ text = re.sub(r'\s+', ' ', text)
20
+ text = text.strip()
21
+ return text
22
+
23
+
24
+ def canonicalize(text, keep_punctuation_exact_string=None):
25
+ text = text.replace('_', ' ')
26
+ if keep_punctuation_exact_string:
27
+ text = keep_punctuation_exact_string.join(
28
+ part.translate(str.maketrans('', '', string.punctuation))
29
+ for part in text.split(keep_punctuation_exact_string))
30
+ else:
31
+ text = text.translate(str.maketrans('', '', string.punctuation))
32
+ text = text.lower()
33
+ text = re.sub(r'\s+', ' ', text)
34
+ return text.strip()
35
+
36
+
37
+ class HuggingfaceTokenizer:
38
+
39
+ def __init__(self, name, seq_len=None, clean=None, **kwargs):
40
+ assert clean in (None, 'whitespace', 'lower', 'canonicalize')
41
+ self.name = name
42
+ self.seq_len = seq_len
43
+ self.clean = clean
44
+
45
+ # init tokenizer
46
+ self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
47
+ self.vocab_size = self.tokenizer.vocab_size
48
+
49
+ def __call__(self, sequence, **kwargs):
50
+ return_mask = kwargs.pop('return_mask', False)
51
+
52
+ # arguments
53
+ _kwargs = {'return_tensors': 'pt'}
54
+ if self.seq_len is not None:
55
+ _kwargs.update({
56
+ 'padding': 'max_length',
57
+ 'truncation': True,
58
+ 'max_length': self.seq_len
59
+ })
60
+ _kwargs.update(**kwargs)
61
+
62
+ # tokenization
63
+ if isinstance(sequence, str):
64
+ sequence = [sequence]
65
+ if self.clean:
66
+ sequence = [self._clean(u) for u in sequence]
67
+ ids = self.tokenizer(sequence, **_kwargs)
68
+
69
+ # output
70
+ if return_mask:
71
+ return ids.input_ids, ids.attention_mask
72
+ else:
73
+ return ids.input_ids
74
+
75
+ def _clean(self, text):
76
+ if self.clean == 'whitespace':
77
+ text = whitespace_clean(basic_clean(text))
78
+ elif self.clean == 'lower':
79
+ text = whitespace_clean(basic_clean(text)).lower()
80
+ elif self.clean == 'canonicalize':
81
+ text = canonicalize(basic_clean(text))
82
+ return text
wan/modules/vae2_1.py ADDED
@@ -0,0 +1,663 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ __all__ = [
11
+ 'Wan2_1_VAE',
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
25
+ self.padding[1], 2 * self.padding[0], 0)
26
+ self.padding = (0, 0, 0)
27
+
28
+ def forward(self, x, cache_x=None):
29
+ padding = list(self._padding)
30
+ if cache_x is not None and self._padding[4] > 0:
31
+ cache_x = cache_x.to(x.device)
32
+ x = torch.cat([cache_x, x], dim=2)
33
+ padding[4] -= cache_x.shape[2]
34
+ x = F.pad(x, padding)
35
+
36
+ return super().forward(x)
37
+
38
+
39
+ class RMS_norm(nn.Module):
40
+
41
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
42
+ super().__init__()
43
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
44
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
45
+
46
+ self.channel_first = channel_first
47
+ self.scale = dim**0.5
48
+ self.gamma = nn.Parameter(torch.ones(shape))
49
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
50
+
51
+ def forward(self, x):
52
+ return F.normalize(
53
+ x, dim=(1 if self.channel_first else
54
+ -1)) * self.scale * self.gamma + self.bias
55
+
56
+
57
+ class Upsample(nn.Upsample):
58
+
59
+ def forward(self, x):
60
+ """
61
+ Fix bfloat16 support for nearest neighbor interpolation.
62
+ """
63
+ return super().forward(x.float()).type_as(x)
64
+
65
+
66
+ class Resample(nn.Module):
67
+
68
+ def __init__(self, dim, mode):
69
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
70
+ 'downsample3d')
71
+ super().__init__()
72
+ self.dim = dim
73
+ self.mode = mode
74
+
75
+ # layers
76
+ if mode == 'upsample2d':
77
+ self.resample = nn.Sequential(
78
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
79
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
80
+ elif mode == 'upsample3d':
81
+ self.resample = nn.Sequential(
82
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
84
+ self.time_conv = CausalConv3d(
85
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
86
+
87
+ elif mode == 'downsample2d':
88
+ self.resample = nn.Sequential(
89
+ nn.ZeroPad2d((0, 1, 0, 1)),
90
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
91
+ elif mode == 'downsample3d':
92
+ self.resample = nn.Sequential(
93
+ nn.ZeroPad2d((0, 1, 0, 1)),
94
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
+ self.time_conv = CausalConv3d(
96
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
97
+
98
+ else:
99
+ self.resample = nn.Identity()
100
+
101
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
102
+ b, c, t, h, w = x.size()
103
+ if self.mode == 'upsample3d':
104
+ if feat_cache is not None:
105
+ idx = feat_idx[0]
106
+ if feat_cache[idx] is None:
107
+ feat_cache[idx] = 'Rep'
108
+ feat_idx[0] += 1
109
+ else:
110
+
111
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
112
+ if cache_x.shape[2] < 2 and feat_cache[
113
+ idx] is not None and feat_cache[idx] != 'Rep':
114
+ # cache last frame of last two chunk
115
+ cache_x = torch.cat([
116
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
117
+ cache_x.device), cache_x
118
+ ],
119
+ dim=2)
120
+ if cache_x.shape[2] < 2 and feat_cache[
121
+ idx] is not None and feat_cache[idx] == 'Rep':
122
+ cache_x = torch.cat([
123
+ torch.zeros_like(cache_x).to(cache_x.device),
124
+ cache_x
125
+ ],
126
+ dim=2)
127
+ if feat_cache[idx] == 'Rep':
128
+ x = self.time_conv(x)
129
+ else:
130
+ x = self.time_conv(x, feat_cache[idx])
131
+ feat_cache[idx] = cache_x
132
+ feat_idx[0] += 1
133
+
134
+ x = x.reshape(b, 2, c, t, h, w)
135
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
136
+ 3)
137
+ x = x.reshape(b, c, t * 2, h, w)
138
+ t = x.shape[2]
139
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
140
+ x = self.resample(x)
141
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
142
+
143
+ if self.mode == 'downsample3d':
144
+ if feat_cache is not None:
145
+ idx = feat_idx[0]
146
+ if feat_cache[idx] is None:
147
+ feat_cache[idx] = x.clone()
148
+ feat_idx[0] += 1
149
+ else:
150
+
151
+ cache_x = x[:, :, -1:, :, :].clone()
152
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
153
+ # # cache last frame of last two chunk
154
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
155
+
156
+ x = self.time_conv(
157
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
158
+ feat_cache[idx] = cache_x
159
+ feat_idx[0] += 1
160
+ return x
161
+
162
+ def init_weight(self, conv):
163
+ conv_weight = conv.weight
164
+ nn.init.zeros_(conv_weight)
165
+ c1, c2, t, h, w = conv_weight.size()
166
+ one_matrix = torch.eye(c1, c2)
167
+ init_matrix = one_matrix
168
+ nn.init.zeros_(conv_weight)
169
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
170
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
171
+ conv.weight.data.copy_(conv_weight)
172
+ nn.init.zeros_(conv.bias.data)
173
+
174
+ def init_weight2(self, conv):
175
+ conv_weight = conv.weight.data
176
+ nn.init.zeros_(conv_weight)
177
+ c1, c2, t, h, w = conv_weight.size()
178
+ init_matrix = torch.eye(c1 // 2, c2)
179
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
180
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
181
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
182
+ conv.weight.data.copy_(conv_weight)
183
+ nn.init.zeros_(conv.bias.data)
184
+
185
+
186
+ class ResidualBlock(nn.Module):
187
+
188
+ def __init__(self, in_dim, out_dim, dropout=0.0):
189
+ super().__init__()
190
+ self.in_dim = in_dim
191
+ self.out_dim = out_dim
192
+
193
+ # layers
194
+ self.residual = nn.Sequential(
195
+ RMS_norm(in_dim, images=False), nn.SiLU(),
196
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
197
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
198
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
199
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
200
+ if in_dim != out_dim else nn.Identity()
201
+
202
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
203
+ h = self.shortcut(x)
204
+ for layer in self.residual:
205
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
206
+ idx = feat_idx[0]
207
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
208
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
209
+ # cache last frame of last two chunk
210
+ cache_x = torch.cat([
211
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
212
+ cache_x.device), cache_x
213
+ ],
214
+ dim=2)
215
+ x = layer(x, feat_cache[idx])
216
+ feat_cache[idx] = cache_x
217
+ feat_idx[0] += 1
218
+ else:
219
+ x = layer(x)
220
+ return x + h
221
+
222
+
223
+ class AttentionBlock(nn.Module):
224
+ """
225
+ Causal self-attention with a single head.
226
+ """
227
+
228
+ def __init__(self, dim):
229
+ super().__init__()
230
+ self.dim = dim
231
+
232
+ # layers
233
+ self.norm = RMS_norm(dim)
234
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
235
+ self.proj = nn.Conv2d(dim, dim, 1)
236
+
237
+ # zero out the last layer params
238
+ nn.init.zeros_(self.proj.weight)
239
+
240
+ def forward(self, x):
241
+ identity = x
242
+ b, c, t, h, w = x.size()
243
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
244
+ x = self.norm(x)
245
+ # compute query, key, value
246
+ q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
247
+ -1).permute(0, 1, 3,
248
+ 2).contiguous().chunk(
249
+ 3, dim=-1)
250
+
251
+ # apply attention
252
+ x = F.scaled_dot_product_attention(
253
+ q,
254
+ k,
255
+ v,
256
+ )
257
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
258
+
259
+ # output
260
+ x = self.proj(x)
261
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
262
+ return x + identity
263
+
264
+
265
+ class Encoder3d(nn.Module):
266
+
267
+ def __init__(self,
268
+ dim=128,
269
+ z_dim=4,
270
+ dim_mult=[1, 2, 4, 4],
271
+ num_res_blocks=2,
272
+ attn_scales=[],
273
+ temperal_downsample=[True, True, False],
274
+ dropout=0.0):
275
+ super().__init__()
276
+ self.dim = dim
277
+ self.z_dim = z_dim
278
+ self.dim_mult = dim_mult
279
+ self.num_res_blocks = num_res_blocks
280
+ self.attn_scales = attn_scales
281
+ self.temperal_downsample = temperal_downsample
282
+
283
+ # dimensions
284
+ dims = [dim * u for u in [1] + dim_mult]
285
+ scale = 1.0
286
+
287
+ # init block
288
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
289
+
290
+ # downsample blocks
291
+ downsamples = []
292
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
293
+ # residual (+attention) blocks
294
+ for _ in range(num_res_blocks):
295
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
296
+ if scale in attn_scales:
297
+ downsamples.append(AttentionBlock(out_dim))
298
+ in_dim = out_dim
299
+
300
+ # downsample block
301
+ if i != len(dim_mult) - 1:
302
+ mode = 'downsample3d' if temperal_downsample[
303
+ i] else 'downsample2d'
304
+ downsamples.append(Resample(out_dim, mode=mode))
305
+ scale /= 2.0
306
+ self.downsamples = nn.Sequential(*downsamples)
307
+
308
+ # middle blocks
309
+ self.middle = nn.Sequential(
310
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
311
+ ResidualBlock(out_dim, out_dim, dropout))
312
+
313
+ # output blocks
314
+ self.head = nn.Sequential(
315
+ RMS_norm(out_dim, images=False), nn.SiLU(),
316
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
317
+
318
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
319
+ if feat_cache is not None:
320
+ idx = feat_idx[0]
321
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
322
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
323
+ # cache last frame of last two chunk
324
+ cache_x = torch.cat([
325
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
326
+ cache_x.device), cache_x
327
+ ],
328
+ dim=2)
329
+ x = self.conv1(x, feat_cache[idx])
330
+ feat_cache[idx] = cache_x
331
+ feat_idx[0] += 1
332
+ else:
333
+ x = self.conv1(x)
334
+
335
+ ## downsamples
336
+ for layer in self.downsamples:
337
+ if feat_cache is not None:
338
+ x = layer(x, feat_cache, feat_idx)
339
+ else:
340
+ x = layer(x)
341
+
342
+ ## middle
343
+ for layer in self.middle:
344
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
345
+ x = layer(x, feat_cache, feat_idx)
346
+ else:
347
+ x = layer(x)
348
+
349
+ ## head
350
+ for layer in self.head:
351
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
352
+ idx = feat_idx[0]
353
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
354
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
355
+ # cache last frame of last two chunk
356
+ cache_x = torch.cat([
357
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
358
+ cache_x.device), cache_x
359
+ ],
360
+ dim=2)
361
+ x = layer(x, feat_cache[idx])
362
+ feat_cache[idx] = cache_x
363
+ feat_idx[0] += 1
364
+ else:
365
+ x = layer(x)
366
+ return x
367
+
368
+
369
+ class Decoder3d(nn.Module):
370
+
371
+ def __init__(self,
372
+ dim=128,
373
+ z_dim=4,
374
+ dim_mult=[1, 2, 4, 4],
375
+ num_res_blocks=2,
376
+ attn_scales=[],
377
+ temperal_upsample=[False, True, True],
378
+ dropout=0.0):
379
+ super().__init__()
380
+ self.dim = dim
381
+ self.z_dim = z_dim
382
+ self.dim_mult = dim_mult
383
+ self.num_res_blocks = num_res_blocks
384
+ self.attn_scales = attn_scales
385
+ self.temperal_upsample = temperal_upsample
386
+
387
+ # dimensions
388
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
389
+ scale = 1.0 / 2**(len(dim_mult) - 2)
390
+
391
+ # init block
392
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
393
+
394
+ # middle blocks
395
+ self.middle = nn.Sequential(
396
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
397
+ ResidualBlock(dims[0], dims[0], dropout))
398
+
399
+ # upsample blocks
400
+ upsamples = []
401
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
402
+ # residual (+attention) blocks
403
+ if i == 1 or i == 2 or i == 3:
404
+ in_dim = in_dim // 2
405
+ for _ in range(num_res_blocks + 1):
406
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
407
+ if scale in attn_scales:
408
+ upsamples.append(AttentionBlock(out_dim))
409
+ in_dim = out_dim
410
+
411
+ # upsample block
412
+ if i != len(dim_mult) - 1:
413
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
414
+ upsamples.append(Resample(out_dim, mode=mode))
415
+ scale *= 2.0
416
+ self.upsamples = nn.Sequential(*upsamples)
417
+
418
+ # output blocks
419
+ self.head = nn.Sequential(
420
+ RMS_norm(out_dim, images=False), nn.SiLU(),
421
+ CausalConv3d(out_dim, 3, 3, padding=1))
422
+
423
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
424
+ ## conv1
425
+ if feat_cache is not None:
426
+ idx = feat_idx[0]
427
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
428
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
429
+ # cache last frame of last two chunk
430
+ cache_x = torch.cat([
431
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
432
+ cache_x.device), cache_x
433
+ ],
434
+ dim=2)
435
+ x = self.conv1(x, feat_cache[idx])
436
+ feat_cache[idx] = cache_x
437
+ feat_idx[0] += 1
438
+ else:
439
+ x = self.conv1(x)
440
+
441
+ ## middle
442
+ for layer in self.middle:
443
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
444
+ x = layer(x, feat_cache, feat_idx)
445
+ else:
446
+ x = layer(x)
447
+
448
+ ## upsamples
449
+ for layer in self.upsamples:
450
+ if feat_cache is not None:
451
+ x = layer(x, feat_cache, feat_idx)
452
+ else:
453
+ x = layer(x)
454
+
455
+ ## head
456
+ for layer in self.head:
457
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
458
+ idx = feat_idx[0]
459
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
460
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
461
+ # cache last frame of last two chunk
462
+ cache_x = torch.cat([
463
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
464
+ cache_x.device), cache_x
465
+ ],
466
+ dim=2)
467
+ x = layer(x, feat_cache[idx])
468
+ feat_cache[idx] = cache_x
469
+ feat_idx[0] += 1
470
+ else:
471
+ x = layer(x)
472
+ return x
473
+
474
+
475
+ def count_conv3d(model):
476
+ count = 0
477
+ for m in model.modules():
478
+ if isinstance(m, CausalConv3d):
479
+ count += 1
480
+ return count
481
+
482
+
483
+ class WanVAE_(nn.Module):
484
+
485
+ def __init__(self,
486
+ dim=128,
487
+ z_dim=4,
488
+ dim_mult=[1, 2, 4, 4],
489
+ num_res_blocks=2,
490
+ attn_scales=[],
491
+ temperal_downsample=[True, True, False],
492
+ dropout=0.0):
493
+ super().__init__()
494
+ self.dim = dim
495
+ self.z_dim = z_dim
496
+ self.dim_mult = dim_mult
497
+ self.num_res_blocks = num_res_blocks
498
+ self.attn_scales = attn_scales
499
+ self.temperal_downsample = temperal_downsample
500
+ self.temperal_upsample = temperal_downsample[::-1]
501
+
502
+ # modules
503
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
504
+ attn_scales, self.temperal_downsample, dropout)
505
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
506
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
507
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
508
+ attn_scales, self.temperal_upsample, dropout)
509
+
510
+ def forward(self, x):
511
+ mu, log_var = self.encode(x)
512
+ z = self.reparameterize(mu, log_var)
513
+ x_recon = self.decode(z)
514
+ return x_recon, mu, log_var
515
+
516
+ def encode(self, x, scale):
517
+ self.clear_cache()
518
+ ## cache
519
+ t = x.shape[2]
520
+ iter_ = 1 + (t - 1) // 4
521
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
522
+ for i in range(iter_):
523
+ self._enc_conv_idx = [0]
524
+ if i == 0:
525
+ out = self.encoder(
526
+ x[:, :, :1, :, :],
527
+ feat_cache=self._enc_feat_map,
528
+ feat_idx=self._enc_conv_idx)
529
+ else:
530
+ out_ = self.encoder(
531
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
532
+ feat_cache=self._enc_feat_map,
533
+ feat_idx=self._enc_conv_idx)
534
+ out = torch.cat([out, out_], 2)
535
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
536
+ if isinstance(scale[0], torch.Tensor):
537
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
538
+ 1, self.z_dim, 1, 1, 1)
539
+ else:
540
+ mu = (mu - scale[0]) * scale[1]
541
+ self.clear_cache()
542
+ return mu
543
+
544
+ def decode(self, z, scale):
545
+ self.clear_cache()
546
+ # z: [b,c,t,h,w]
547
+ if isinstance(scale[0], torch.Tensor):
548
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
549
+ 1, self.z_dim, 1, 1, 1)
550
+ else:
551
+ z = z / scale[1] + scale[0]
552
+ iter_ = z.shape[2]
553
+ x = self.conv2(z)
554
+ for i in range(iter_):
555
+ self._conv_idx = [0]
556
+ if i == 0:
557
+ out = self.decoder(
558
+ x[:, :, i:i + 1, :, :],
559
+ feat_cache=self._feat_map,
560
+ feat_idx=self._conv_idx)
561
+ else:
562
+ out_ = self.decoder(
563
+ x[:, :, i:i + 1, :, :],
564
+ feat_cache=self._feat_map,
565
+ feat_idx=self._conv_idx)
566
+ out = torch.cat([out, out_], 2)
567
+ self.clear_cache()
568
+ return out
569
+
570
+ def reparameterize(self, mu, log_var):
571
+ std = torch.exp(0.5 * log_var)
572
+ eps = torch.randn_like(std)
573
+ return eps * std + mu
574
+
575
+ def sample(self, imgs, deterministic=False):
576
+ mu, log_var = self.encode(imgs)
577
+ if deterministic:
578
+ return mu
579
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
580
+ return mu + std * torch.randn_like(std)
581
+
582
+ def clear_cache(self):
583
+ self._conv_num = count_conv3d(self.decoder)
584
+ self._conv_idx = [0]
585
+ self._feat_map = [None] * self._conv_num
586
+ #cache encode
587
+ self._enc_conv_num = count_conv3d(self.encoder)
588
+ self._enc_conv_idx = [0]
589
+ self._enc_feat_map = [None] * self._enc_conv_num
590
+
591
+
592
+ def _video_vae(pretrained_path=None, z_dim=None, device='cpu', **kwargs):
593
+ """
594
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
595
+ """
596
+ # params
597
+ cfg = dict(
598
+ dim=96,
599
+ z_dim=z_dim,
600
+ dim_mult=[1, 2, 4, 4],
601
+ num_res_blocks=2,
602
+ attn_scales=[],
603
+ temperal_downsample=[False, True, True],
604
+ dropout=0.0)
605
+ cfg.update(**kwargs)
606
+
607
+ # init model
608
+ with torch.device('meta'):
609
+ model = WanVAE_(**cfg)
610
+
611
+ # load checkpoint
612
+ logging.info(f'loading {pretrained_path}')
613
+ model.load_state_dict(
614
+ torch.load(pretrained_path, map_location=device), assign=True)
615
+
616
+ return model
617
+
618
+
619
+ class Wan2_1_VAE:
620
+
621
+ def __init__(self,
622
+ z_dim=16,
623
+ vae_pth='cache/vae_step_411000.pth',
624
+ dtype=torch.float,
625
+ device="cuda"):
626
+ self.dtype = dtype
627
+ self.device = device
628
+
629
+ mean = [
630
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
631
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
632
+ ]
633
+ std = [
634
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
635
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
636
+ ]
637
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
638
+ self.std = torch.tensor(std, dtype=dtype, device=device)
639
+ self.scale = [self.mean, 1.0 / self.std]
640
+
641
+ # init model
642
+ self.model = _video_vae(
643
+ pretrained_path=vae_pth,
644
+ z_dim=z_dim,
645
+ ).eval().requires_grad_(False).to(device)
646
+
647
+ def encode(self, videos):
648
+ """
649
+ videos: A list of videos each with shape [C, T, H, W].
650
+ """
651
+ with amp.autocast(dtype=self.dtype):
652
+ return [
653
+ self.model.encode(u.unsqueeze(0), self.scale).float().squeeze(0)
654
+ for u in videos
655
+ ]
656
+
657
+ def decode(self, zs):
658
+ with amp.autocast(dtype=self.dtype):
659
+ return [
660
+ self.model.decode(u.unsqueeze(0),
661
+ self.scale).float().clamp_(-1, 1).squeeze(0)
662
+ for u in zs
663
+ ]
wan/modules/vae2_2.py ADDED
@@ -0,0 +1,1051 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import logging
3
+
4
+ import torch
5
+ import torch.cuda.amp as amp
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+
10
+ __all__ = [
11
+ "Wan2_2_VAE",
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (
25
+ self.padding[2],
26
+ self.padding[2],
27
+ self.padding[1],
28
+ self.padding[1],
29
+ 2 * self.padding[0],
30
+ 0,
31
+ )
32
+ self.padding = (0, 0, 0)
33
+
34
+ def forward(self, x, cache_x=None):
35
+ padding = list(self._padding)
36
+ if cache_x is not None and self._padding[4] > 0:
37
+ cache_x = cache_x.to(x.device)
38
+ x = torch.cat([cache_x, x], dim=2)
39
+ padding[4] -= cache_x.shape[2]
40
+ x = F.pad(x, padding)
41
+
42
+ return super().forward(x)
43
+
44
+
45
+ class RMS_norm(nn.Module):
46
+
47
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
48
+ super().__init__()
49
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
50
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
51
+
52
+ self.channel_first = channel_first
53
+ self.scale = dim**0.5
54
+ self.gamma = nn.Parameter(torch.ones(shape))
55
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
56
+
57
+ def forward(self, x):
58
+ return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
59
+ self.scale * self.gamma + self.bias)
60
+
61
+
62
+ class Upsample(nn.Upsample):
63
+
64
+ def forward(self, x):
65
+ """
66
+ Fix bfloat16 support for nearest neighbor interpolation.
67
+ """
68
+ return super().forward(x.float()).type_as(x)
69
+
70
+
71
+ class Resample(nn.Module):
72
+
73
+ def __init__(self, dim, mode):
74
+ assert mode in (
75
+ "none",
76
+ "upsample2d",
77
+ "upsample3d",
78
+ "downsample2d",
79
+ "downsample3d",
80
+ )
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.mode = mode
84
+
85
+ # layers
86
+ if mode == "upsample2d":
87
+ self.resample = nn.Sequential(
88
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
89
+ nn.Conv2d(dim, dim, 3, padding=1),
90
+ )
91
+ elif mode == "upsample3d":
92
+ self.resample = nn.Sequential(
93
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
94
+ nn.Conv2d(dim, dim, 3, padding=1),
95
+ # nn.Conv2d(dim, dim//2, 3, padding=1)
96
+ )
97
+ self.time_conv = CausalConv3d(
98
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
99
+ elif mode == "downsample2d":
100
+ self.resample = nn.Sequential(
101
+ nn.ZeroPad2d((0, 1, 0, 1)),
102
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
103
+ elif mode == "downsample3d":
104
+ self.resample = nn.Sequential(
105
+ nn.ZeroPad2d((0, 1, 0, 1)),
106
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
107
+ self.time_conv = CausalConv3d(
108
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
109
+ else:
110
+ self.resample = nn.Identity()
111
+
112
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
113
+ b, c, t, h, w = x.size()
114
+ if self.mode == "upsample3d":
115
+ if feat_cache is not None:
116
+ idx = feat_idx[0]
117
+ if feat_cache[idx] is None:
118
+ feat_cache[idx] = "Rep"
119
+ feat_idx[0] += 1
120
+ else:
121
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
122
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
123
+ feat_cache[idx] != "Rep"):
124
+ # cache last frame of last two chunk
125
+ cache_x = torch.cat(
126
+ [
127
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
128
+ cache_x.device),
129
+ cache_x,
130
+ ],
131
+ dim=2,
132
+ )
133
+ if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
134
+ feat_cache[idx] == "Rep"):
135
+ cache_x = torch.cat(
136
+ [
137
+ torch.zeros_like(cache_x).to(cache_x.device),
138
+ cache_x
139
+ ],
140
+ dim=2,
141
+ )
142
+ if feat_cache[idx] == "Rep":
143
+ x = self.time_conv(x)
144
+ else:
145
+ x = self.time_conv(x, feat_cache[idx])
146
+ feat_cache[idx] = cache_x
147
+ feat_idx[0] += 1
148
+ x = x.reshape(b, 2, c, t, h, w)
149
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
150
+ 3)
151
+ x = x.reshape(b, c, t * 2, h, w)
152
+ t = x.shape[2]
153
+ x = rearrange(x, "b c t h w -> (b t) c h w")
154
+ x = self.resample(x)
155
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
156
+
157
+ if self.mode == "downsample3d":
158
+ if feat_cache is not None:
159
+ idx = feat_idx[0]
160
+ if feat_cache[idx] is None:
161
+ feat_cache[idx] = x.clone()
162
+ feat_idx[0] += 1
163
+ else:
164
+ cache_x = x[:, :, -1:, :, :].clone()
165
+ x = self.time_conv(
166
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
167
+ feat_cache[idx] = cache_x
168
+ feat_idx[0] += 1
169
+ return x
170
+
171
+ def init_weight(self, conv):
172
+ conv_weight = conv.weight.detach().clone()
173
+ nn.init.zeros_(conv_weight)
174
+ c1, c2, t, h, w = conv_weight.size()
175
+ one_matrix = torch.eye(c1, c2)
176
+ init_matrix = one_matrix
177
+ nn.init.zeros_(conv_weight)
178
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
179
+ conv.weight = nn.Parameter(conv_weight)
180
+ nn.init.zeros_(conv.bias.data)
181
+
182
+ def init_weight2(self, conv):
183
+ conv_weight = conv.weight.data.detach().clone()
184
+ nn.init.zeros_(conv_weight)
185
+ c1, c2, t, h, w = conv_weight.size()
186
+ init_matrix = torch.eye(c1 // 2, c2)
187
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
188
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
189
+ conv.weight = nn.Parameter(conv_weight)
190
+ nn.init.zeros_(conv.bias.data)
191
+
192
+
193
+ class ResidualBlock(nn.Module):
194
+
195
+ def __init__(self, in_dim, out_dim, dropout=0.0):
196
+ super().__init__()
197
+ self.in_dim = in_dim
198
+ self.out_dim = out_dim
199
+
200
+ # layers
201
+ self.residual = nn.Sequential(
202
+ RMS_norm(in_dim, images=False),
203
+ nn.SiLU(),
204
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
205
+ RMS_norm(out_dim, images=False),
206
+ nn.SiLU(),
207
+ nn.Dropout(dropout),
208
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
209
+ )
210
+ self.shortcut = (
211
+ CausalConv3d(in_dim, out_dim, 1)
212
+ if in_dim != out_dim else nn.Identity())
213
+
214
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
215
+ h = self.shortcut(x)
216
+ for layer in self.residual:
217
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
218
+ idx = feat_idx[0]
219
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
220
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
221
+ # cache last frame of last two chunk
222
+ cache_x = torch.cat(
223
+ [
224
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
225
+ cache_x.device),
226
+ cache_x,
227
+ ],
228
+ dim=2,
229
+ )
230
+ x = layer(x, feat_cache[idx])
231
+ feat_cache[idx] = cache_x
232
+ feat_idx[0] += 1
233
+ else:
234
+ x = layer(x)
235
+ return x + h
236
+
237
+
238
+ class AttentionBlock(nn.Module):
239
+ """
240
+ Causal self-attention with a single head.
241
+ """
242
+
243
+ def __init__(self, dim):
244
+ super().__init__()
245
+ self.dim = dim
246
+
247
+ # layers
248
+ self.norm = RMS_norm(dim)
249
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
250
+ self.proj = nn.Conv2d(dim, dim, 1)
251
+
252
+ # zero out the last layer params
253
+ nn.init.zeros_(self.proj.weight)
254
+
255
+ def forward(self, x):
256
+ identity = x
257
+ b, c, t, h, w = x.size()
258
+ x = rearrange(x, "b c t h w -> (b t) c h w")
259
+ x = self.norm(x)
260
+ # compute query, key, value
261
+ q, k, v = (
262
+ self.to_qkv(x).reshape(b * t, 1, c * 3,
263
+ -1).permute(0, 1, 3,
264
+ 2).contiguous().chunk(3, dim=-1))
265
+
266
+ # apply attention
267
+ x = F.scaled_dot_product_attention(
268
+ q,
269
+ k,
270
+ v,
271
+ )
272
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
273
+
274
+ # output
275
+ x = self.proj(x)
276
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
277
+ return x + identity
278
+
279
+
280
+ def patchify(x, patch_size):
281
+ if patch_size == 1:
282
+ return x
283
+ if x.dim() == 4:
284
+ x = rearrange(
285
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
286
+ elif x.dim() == 5:
287
+ x = rearrange(
288
+ x,
289
+ "b c f (h q) (w r) -> b (c r q) f h w",
290
+ q=patch_size,
291
+ r=patch_size,
292
+ )
293
+ else:
294
+ raise ValueError(f"Invalid input shape: {x.shape}")
295
+
296
+ return x
297
+
298
+
299
+ def unpatchify(x, patch_size):
300
+ if patch_size == 1:
301
+ return x
302
+
303
+ if x.dim() == 4:
304
+ x = rearrange(
305
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
306
+ elif x.dim() == 5:
307
+ x = rearrange(
308
+ x,
309
+ "b (c r q) f h w -> b c f (h q) (w r)",
310
+ q=patch_size,
311
+ r=patch_size,
312
+ )
313
+ return x
314
+
315
+
316
+ class AvgDown3D(nn.Module):
317
+
318
+ def __init__(
319
+ self,
320
+ in_channels,
321
+ out_channels,
322
+ factor_t,
323
+ factor_s=1,
324
+ ):
325
+ super().__init__()
326
+ self.in_channels = in_channels
327
+ self.out_channels = out_channels
328
+ self.factor_t = factor_t
329
+ self.factor_s = factor_s
330
+ self.factor = self.factor_t * self.factor_s * self.factor_s
331
+
332
+ assert in_channels * self.factor % out_channels == 0
333
+ self.group_size = in_channels * self.factor // out_channels
334
+
335
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
336
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
337
+ pad = (0, 0, 0, 0, pad_t, 0)
338
+ x = F.pad(x, pad)
339
+ B, C, T, H, W = x.shape
340
+ x = x.view(
341
+ B,
342
+ C,
343
+ T // self.factor_t,
344
+ self.factor_t,
345
+ H // self.factor_s,
346
+ self.factor_s,
347
+ W // self.factor_s,
348
+ self.factor_s,
349
+ )
350
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
351
+ x = x.view(
352
+ B,
353
+ C * self.factor,
354
+ T // self.factor_t,
355
+ H // self.factor_s,
356
+ W // self.factor_s,
357
+ )
358
+ x = x.view(
359
+ B,
360
+ self.out_channels,
361
+ self.group_size,
362
+ T // self.factor_t,
363
+ H // self.factor_s,
364
+ W // self.factor_s,
365
+ )
366
+ x = x.mean(dim=2)
367
+ return x
368
+
369
+
370
+ class DupUp3D(nn.Module):
371
+
372
+ def __init__(
373
+ self,
374
+ in_channels: int,
375
+ out_channels: int,
376
+ factor_t,
377
+ factor_s=1,
378
+ ):
379
+ super().__init__()
380
+ self.in_channels = in_channels
381
+ self.out_channels = out_channels
382
+
383
+ self.factor_t = factor_t
384
+ self.factor_s = factor_s
385
+ self.factor = self.factor_t * self.factor_s * self.factor_s
386
+
387
+ assert out_channels * self.factor % in_channels == 0
388
+ self.repeats = out_channels * self.factor // in_channels
389
+
390
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
391
+ x = x.repeat_interleave(self.repeats, dim=1)
392
+ x = x.view(
393
+ x.size(0),
394
+ self.out_channels,
395
+ self.factor_t,
396
+ self.factor_s,
397
+ self.factor_s,
398
+ x.size(2),
399
+ x.size(3),
400
+ x.size(4),
401
+ )
402
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
403
+ x = x.view(
404
+ x.size(0),
405
+ self.out_channels,
406
+ x.size(2) * self.factor_t,
407
+ x.size(4) * self.factor_s,
408
+ x.size(6) * self.factor_s,
409
+ )
410
+ if first_chunk:
411
+ x = x[:, :, self.factor_t - 1:, :, :]
412
+ return x
413
+
414
+
415
+ class Down_ResidualBlock(nn.Module):
416
+
417
+ def __init__(self,
418
+ in_dim,
419
+ out_dim,
420
+ dropout,
421
+ mult,
422
+ temperal_downsample=False,
423
+ down_flag=False):
424
+ super().__init__()
425
+
426
+ # Shortcut path with downsample
427
+ self.avg_shortcut = AvgDown3D(
428
+ in_dim,
429
+ out_dim,
430
+ factor_t=2 if temperal_downsample else 1,
431
+ factor_s=2 if down_flag else 1,
432
+ )
433
+
434
+ # Main path with residual blocks and downsample
435
+ downsamples = []
436
+ for _ in range(mult):
437
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
438
+ in_dim = out_dim
439
+
440
+ # Add the final downsample block
441
+ if down_flag:
442
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
443
+ downsamples.append(Resample(out_dim, mode=mode))
444
+
445
+ self.downsamples = nn.Sequential(*downsamples)
446
+
447
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
448
+ x_copy = x.clone()
449
+ for module in self.downsamples:
450
+ x = module(x, feat_cache, feat_idx)
451
+
452
+ return x + self.avg_shortcut(x_copy)
453
+
454
+
455
+ class Up_ResidualBlock(nn.Module):
456
+
457
+ def __init__(self,
458
+ in_dim,
459
+ out_dim,
460
+ dropout,
461
+ mult,
462
+ temperal_upsample=False,
463
+ up_flag=False):
464
+ super().__init__()
465
+ # Shortcut path with upsample
466
+ if up_flag:
467
+ self.avg_shortcut = DupUp3D(
468
+ in_dim,
469
+ out_dim,
470
+ factor_t=2 if temperal_upsample else 1,
471
+ factor_s=2 if up_flag else 1,
472
+ )
473
+ else:
474
+ self.avg_shortcut = None
475
+
476
+ # Main path with residual blocks and upsample
477
+ upsamples = []
478
+ for _ in range(mult):
479
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
480
+ in_dim = out_dim
481
+
482
+ # Add the final upsample block
483
+ if up_flag:
484
+ mode = "upsample3d" if temperal_upsample else "upsample2d"
485
+ upsamples.append(Resample(out_dim, mode=mode))
486
+
487
+ self.upsamples = nn.Sequential(*upsamples)
488
+
489
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
490
+ x_main = x.clone()
491
+ for module in self.upsamples:
492
+ x_main = module(x_main, feat_cache, feat_idx)
493
+ if self.avg_shortcut is not None:
494
+ x_shortcut = self.avg_shortcut(x, first_chunk)
495
+ return x_main + x_shortcut
496
+ else:
497
+ return x_main
498
+
499
+
500
+ class Encoder3d(nn.Module):
501
+
502
+ def __init__(
503
+ self,
504
+ dim=128,
505
+ z_dim=4,
506
+ dim_mult=[1, 2, 4, 4],
507
+ num_res_blocks=2,
508
+ attn_scales=[],
509
+ temperal_downsample=[True, True, False],
510
+ dropout=0.0,
511
+ ):
512
+ super().__init__()
513
+ self.dim = dim
514
+ self.z_dim = z_dim
515
+ self.dim_mult = dim_mult
516
+ self.num_res_blocks = num_res_blocks
517
+ self.attn_scales = attn_scales
518
+ self.temperal_downsample = temperal_downsample
519
+
520
+ # dimensions
521
+ dims = [dim * u for u in [1] + dim_mult]
522
+ scale = 1.0
523
+
524
+ # init block
525
+ self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
526
+
527
+ # downsample blocks
528
+ downsamples = []
529
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
530
+ t_down_flag = (
531
+ temperal_downsample[i]
532
+ if i < len(temperal_downsample) else False)
533
+ downsamples.append(
534
+ Down_ResidualBlock(
535
+ in_dim=in_dim,
536
+ out_dim=out_dim,
537
+ dropout=dropout,
538
+ mult=num_res_blocks,
539
+ temperal_downsample=t_down_flag,
540
+ down_flag=i != len(dim_mult) - 1,
541
+ ))
542
+ scale /= 2.0
543
+ self.downsamples = nn.Sequential(*downsamples)
544
+
545
+ # middle blocks
546
+ self.middle = nn.Sequential(
547
+ ResidualBlock(out_dim, out_dim, dropout),
548
+ AttentionBlock(out_dim),
549
+ ResidualBlock(out_dim, out_dim, dropout),
550
+ )
551
+
552
+ # # output blocks
553
+ self.head = nn.Sequential(
554
+ RMS_norm(out_dim, images=False),
555
+ nn.SiLU(),
556
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
557
+ )
558
+
559
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
560
+
561
+ if feat_cache is not None:
562
+ idx = feat_idx[0]
563
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
564
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
565
+ cache_x = torch.cat(
566
+ [
567
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
568
+ cache_x.device),
569
+ cache_x,
570
+ ],
571
+ dim=2,
572
+ )
573
+ x = self.conv1(x, feat_cache[idx])
574
+ feat_cache[idx] = cache_x
575
+ feat_idx[0] += 1
576
+ else:
577
+ x = self.conv1(x)
578
+
579
+ ## downsamples
580
+ for layer in self.downsamples:
581
+ if feat_cache is not None:
582
+ x = layer(x, feat_cache, feat_idx)
583
+ else:
584
+ x = layer(x)
585
+
586
+ ## middle
587
+ for layer in self.middle:
588
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
589
+ x = layer(x, feat_cache, feat_idx)
590
+ else:
591
+ x = layer(x)
592
+
593
+ ## head
594
+ for layer in self.head:
595
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
596
+ idx = feat_idx[0]
597
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
598
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
599
+ cache_x = torch.cat(
600
+ [
601
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
602
+ cache_x.device),
603
+ cache_x,
604
+ ],
605
+ dim=2,
606
+ )
607
+ x = layer(x, feat_cache[idx])
608
+ feat_cache[idx] = cache_x
609
+ feat_idx[0] += 1
610
+ else:
611
+ x = layer(x)
612
+
613
+ return x
614
+
615
+
616
+ class Decoder3d(nn.Module):
617
+
618
+ def __init__(
619
+ self,
620
+ dim=128,
621
+ z_dim=4,
622
+ dim_mult=[1, 2, 4, 4],
623
+ num_res_blocks=2,
624
+ attn_scales=[],
625
+ temperal_upsample=[False, True, True],
626
+ dropout=0.0,
627
+ ):
628
+ super().__init__()
629
+ self.dim = dim
630
+ self.z_dim = z_dim
631
+ self.dim_mult = dim_mult
632
+ self.num_res_blocks = num_res_blocks
633
+ self.attn_scales = attn_scales
634
+ self.temperal_upsample = temperal_upsample
635
+
636
+ # dimensions
637
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
638
+ scale = 1.0 / 2**(len(dim_mult) - 2)
639
+ # init block
640
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
641
+
642
+ # middle blocks
643
+ self.middle = nn.Sequential(
644
+ ResidualBlock(dims[0], dims[0], dropout),
645
+ AttentionBlock(dims[0]),
646
+ ResidualBlock(dims[0], dims[0], dropout),
647
+ )
648
+
649
+ # upsample blocks
650
+ upsamples = []
651
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
652
+ t_up_flag = temperal_upsample[i] if i < len(
653
+ temperal_upsample) else False
654
+ upsamples.append(
655
+ Up_ResidualBlock(
656
+ in_dim=in_dim,
657
+ out_dim=out_dim,
658
+ dropout=dropout,
659
+ mult=num_res_blocks + 1,
660
+ temperal_upsample=t_up_flag,
661
+ up_flag=i != len(dim_mult) - 1,
662
+ ))
663
+ self.upsamples = nn.Sequential(*upsamples)
664
+
665
+ # output blocks
666
+ self.head = nn.Sequential(
667
+ RMS_norm(out_dim, images=False),
668
+ nn.SiLU(),
669
+ CausalConv3d(out_dim, 12, 3, padding=1),
670
+ )
671
+
672
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
673
+ if feat_cache is not None:
674
+ idx = feat_idx[0]
675
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
676
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
677
+ cache_x = torch.cat(
678
+ [
679
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
680
+ cache_x.device),
681
+ cache_x,
682
+ ],
683
+ dim=2,
684
+ )
685
+ x = self.conv1(x, feat_cache[idx])
686
+ feat_cache[idx] = cache_x
687
+ feat_idx[0] += 1
688
+ else:
689
+ x = self.conv1(x)
690
+
691
+ for layer in self.middle:
692
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
693
+ x = layer(x, feat_cache, feat_idx)
694
+ else:
695
+ x = layer(x)
696
+
697
+ ## upsamples
698
+ for layer in self.upsamples:
699
+ if feat_cache is not None:
700
+ x = layer(x, feat_cache, feat_idx, first_chunk)
701
+ else:
702
+ x = layer(x)
703
+
704
+ ## head
705
+ for layer in self.head:
706
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
707
+ idx = feat_idx[0]
708
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
709
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
710
+ cache_x = torch.cat(
711
+ [
712
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
713
+ cache_x.device),
714
+ cache_x,
715
+ ],
716
+ dim=2,
717
+ )
718
+ x = layer(x, feat_cache[idx])
719
+ feat_cache[idx] = cache_x
720
+ feat_idx[0] += 1
721
+ else:
722
+ x = layer(x)
723
+ return x
724
+
725
+
726
+ def count_conv3d(model):
727
+ count = 0
728
+ for m in model.modules():
729
+ if isinstance(m, CausalConv3d):
730
+ count += 1
731
+ return count
732
+
733
+
734
+ class WanVAE_(nn.Module):
735
+
736
+ def __init__(
737
+ self,
738
+ dim=160,
739
+ dec_dim=256,
740
+ z_dim=16,
741
+ dim_mult=[1, 2, 4, 4],
742
+ num_res_blocks=2,
743
+ attn_scales=[],
744
+ temperal_downsample=[True, True, False],
745
+ dropout=0.0,
746
+ ):
747
+ super().__init__()
748
+ self.dim = dim
749
+ self.z_dim = z_dim
750
+ self.dim_mult = dim_mult
751
+ self.num_res_blocks = num_res_blocks
752
+ self.attn_scales = attn_scales
753
+ self.temperal_downsample = temperal_downsample
754
+ self.temperal_upsample = temperal_downsample[::-1]
755
+
756
+ # modules
757
+ self.encoder = Encoder3d(
758
+ dim,
759
+ z_dim * 2,
760
+ dim_mult,
761
+ num_res_blocks,
762
+ attn_scales,
763
+ self.temperal_downsample,
764
+ dropout,
765
+ )
766
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
767
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
768
+ self.decoder = Decoder3d(
769
+ dec_dim,
770
+ z_dim,
771
+ dim_mult,
772
+ num_res_blocks,
773
+ attn_scales,
774
+ self.temperal_upsample,
775
+ dropout,
776
+ )
777
+
778
+ def forward(self, x, scale=[0, 1]):
779
+ mu = self.encode(x, scale)
780
+ x_recon = self.decode(mu, scale)
781
+ return x_recon, mu
782
+
783
+ def encode(self, x, scale):
784
+ self.clear_cache()
785
+ x = patchify(x, patch_size=2)
786
+ t = x.shape[2]
787
+ iter_ = 1 + (t - 1) // 4
788
+ for i in range(iter_):
789
+ self._enc_conv_idx = [0]
790
+ if i == 0:
791
+ out = self.encoder(
792
+ x[:, :, :1, :, :],
793
+ feat_cache=self._enc_feat_map,
794
+ feat_idx=self._enc_conv_idx,
795
+ )
796
+ else:
797
+ out_ = self.encoder(
798
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
799
+ feat_cache=self._enc_feat_map,
800
+ feat_idx=self._enc_conv_idx,
801
+ )
802
+ out = torch.cat([out, out_], 2)
803
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
804
+ if isinstance(scale[0], torch.Tensor):
805
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
806
+ 1, self.z_dim, 1, 1, 1)
807
+ else:
808
+ mu = (mu - scale[0]) * scale[1]
809
+ self.clear_cache()
810
+ return mu
811
+
812
+ def decode(self, z, scale):
813
+ self.clear_cache()
814
+ if isinstance(scale[0], torch.Tensor):
815
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
816
+ 1, self.z_dim, 1, 1, 1)
817
+ else:
818
+ z = z / scale[1] + scale[0]
819
+ iter_ = z.shape[2]
820
+ x = self.conv2(z)
821
+ for i in range(iter_):
822
+ self._conv_idx = [0]
823
+ if i == 0:
824
+ out = self.decoder(
825
+ x[:, :, i:i + 1, :, :],
826
+ feat_cache=self._feat_map,
827
+ feat_idx=self._conv_idx,
828
+ first_chunk=True,
829
+ )
830
+ else:
831
+ out_ = self.decoder(
832
+ x[:, :, i:i + 1, :, :],
833
+ feat_cache=self._feat_map,
834
+ feat_idx=self._conv_idx,
835
+ )
836
+ out = torch.cat([out, out_], 2)
837
+ out = unpatchify(out, patch_size=2)
838
+ self.clear_cache()
839
+ return out
840
+
841
+ def reparameterize(self, mu, log_var):
842
+ std = torch.exp(0.5 * log_var)
843
+ eps = torch.randn_like(std)
844
+ return eps * std + mu
845
+
846
+ def sample(self, imgs, deterministic=False):
847
+ mu, log_var = self.encode(imgs)
848
+ if deterministic:
849
+ return mu
850
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
851
+ return mu + std * torch.randn_like(std)
852
+
853
+ def clear_cache(self):
854
+ self._conv_num = count_conv3d(self.decoder)
855
+ self._conv_idx = [0]
856
+ self._feat_map = [None] * self._conv_num
857
+ # cache encode
858
+ self._enc_conv_num = count_conv3d(self.encoder)
859
+ self._enc_conv_idx = [0]
860
+ self._enc_feat_map = [None] * self._enc_conv_num
861
+
862
+
863
+ def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
864
+ # params
865
+ cfg = dict(
866
+ dim=dim,
867
+ z_dim=z_dim,
868
+ dim_mult=[1, 2, 4, 4],
869
+ num_res_blocks=2,
870
+ attn_scales=[],
871
+ temperal_downsample=[True, True, True],
872
+ dropout=0.0,
873
+ )
874
+ cfg.update(**kwargs)
875
+
876
+ # init model
877
+ with torch.device("meta"):
878
+ model = WanVAE_(**cfg)
879
+
880
+ # load checkpoint
881
+ logging.info(f"loading {pretrained_path}")
882
+ model.load_state_dict(
883
+ torch.load(pretrained_path, map_location=device), assign=True)
884
+
885
+ return model
886
+
887
+
888
+ class Wan2_2_VAE:
889
+
890
+ def __init__(
891
+ self,
892
+ z_dim=48,
893
+ c_dim=160,
894
+ vae_pth=None,
895
+ dim_mult=[1, 2, 4, 4],
896
+ temperal_downsample=[False, True, True],
897
+ dtype=torch.float,
898
+ device="cuda",
899
+ ):
900
+
901
+ self.dtype = dtype
902
+ self.device = device
903
+
904
+ mean = torch.tensor(
905
+ [
906
+ -0.2289,
907
+ -0.0052,
908
+ -0.1323,
909
+ -0.2339,
910
+ -0.2799,
911
+ 0.0174,
912
+ 0.1838,
913
+ 0.1557,
914
+ -0.1382,
915
+ 0.0542,
916
+ 0.2813,
917
+ 0.0891,
918
+ 0.1570,
919
+ -0.0098,
920
+ 0.0375,
921
+ -0.1825,
922
+ -0.2246,
923
+ -0.1207,
924
+ -0.0698,
925
+ 0.5109,
926
+ 0.2665,
927
+ -0.2108,
928
+ -0.2158,
929
+ 0.2502,
930
+ -0.2055,
931
+ -0.0322,
932
+ 0.1109,
933
+ 0.1567,
934
+ -0.0729,
935
+ 0.0899,
936
+ -0.2799,
937
+ -0.1230,
938
+ -0.0313,
939
+ -0.1649,
940
+ 0.0117,
941
+ 0.0723,
942
+ -0.2839,
943
+ -0.2083,
944
+ -0.0520,
945
+ 0.3748,
946
+ 0.0152,
947
+ 0.1957,
948
+ 0.1433,
949
+ -0.2944,
950
+ 0.3573,
951
+ -0.0548,
952
+ -0.1681,
953
+ -0.0667,
954
+ ],
955
+ dtype=dtype,
956
+ device=device,
957
+ )
958
+ std = torch.tensor(
959
+ [
960
+ 0.4765,
961
+ 1.0364,
962
+ 0.4514,
963
+ 1.1677,
964
+ 0.5313,
965
+ 0.4990,
966
+ 0.4818,
967
+ 0.5013,
968
+ 0.8158,
969
+ 1.0344,
970
+ 0.5894,
971
+ 1.0901,
972
+ 0.6885,
973
+ 0.6165,
974
+ 0.8454,
975
+ 0.4978,
976
+ 0.5759,
977
+ 0.3523,
978
+ 0.7135,
979
+ 0.6804,
980
+ 0.5833,
981
+ 1.4146,
982
+ 0.8986,
983
+ 0.5659,
984
+ 0.7069,
985
+ 0.5338,
986
+ 0.4889,
987
+ 0.4917,
988
+ 0.4069,
989
+ 0.4999,
990
+ 0.6866,
991
+ 0.4093,
992
+ 0.5709,
993
+ 0.6065,
994
+ 0.6415,
995
+ 0.4944,
996
+ 0.5726,
997
+ 1.2042,
998
+ 0.5458,
999
+ 1.6887,
1000
+ 0.3971,
1001
+ 1.0600,
1002
+ 0.3943,
1003
+ 0.5537,
1004
+ 0.5444,
1005
+ 0.4089,
1006
+ 0.7468,
1007
+ 0.7744,
1008
+ ],
1009
+ dtype=dtype,
1010
+ device=device,
1011
+ )
1012
+ self.scale = [mean, 1.0 / std]
1013
+
1014
+ # init model
1015
+ self.model = (
1016
+ _video_vae(
1017
+ pretrained_path=vae_pth,
1018
+ z_dim=z_dim,
1019
+ dim=c_dim,
1020
+ dim_mult=dim_mult,
1021
+ temperal_downsample=temperal_downsample,
1022
+ ).eval().requires_grad_(False).to(device))
1023
+
1024
+ def encode(self, videos):
1025
+ try:
1026
+ if not isinstance(videos, list):
1027
+ raise TypeError("videos should be a list")
1028
+ with amp.autocast(dtype=self.dtype):
1029
+ return [
1030
+ self.model.encode(u.unsqueeze(0),
1031
+ self.scale).float().squeeze(0)
1032
+ for u in videos
1033
+ ]
1034
+ except TypeError as e:
1035
+ logging.info(e)
1036
+ return None
1037
+
1038
+ def decode(self, zs):
1039
+ try:
1040
+ if not isinstance(zs, list):
1041
+ raise TypeError("zs should be a list")
1042
+ with amp.autocast(dtype=self.dtype):
1043
+ return [
1044
+ self.model.decode(u.unsqueeze(0),
1045
+ self.scale).float().clamp_(-1,
1046
+ 1).squeeze(0)
1047
+ for u in zs
1048
+ ]
1049
+ except TypeError as e:
1050
+ logging.info(e)
1051
+ return None
wan/text2video.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import torch
13
+ import torch.cuda.amp as amp
14
+ import torch.distributed as dist
15
+ from tqdm import tqdm
16
+
17
+ from .distributed.fsdp import shard_model
18
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
19
+ from .distributed.util import get_world_size
20
+ from .modules.model import WanModel
21
+ from .modules.t5 import T5EncoderModel
22
+ from .modules.vae2_1 import Wan2_1_VAE
23
+ from .utils.fm_solvers import (
24
+ FlowDPMSolverMultistepScheduler,
25
+ get_sampling_sigmas,
26
+ retrieve_timesteps,
27
+ )
28
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
29
+
30
+
31
+ class WanT2V:
32
+
33
+ def __init__(
34
+ self,
35
+ config,
36
+ checkpoint_dir,
37
+ device_id=0,
38
+ rank=0,
39
+ t5_fsdp=False,
40
+ dit_fsdp=False,
41
+ use_sp=False,
42
+ t5_cpu=False,
43
+ init_on_cpu=True,
44
+ convert_model_dtype=False,
45
+ ):
46
+ r"""
47
+ Initializes the Wan text-to-video generation model components.
48
+
49
+ Args:
50
+ config (EasyDict):
51
+ Object containing model parameters initialized from config.py
52
+ checkpoint_dir (`str`):
53
+ Path to directory containing model checkpoints
54
+ device_id (`int`, *optional*, defaults to 0):
55
+ Id of target GPU device
56
+ rank (`int`, *optional*, defaults to 0):
57
+ Process rank for distributed training
58
+ t5_fsdp (`bool`, *optional*, defaults to False):
59
+ Enable FSDP sharding for T5 model
60
+ dit_fsdp (`bool`, *optional*, defaults to False):
61
+ Enable FSDP sharding for DiT model
62
+ use_sp (`bool`, *optional*, defaults to False):
63
+ Enable distribution strategy of sequence parallel.
64
+ t5_cpu (`bool`, *optional*, defaults to False):
65
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
66
+ init_on_cpu (`bool`, *optional*, defaults to True):
67
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
68
+ convert_model_dtype (`bool`, *optional*, defaults to False):
69
+ Convert DiT model parameters dtype to 'config.param_dtype'.
70
+ Only works without FSDP.
71
+ """
72
+ self.device = torch.device(f"cuda:{device_id}")
73
+ self.config = config
74
+ self.rank = rank
75
+ self.t5_cpu = t5_cpu
76
+ self.init_on_cpu = init_on_cpu
77
+
78
+ self.num_train_timesteps = config.num_train_timesteps
79
+ self.boundary = config.boundary
80
+ self.param_dtype = config.param_dtype
81
+
82
+ if t5_fsdp or dit_fsdp or use_sp:
83
+ self.init_on_cpu = False
84
+
85
+ shard_fn = partial(shard_model, device_id=device_id)
86
+ self.text_encoder = T5EncoderModel(
87
+ text_len=config.text_len,
88
+ dtype=config.t5_dtype,
89
+ device=torch.device('cpu'),
90
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
91
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
92
+ shard_fn=shard_fn if t5_fsdp else None)
93
+
94
+ self.vae_stride = config.vae_stride
95
+ self.patch_size = config.patch_size
96
+ self.vae = Wan2_1_VAE(
97
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
98
+ device=self.device)
99
+
100
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
101
+ self.low_noise_model = WanModel.from_pretrained(
102
+ checkpoint_dir, subfolder=config.low_noise_checkpoint)
103
+ self.low_noise_model = self._configure_model(
104
+ model=self.low_noise_model,
105
+ use_sp=use_sp,
106
+ dit_fsdp=dit_fsdp,
107
+ shard_fn=shard_fn,
108
+ convert_model_dtype=convert_model_dtype)
109
+
110
+ self.high_noise_model = WanModel.from_pretrained(
111
+ checkpoint_dir, subfolder=config.high_noise_checkpoint)
112
+ self.high_noise_model = self._configure_model(
113
+ model=self.high_noise_model,
114
+ use_sp=use_sp,
115
+ dit_fsdp=dit_fsdp,
116
+ shard_fn=shard_fn,
117
+ convert_model_dtype=convert_model_dtype)
118
+ if use_sp:
119
+ self.sp_size = get_world_size()
120
+ else:
121
+ self.sp_size = 1
122
+
123
+ self.sample_neg_prompt = config.sample_neg_prompt
124
+
125
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
126
+ convert_model_dtype):
127
+ """
128
+ Configures a model object. This includes setting evaluation modes,
129
+ applying distributed parallel strategy, and handling device placement.
130
+
131
+ Args:
132
+ model (torch.nn.Module):
133
+ The model instance to configure.
134
+ use_sp (`bool`):
135
+ Enable distribution strategy of sequence parallel.
136
+ dit_fsdp (`bool`):
137
+ Enable FSDP sharding for DiT model.
138
+ shard_fn (callable):
139
+ The function to apply FSDP sharding.
140
+ convert_model_dtype (`bool`):
141
+ Convert DiT model parameters dtype to 'config.param_dtype'.
142
+ Only works without FSDP.
143
+
144
+ Returns:
145
+ torch.nn.Module:
146
+ The configured model.
147
+ """
148
+ model.eval().requires_grad_(False)
149
+
150
+ if use_sp:
151
+ for block in model.blocks:
152
+ block.self_attn.forward = types.MethodType(
153
+ sp_attn_forward, block.self_attn)
154
+ model.forward = types.MethodType(sp_dit_forward, model)
155
+
156
+ if dist.is_initialized():
157
+ dist.barrier()
158
+
159
+ if dit_fsdp:
160
+ model = shard_fn(model)
161
+ else:
162
+ if convert_model_dtype:
163
+ model.to(self.param_dtype)
164
+ if not self.init_on_cpu:
165
+ model.to(self.device)
166
+
167
+ return model
168
+
169
+ def _prepare_model_for_timestep(self, t, boundary, offload_model):
170
+ r"""
171
+ Prepares and returns the required model for the current timestep.
172
+
173
+ Args:
174
+ t (torch.Tensor):
175
+ current timestep.
176
+ boundary (`int`):
177
+ The timestep threshold. If `t` is at or above this value,
178
+ the `high_noise_model` is considered as the required model.
179
+ offload_model (`bool`):
180
+ A flag intended to control the offloading behavior.
181
+
182
+ Returns:
183
+ torch.nn.Module:
184
+ The active model on the target device for the current timestep.
185
+ """
186
+ if t.item() >= boundary:
187
+ required_model_name = 'high_noise_model'
188
+ offload_model_name = 'low_noise_model'
189
+ else:
190
+ required_model_name = 'low_noise_model'
191
+ offload_model_name = 'high_noise_model'
192
+ if offload_model or self.init_on_cpu:
193
+ if next(getattr(
194
+ self,
195
+ offload_model_name).parameters()).device.type == 'cuda':
196
+ getattr(self, offload_model_name).to('cpu')
197
+ if next(getattr(
198
+ self,
199
+ required_model_name).parameters()).device.type == 'cpu':
200
+ getattr(self, required_model_name).to(self.device)
201
+ return getattr(self, required_model_name)
202
+
203
+ def generate(self,
204
+ input_prompt,
205
+ size=(1280, 720),
206
+ frame_num=81,
207
+ shift=5.0,
208
+ sample_solver='unipc',
209
+ sampling_steps=50,
210
+ guide_scale=5.0,
211
+ n_prompt="",
212
+ seed=-1,
213
+ offload_model=True):
214
+ r"""
215
+ Generates video frames from text prompt using diffusion process.
216
+
217
+ Args:
218
+ input_prompt (`str`):
219
+ Text prompt for content generation
220
+ size (`tuple[int]`, *optional*, defaults to (1280,720)):
221
+ Controls video resolution, (width,height).
222
+ frame_num (`int`, *optional*, defaults to 81):
223
+ How many frames to sample from a video. The number should be 4n+1
224
+ shift (`float`, *optional*, defaults to 5.0):
225
+ Noise schedule shift parameter. Affects temporal dynamics
226
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
227
+ Solver used to sample the video.
228
+ sampling_steps (`int`, *optional*, defaults to 50):
229
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
230
+ guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0):
231
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
232
+ If tuple, the first guide_scale will be used for low noise model and
233
+ the second guide_scale will be used for high noise model.
234
+ n_prompt (`str`, *optional*, defaults to ""):
235
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
236
+ seed (`int`, *optional*, defaults to -1):
237
+ Random seed for noise generation. If -1, use random seed.
238
+ offload_model (`bool`, *optional*, defaults to True):
239
+ If True, offloads models to CPU during generation to save VRAM
240
+
241
+ Returns:
242
+ torch.Tensor:
243
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
244
+ - C: Color channels (3 for RGB)
245
+ - N: Number of frames (81)
246
+ - H: Frame height (from size)
247
+ - W: Frame width from size)
248
+ """
249
+ # preprocess
250
+ guide_scale = (guide_scale, guide_scale) if isinstance(
251
+ guide_scale, float) else guide_scale
252
+ F = frame_num
253
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
254
+ size[1] // self.vae_stride[1],
255
+ size[0] // self.vae_stride[2])
256
+
257
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
258
+ (self.patch_size[1] * self.patch_size[2]) *
259
+ target_shape[1] / self.sp_size) * self.sp_size
260
+
261
+ if n_prompt == "":
262
+ n_prompt = self.sample_neg_prompt
263
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
264
+ seed_g = torch.Generator(device=self.device)
265
+ seed_g.manual_seed(seed)
266
+
267
+ if not self.t5_cpu:
268
+ self.text_encoder.model.to(self.device)
269
+ context = self.text_encoder([input_prompt], self.device)
270
+ context_null = self.text_encoder([n_prompt], self.device)
271
+ if offload_model:
272
+ self.text_encoder.model.cpu()
273
+ else:
274
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
275
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
276
+ context = [t.to(self.device) for t in context]
277
+ context_null = [t.to(self.device) for t in context_null]
278
+
279
+ noise = [
280
+ torch.randn(
281
+ target_shape[0],
282
+ target_shape[1],
283
+ target_shape[2],
284
+ target_shape[3],
285
+ dtype=torch.float32,
286
+ device=self.device,
287
+ generator=seed_g)
288
+ ]
289
+
290
+ @contextmanager
291
+ def noop_no_sync():
292
+ yield
293
+
294
+ no_sync_low_noise = getattr(self.low_noise_model, 'no_sync',
295
+ noop_no_sync)
296
+ no_sync_high_noise = getattr(self.high_noise_model, 'no_sync',
297
+ noop_no_sync)
298
+
299
+ # evaluation mode
300
+ with (
301
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
302
+ torch.no_grad(),
303
+ no_sync_low_noise(),
304
+ no_sync_high_noise(),
305
+ ):
306
+ boundary = self.boundary * self.num_train_timesteps
307
+
308
+ if sample_solver == 'unipc':
309
+ sample_scheduler = FlowUniPCMultistepScheduler(
310
+ num_train_timesteps=self.num_train_timesteps,
311
+ shift=1,
312
+ use_dynamic_shifting=False)
313
+ sample_scheduler.set_timesteps(
314
+ sampling_steps, device=self.device, shift=shift)
315
+ timesteps = sample_scheduler.timesteps
316
+ elif sample_solver == 'dpm++':
317
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
318
+ num_train_timesteps=self.num_train_timesteps,
319
+ shift=1,
320
+ use_dynamic_shifting=False)
321
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
322
+ timesteps, _ = retrieve_timesteps(
323
+ sample_scheduler,
324
+ device=self.device,
325
+ sigmas=sampling_sigmas)
326
+ else:
327
+ raise NotImplementedError("Unsupported solver.")
328
+
329
+ # sample videos
330
+ latents = noise
331
+
332
+ arg_c = {'context': context, 'seq_len': seq_len}
333
+ arg_null = {'context': context_null, 'seq_len': seq_len}
334
+
335
+ for _, t in enumerate(tqdm(timesteps)):
336
+ latent_model_input = latents
337
+ timestep = [t]
338
+
339
+ timestep = torch.stack(timestep)
340
+
341
+ model = self._prepare_model_for_timestep(
342
+ t, boundary, offload_model)
343
+ sample_guide_scale = guide_scale[1] if t.item(
344
+ ) >= boundary else guide_scale[0]
345
+
346
+ noise_pred_cond = model(
347
+ latent_model_input, t=timestep, **arg_c)[0]
348
+ noise_pred_uncond = model(
349
+ latent_model_input, t=timestep, **arg_null)[0]
350
+
351
+ noise_pred = noise_pred_uncond + sample_guide_scale * (
352
+ noise_pred_cond - noise_pred_uncond)
353
+
354
+ temp_x0 = sample_scheduler.step(
355
+ noise_pred.unsqueeze(0),
356
+ t,
357
+ latents[0].unsqueeze(0),
358
+ return_dict=False,
359
+ generator=seed_g)[0]
360
+ latents = [temp_x0.squeeze(0)]
361
+
362
+ x0 = latents
363
+ if offload_model:
364
+ self.low_noise_model.cpu()
365
+ self.high_noise_model.cpu()
366
+ torch.cuda.empty_cache()
367
+ if self.rank == 0:
368
+ videos = self.vae.decode(x0)
369
+
370
+ del noise, latents
371
+ del sample_scheduler
372
+ if offload_model:
373
+ gc.collect()
374
+ torch.cuda.synchronize()
375
+ if dist.is_initialized():
376
+ dist.barrier()
377
+
378
+ return videos[0] if self.rank == 0 else None
wan/textimage2video.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import gc
3
+ import logging
4
+ import math
5
+ import os
6
+ import random
7
+ import sys
8
+ import types
9
+ from contextlib import contextmanager
10
+ from functools import partial
11
+
12
+ import torch
13
+ import torch.cuda.amp as amp
14
+ import torch.distributed as dist
15
+ import torchvision.transforms.functional as TF
16
+ from PIL import Image
17
+ from tqdm import tqdm
18
+
19
+ from .distributed.fsdp import shard_model
20
+ from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward
21
+ from .distributed.util import get_world_size
22
+ from .modules.model import WanModel
23
+ from .modules.t5 import T5EncoderModel
24
+ from .modules.vae2_2 import Wan2_2_VAE
25
+ from .utils.fm_solvers import (
26
+ FlowDPMSolverMultistepScheduler,
27
+ get_sampling_sigmas,
28
+ retrieve_timesteps,
29
+ )
30
+ from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
31
+ from .utils.utils import best_output_size, masks_like
32
+
33
+
34
+ class WanTI2V:
35
+
36
+ def __init__(
37
+ self,
38
+ config,
39
+ checkpoint_dir,
40
+ device_id=0,
41
+ rank=0,
42
+ t5_fsdp=False,
43
+ dit_fsdp=False,
44
+ use_sp=False,
45
+ t5_cpu=False,
46
+ init_on_cpu=True,
47
+ convert_model_dtype=False,
48
+ ):
49
+ r"""
50
+ Initializes the Wan text-to-video generation model components.
51
+
52
+ Args:
53
+ config (EasyDict):
54
+ Object containing model parameters initialized from config.py
55
+ checkpoint_dir (`str`):
56
+ Path to directory containing model checkpoints
57
+ device_id (`int`, *optional*, defaults to 0):
58
+ Id of target GPU device
59
+ rank (`int`, *optional*, defaults to 0):
60
+ Process rank for distributed training
61
+ t5_fsdp (`bool`, *optional*, defaults to False):
62
+ Enable FSDP sharding for T5 model
63
+ dit_fsdp (`bool`, *optional*, defaults to False):
64
+ Enable FSDP sharding for DiT model
65
+ use_sp (`bool`, *optional*, defaults to False):
66
+ Enable distribution strategy of sequence parallel.
67
+ t5_cpu (`bool`, *optional*, defaults to False):
68
+ Whether to place T5 model on CPU. Only works without t5_fsdp.
69
+ init_on_cpu (`bool`, *optional*, defaults to True):
70
+ Enable initializing Transformer Model on CPU. Only works without FSDP or USP.
71
+ convert_model_dtype (`bool`, *optional*, defaults to False):
72
+ Convert DiT model parameters dtype to 'config.param_dtype'.
73
+ Only works without FSDP.
74
+ """
75
+ self.device = torch.device(f"cuda:{device_id}")
76
+ self.config = config
77
+ self.rank = rank
78
+ self.t5_cpu = t5_cpu
79
+ self.init_on_cpu = init_on_cpu
80
+
81
+ self.num_train_timesteps = config.num_train_timesteps
82
+ self.param_dtype = config.param_dtype
83
+
84
+ if t5_fsdp or dit_fsdp or use_sp:
85
+ self.init_on_cpu = False
86
+
87
+ shard_fn = partial(shard_model, device_id=device_id)
88
+ self.text_encoder = T5EncoderModel(
89
+ text_len=config.text_len,
90
+ dtype=config.t5_dtype,
91
+ device=torch.device('cpu'),
92
+ checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint),
93
+ tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer),
94
+ shard_fn=shard_fn if t5_fsdp else None)
95
+
96
+ self.vae_stride = config.vae_stride
97
+ self.patch_size = config.patch_size
98
+ self.vae = Wan2_2_VAE(
99
+ vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint),
100
+ device=self.device)
101
+
102
+ logging.info(f"Creating WanModel from {checkpoint_dir}")
103
+ self.model = WanModel.from_pretrained(checkpoint_dir)
104
+ self.model = self._configure_model(
105
+ model=self.model,
106
+ use_sp=use_sp,
107
+ dit_fsdp=dit_fsdp,
108
+ shard_fn=shard_fn,
109
+ convert_model_dtype=convert_model_dtype)
110
+
111
+ if use_sp:
112
+ self.sp_size = get_world_size()
113
+ else:
114
+ self.sp_size = 1
115
+
116
+ self.sample_neg_prompt = config.sample_neg_prompt
117
+
118
+ def _configure_model(self, model, use_sp, dit_fsdp, shard_fn,
119
+ convert_model_dtype):
120
+ """
121
+ Configures a model object. This includes setting evaluation modes,
122
+ applying distributed parallel strategy, and handling device placement.
123
+
124
+ Args:
125
+ model (torch.nn.Module):
126
+ The model instance to configure.
127
+ use_sp (`bool`):
128
+ Enable distribution strategy of sequence parallel.
129
+ dit_fsdp (`bool`):
130
+ Enable FSDP sharding for DiT model.
131
+ shard_fn (callable):
132
+ The function to apply FSDP sharding.
133
+ convert_model_dtype (`bool`):
134
+ Convert DiT model parameters dtype to 'config.param_dtype'.
135
+ Only works without FSDP.
136
+
137
+ Returns:
138
+ torch.nn.Module:
139
+ The configured model.
140
+ """
141
+ model.eval().requires_grad_(False)
142
+
143
+ if use_sp:
144
+ for block in model.blocks:
145
+ block.self_attn.forward = types.MethodType(
146
+ sp_attn_forward, block.self_attn)
147
+ model.forward = types.MethodType(sp_dit_forward, model)
148
+
149
+ if dist.is_initialized():
150
+ dist.barrier()
151
+
152
+ if dit_fsdp:
153
+ model = shard_fn(model)
154
+ else:
155
+ if convert_model_dtype:
156
+ model.to(self.param_dtype)
157
+ if not self.init_on_cpu:
158
+ model.to(self.device)
159
+
160
+ return model
161
+
162
+ def generate(self,
163
+ input_prompt,
164
+ img=None,
165
+ size=(1280, 704),
166
+ max_area=704 * 1280,
167
+ frame_num=81,
168
+ shift=5.0,
169
+ sample_solver='unipc',
170
+ sampling_steps=50,
171
+ guide_scale=5.0,
172
+ n_prompt="",
173
+ seed=-1,
174
+ offload_model=True):
175
+ r"""
176
+ Generates video frames from text prompt using diffusion process.
177
+
178
+ Args:
179
+ input_prompt (`str`):
180
+ Text prompt for content generation
181
+ img (PIL.Image.Image):
182
+ Input image tensor. Shape: [3, H, W]
183
+ size (`tuple[int]`, *optional*, defaults to (1280,704)):
184
+ Controls video resolution, (width,height).
185
+ max_area (`int`, *optional*, defaults to 704*1280):
186
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
187
+ frame_num (`int`, *optional*, defaults to 81):
188
+ How many frames to sample from a video. The number should be 4n+1
189
+ shift (`float`, *optional*, defaults to 5.0):
190
+ Noise schedule shift parameter. Affects temporal dynamics
191
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
192
+ Solver used to sample the video.
193
+ sampling_steps (`int`, *optional*, defaults to 50):
194
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
195
+ guide_scale (`float`, *optional*, defaults 5.0):
196
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
197
+ n_prompt (`str`, *optional*, defaults to ""):
198
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
199
+ seed (`int`, *optional*, defaults to -1):
200
+ Random seed for noise generation. If -1, use random seed.
201
+ offload_model (`bool`, *optional*, defaults to True):
202
+ If True, offloads models to CPU during generation to save VRAM
203
+
204
+ Returns:
205
+ torch.Tensor:
206
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
207
+ - C: Color channels (3 for RGB)
208
+ - N: Number of frames (81)
209
+ - H: Frame height (from size)
210
+ - W: Frame width from size)
211
+ """
212
+ # i2v
213
+ if img is not None:
214
+ return self.i2v(
215
+ input_prompt=input_prompt,
216
+ img=img,
217
+ max_area=max_area,
218
+ frame_num=frame_num,
219
+ shift=shift,
220
+ sample_solver=sample_solver,
221
+ sampling_steps=sampling_steps,
222
+ guide_scale=guide_scale,
223
+ n_prompt=n_prompt,
224
+ seed=seed,
225
+ offload_model=offload_model)
226
+ # t2v
227
+ return self.t2v(
228
+ input_prompt=input_prompt,
229
+ size=size,
230
+ frame_num=frame_num,
231
+ shift=shift,
232
+ sample_solver=sample_solver,
233
+ sampling_steps=sampling_steps,
234
+ guide_scale=guide_scale,
235
+ n_prompt=n_prompt,
236
+ seed=seed,
237
+ offload_model=offload_model)
238
+
239
+ def t2v(self,
240
+ input_prompt,
241
+ size=(1280, 704),
242
+ frame_num=121,
243
+ shift=5.0,
244
+ sample_solver='unipc',
245
+ sampling_steps=50,
246
+ guide_scale=5.0,
247
+ n_prompt="",
248
+ seed=-1,
249
+ offload_model=True):
250
+ r"""
251
+ Generates video frames from text prompt using diffusion process.
252
+
253
+ Args:
254
+ input_prompt (`str`):
255
+ Text prompt for content generation
256
+ size (`tuple[int]`, *optional*, defaults to (1280,704)):
257
+ Controls video resolution, (width,height).
258
+ frame_num (`int`, *optional*, defaults to 121):
259
+ How many frames to sample from a video. The number should be 4n+1
260
+ shift (`float`, *optional*, defaults to 5.0):
261
+ Noise schedule shift parameter. Affects temporal dynamics
262
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
263
+ Solver used to sample the video.
264
+ sampling_steps (`int`, *optional*, defaults to 50):
265
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
266
+ guide_scale (`float`, *optional*, defaults 5.0):
267
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
268
+ n_prompt (`str`, *optional*, defaults to ""):
269
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
270
+ seed (`int`, *optional*, defaults to -1):
271
+ Random seed for noise generation. If -1, use random seed.
272
+ offload_model (`bool`, *optional*, defaults to True):
273
+ If True, offloads models to CPU during generation to save VRAM
274
+
275
+ Returns:
276
+ torch.Tensor:
277
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
278
+ - C: Color channels (3 for RGB)
279
+ - N: Number of frames (81)
280
+ - H: Frame height (from size)
281
+ - W: Frame width from size)
282
+ """
283
+ # preprocess
284
+ F = frame_num
285
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
286
+ size[1] // self.vae_stride[1],
287
+ size[0] // self.vae_stride[2])
288
+
289
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
290
+ (self.patch_size[1] * self.patch_size[2]) *
291
+ target_shape[1] / self.sp_size) * self.sp_size
292
+
293
+ if n_prompt == "":
294
+ n_prompt = self.sample_neg_prompt
295
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
296
+ seed_g = torch.Generator(device=self.device)
297
+ seed_g.manual_seed(seed)
298
+
299
+ if not self.t5_cpu:
300
+ self.text_encoder.model.to(self.device)
301
+ context = self.text_encoder([input_prompt], self.device)
302
+ context_null = self.text_encoder([n_prompt], self.device)
303
+ if offload_model:
304
+ self.text_encoder.model.cpu()
305
+ else:
306
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
307
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
308
+ context = [t.to(self.device) for t in context]
309
+ context_null = [t.to(self.device) for t in context_null]
310
+
311
+ noise = [
312
+ torch.randn(
313
+ target_shape[0],
314
+ target_shape[1],
315
+ target_shape[2],
316
+ target_shape[3],
317
+ dtype=torch.float32,
318
+ device=self.device,
319
+ generator=seed_g)
320
+ ]
321
+
322
+ @contextmanager
323
+ def noop_no_sync():
324
+ yield
325
+
326
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
327
+
328
+ # evaluation mode
329
+ with (
330
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
331
+ torch.no_grad(),
332
+ no_sync(),
333
+ ):
334
+
335
+ if sample_solver == 'unipc':
336
+ sample_scheduler = FlowUniPCMultistepScheduler(
337
+ num_train_timesteps=self.num_train_timesteps,
338
+ shift=1,
339
+ use_dynamic_shifting=False)
340
+ sample_scheduler.set_timesteps(
341
+ sampling_steps, device=self.device, shift=shift)
342
+ timesteps = sample_scheduler.timesteps
343
+ elif sample_solver == 'dpm++':
344
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
345
+ num_train_timesteps=self.num_train_timesteps,
346
+ shift=1,
347
+ use_dynamic_shifting=False)
348
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
349
+ timesteps, _ = retrieve_timesteps(
350
+ sample_scheduler,
351
+ device=self.device,
352
+ sigmas=sampling_sigmas)
353
+ else:
354
+ raise NotImplementedError("Unsupported solver.")
355
+
356
+ # sample videos
357
+ latents = noise
358
+ mask1, mask2 = masks_like(noise, zero=False)
359
+
360
+ arg_c = {'context': context, 'seq_len': seq_len}
361
+ arg_null = {'context': context_null, 'seq_len': seq_len}
362
+
363
+ if offload_model or self.init_on_cpu:
364
+ self.model.to(self.device)
365
+ torch.cuda.empty_cache()
366
+
367
+ for _, t in enumerate(tqdm(timesteps)):
368
+ latent_model_input = latents
369
+ timestep = [t]
370
+
371
+ timestep = torch.stack(timestep)
372
+
373
+ temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
374
+ temp_ts = torch.cat([
375
+ temp_ts,
376
+ temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
377
+ ])
378
+ timestep = temp_ts.unsqueeze(0)
379
+
380
+ noise_pred_cond = self.model(
381
+ latent_model_input, t=timestep, **arg_c)[0]
382
+ noise_pred_uncond = self.model(
383
+ latent_model_input, t=timestep, **arg_null)[0]
384
+
385
+ noise_pred = noise_pred_uncond + guide_scale * (
386
+ noise_pred_cond - noise_pred_uncond)
387
+
388
+ temp_x0 = sample_scheduler.step(
389
+ noise_pred.unsqueeze(0),
390
+ t,
391
+ latents[0].unsqueeze(0),
392
+ return_dict=False,
393
+ generator=seed_g)[0]
394
+ latents = [temp_x0.squeeze(0)]
395
+ x0 = latents
396
+ if offload_model:
397
+ self.model.cpu()
398
+ torch.cuda.synchronize()
399
+ torch.cuda.empty_cache()
400
+ if self.rank == 0:
401
+ videos = self.vae.decode(x0)
402
+
403
+ del noise, latents
404
+ del sample_scheduler
405
+ if offload_model:
406
+ gc.collect()
407
+ torch.cuda.synchronize()
408
+ if dist.is_initialized():
409
+ dist.barrier()
410
+
411
+ return videos[0] if self.rank == 0 else None
412
+
413
+ def i2v(self,
414
+ input_prompt,
415
+ img,
416
+ max_area=704 * 1280,
417
+ frame_num=121,
418
+ shift=5.0,
419
+ sample_solver='unipc',
420
+ sampling_steps=40,
421
+ guide_scale=5.0,
422
+ n_prompt="",
423
+ seed=-1,
424
+ offload_model=True):
425
+ r"""
426
+ Generates video frames from input image and text prompt using diffusion process.
427
+
428
+ Args:
429
+ input_prompt (`str`):
430
+ Text prompt for content generation.
431
+ img (PIL.Image.Image):
432
+ Input image tensor. Shape: [3, H, W]
433
+ max_area (`int`, *optional*, defaults to 704*1280):
434
+ Maximum pixel area for latent space calculation. Controls video resolution scaling
435
+ frame_num (`int`, *optional*, defaults to 121):
436
+ How many frames to sample from a video. The number should be 4n+1
437
+ shift (`float`, *optional*, defaults to 5.0):
438
+ Noise schedule shift parameter. Affects temporal dynamics
439
+ [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0.
440
+ sample_solver (`str`, *optional*, defaults to 'unipc'):
441
+ Solver used to sample the video.
442
+ sampling_steps (`int`, *optional*, defaults to 40):
443
+ Number of diffusion sampling steps. Higher values improve quality but slow generation
444
+ guide_scale (`float`, *optional*, defaults 5.0):
445
+ Classifier-free guidance scale. Controls prompt adherence vs. creativity.
446
+ n_prompt (`str`, *optional*, defaults to ""):
447
+ Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt`
448
+ seed (`int`, *optional*, defaults to -1):
449
+ Random seed for noise generation. If -1, use random seed
450
+ offload_model (`bool`, *optional*, defaults to True):
451
+ If True, offloads models to CPU during generation to save VRAM
452
+
453
+ Returns:
454
+ torch.Tensor:
455
+ Generated video frames tensor. Dimensions: (C, N H, W) where:
456
+ - C: Color channels (3 for RGB)
457
+ - N: Number of frames (121)
458
+ - H: Frame height (from max_area)
459
+ - W: Frame width (from max_area)
460
+ """
461
+ # preprocess
462
+ ih, iw = img.height, img.width
463
+ dh, dw = self.patch_size[1] * self.vae_stride[1], self.patch_size[
464
+ 2] * self.vae_stride[2]
465
+ ow, oh = best_output_size(iw, ih, dw, dh, max_area)
466
+
467
+ scale = max(ow / iw, oh / ih)
468
+ img = img.resize((round(iw * scale), round(ih * scale)), Image.LANCZOS)
469
+
470
+ # center-crop
471
+ x1 = (img.width - ow) // 2
472
+ y1 = (img.height - oh) // 2
473
+ img = img.crop((x1, y1, x1 + ow, y1 + oh))
474
+ assert img.width == ow and img.height == oh
475
+
476
+ # to tensor
477
+ img = TF.to_tensor(img).sub_(0.5).div_(0.5).to(self.device).unsqueeze(1)
478
+
479
+ F = frame_num
480
+ seq_len = ((F - 1) // self.vae_stride[0] + 1) * (
481
+ oh // self.vae_stride[1]) * (ow // self.vae_stride[2]) // (
482
+ self.patch_size[1] * self.patch_size[2])
483
+ seq_len = int(math.ceil(seq_len / self.sp_size)) * self.sp_size
484
+
485
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
486
+ seed_g = torch.Generator(device=self.device)
487
+ seed_g.manual_seed(seed)
488
+ noise = torch.randn(
489
+ self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1,
490
+ oh // self.vae_stride[1],
491
+ ow // self.vae_stride[2],
492
+ dtype=torch.float32,
493
+ generator=seed_g,
494
+ device=self.device)
495
+
496
+ if n_prompt == "":
497
+ n_prompt = self.sample_neg_prompt
498
+
499
+ # preprocess
500
+ if not self.t5_cpu:
501
+ self.text_encoder.model.to(self.device)
502
+ context = self.text_encoder([input_prompt], self.device)
503
+ context_null = self.text_encoder([n_prompt], self.device)
504
+ if offload_model:
505
+ self.text_encoder.model.cpu()
506
+ else:
507
+ context = self.text_encoder([input_prompt], torch.device('cpu'))
508
+ context_null = self.text_encoder([n_prompt], torch.device('cpu'))
509
+ context = [t.to(self.device) for t in context]
510
+ context_null = [t.to(self.device) for t in context_null]
511
+
512
+ z = self.vae.encode([img])
513
+
514
+ @contextmanager
515
+ def noop_no_sync():
516
+ yield
517
+
518
+ no_sync = getattr(self.model, 'no_sync', noop_no_sync)
519
+
520
+ # evaluation mode
521
+ with (
522
+ torch.amp.autocast('cuda', dtype=self.param_dtype),
523
+ torch.no_grad(),
524
+ no_sync(),
525
+ ):
526
+
527
+ if sample_solver == 'unipc':
528
+ sample_scheduler = FlowUniPCMultistepScheduler(
529
+ num_train_timesteps=self.num_train_timesteps,
530
+ shift=1,
531
+ use_dynamic_shifting=False)
532
+ sample_scheduler.set_timesteps(
533
+ sampling_steps, device=self.device, shift=shift)
534
+ timesteps = sample_scheduler.timesteps
535
+ elif sample_solver == 'dpm++':
536
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
537
+ num_train_timesteps=self.num_train_timesteps,
538
+ shift=1,
539
+ use_dynamic_shifting=False)
540
+ sampling_sigmas = get_sampling_sigmas(sampling_steps, shift)
541
+ timesteps, _ = retrieve_timesteps(
542
+ sample_scheduler,
543
+ device=self.device,
544
+ sigmas=sampling_sigmas)
545
+ else:
546
+ raise NotImplementedError("Unsupported solver.")
547
+
548
+ # sample videos
549
+ latent = noise
550
+ mask1, mask2 = masks_like([noise], zero=True)
551
+ latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
552
+
553
+ arg_c = {
554
+ 'context': [context[0]],
555
+ 'seq_len': seq_len,
556
+ }
557
+
558
+ arg_null = {
559
+ 'context': context_null,
560
+ 'seq_len': seq_len,
561
+ }
562
+
563
+ if offload_model or self.init_on_cpu:
564
+ self.model.to(self.device)
565
+ torch.cuda.empty_cache()
566
+
567
+ for _, t in enumerate(tqdm(timesteps)):
568
+ latent_model_input = [latent.to(self.device)]
569
+ timestep = [t]
570
+
571
+ timestep = torch.stack(timestep).to(self.device)
572
+
573
+ temp_ts = (mask2[0][0][:, ::2, ::2] * timestep).flatten()
574
+ temp_ts = torch.cat([
575
+ temp_ts,
576
+ temp_ts.new_ones(seq_len - temp_ts.size(0)) * timestep
577
+ ])
578
+ timestep = temp_ts.unsqueeze(0)
579
+
580
+ noise_pred_cond = self.model(
581
+ latent_model_input, t=timestep, **arg_c)[0]
582
+ if offload_model:
583
+ torch.cuda.empty_cache()
584
+ noise_pred_uncond = self.model(
585
+ latent_model_input, t=timestep, **arg_null)[0]
586
+ if offload_model:
587
+ torch.cuda.empty_cache()
588
+ noise_pred = noise_pred_uncond + guide_scale * (
589
+ noise_pred_cond - noise_pred_uncond)
590
+
591
+ temp_x0 = sample_scheduler.step(
592
+ noise_pred.unsqueeze(0),
593
+ t,
594
+ latent.unsqueeze(0),
595
+ return_dict=False,
596
+ generator=seed_g)[0]
597
+ latent = temp_x0.squeeze(0)
598
+ latent = (1. - mask2[0]) * z[0] + mask2[0] * latent
599
+
600
+ x0 = [latent]
601
+ del latent_model_input, timestep
602
+
603
+ if offload_model:
604
+ self.model.cpu()
605
+ torch.cuda.synchronize()
606
+ torch.cuda.empty_cache()
607
+
608
+ if self.rank == 0:
609
+ videos = self.vae.decode(x0)
610
+
611
+ del noise, latent, x0
612
+ del sample_scheduler
613
+ if offload_model:
614
+ gc.collect()
615
+ torch.cuda.synchronize()
616
+ if dist.is_initialized():
617
+ dist.barrier()
618
+
619
+ return videos[0] if self.rank == 0 else None
wan/utils/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ from .fm_solvers import (
3
+ FlowDPMSolverMultistepScheduler,
4
+ get_sampling_sigmas,
5
+ retrieve_timesteps,
6
+ )
7
+ from .fm_solvers_unipc import FlowUniPCMultistepScheduler
8
+
9
+ __all__ = [
10
+ 'HuggingfaceTokenizer', 'get_sampling_sigmas', 'retrieve_timesteps',
11
+ 'FlowDPMSolverMultistepScheduler', 'FlowUniPCMultistepScheduler'
12
+ ]
wan/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (401 Bytes). View file
 
wan/utils/__pycache__/fm_solvers.cpython-310.pyc ADDED
Binary file (26.1 kB). View file