oKen38461 commited on
Commit
e9de5f5
·
1 Parent(s): dcaedff

Import WanGP source code (without large assets)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -35
  2. .gitignore +42 -0
  3. LICENSE.txt +17 -0
  4. README.md +128 -8
  5. docs/CHANGELOG.md +153 -0
  6. docs/CLI.md +226 -0
  7. docs/GETTING_STARTED.md +194 -0
  8. docs/INSTALLATION.md +170 -0
  9. docs/LORAS.md +224 -0
  10. docs/MODELS.md +268 -0
  11. docs/TROUBLESHOOTING.md +338 -0
  12. docs/VACE.md +190 -0
  13. fantasytalking/infer.py +36 -0
  14. fantasytalking/model.py +162 -0
  15. fantasytalking/utils.py +52 -0
  16. hyvideo/__init__.py +0 -0
  17. hyvideo/config.py +534 -0
  18. hyvideo/constants.py +164 -0
  19. hyvideo/data_kits/audio_dataset.py +170 -0
  20. hyvideo/data_kits/audio_preprocessor.py +72 -0
  21. hyvideo/data_kits/data_tools.py +41 -0
  22. hyvideo/data_kits/face_align/__init__.py +1 -0
  23. hyvideo/data_kits/face_align/align.py +34 -0
  24. hyvideo/data_kits/face_align/detface.py +283 -0
  25. hyvideo/diffusion/__init__.py +2 -0
  26. hyvideo/diffusion/pipelines/__init__.py +2 -0
  27. hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py +1421 -0
  28. hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py +1359 -0
  29. hyvideo/diffusion/schedulers/__init__.py +1 -0
  30. hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py +255 -0
  31. hyvideo/hunyuan.py +1000 -0
  32. hyvideo/modules/__init__.py +26 -0
  33. hyvideo/modules/activation_layers.py +23 -0
  34. hyvideo/modules/attenion.py +362 -0
  35. hyvideo/modules/audio_adapters.py +220 -0
  36. hyvideo/modules/embed_layers.py +158 -0
  37. hyvideo/modules/mlp_layers.py +131 -0
  38. hyvideo/modules/models.py +1109 -0
  39. hyvideo/modules/modulate_layers.py +136 -0
  40. hyvideo/modules/norm_layers.py +88 -0
  41. hyvideo/modules/original models.py +760 -0
  42. hyvideo/modules/placement.py +389 -0
  43. hyvideo/modules/posemb_layers.py +475 -0
  44. hyvideo/modules/token_refiner.py +237 -0
  45. hyvideo/modules/utils.py +43 -0
  46. hyvideo/prompt_rewrite.py +51 -0
  47. hyvideo/text_encoder/__init__.py +552 -0
  48. hyvideo/utils/__init__.py +0 -0
  49. hyvideo/utils/data_utils.py +90 -0
  50. hyvideo/utils/file_utils.py +70 -0
.gitattributes CHANGED
@@ -1,35 +1,8 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
2
+ *.avi filter=lfs diff=lfs merge=lfs -text
3
+ *.mov filter=lfs diff=lfs merge=lfs -text
4
+ *.gif filter=lfs diff=lfs merge=lfs -text
5
+ *.png filter=lfs diff=lfs merge=lfs -text
6
+ *.jpg filter=lfs diff=lfs merge=lfs -text
7
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
8
+ *.JPG filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .*
2
+ *.py[cod]
3
+ # *.jpg
4
+ *.jpeg
5
+ # *.png
6
+ *.gif
7
+ *.bmp
8
+ # *.mp4
9
+ *.mov
10
+ *.mkv
11
+ *.log
12
+ *.zip
13
+ *.pt
14
+ *.pth
15
+ *.ckpt
16
+ *.safetensors
17
+ *.json
18
+ # *.txt
19
+ *.backup
20
+ *.pkl
21
+ *.html
22
+ *.pdf
23
+ *.whl
24
+ *.exe
25
+ cache
26
+ __pycache__/
27
+ storage/
28
+ samples/
29
+ !.gitignore
30
+ !requirements.txt
31
+ .DS_Store
32
+ *DS_Store
33
+ google/
34
+ Wan2.1-T2V-14B/
35
+ Wan2.1-T2V-1.3B/
36
+ Wan2.1-I2V-14B-480P/
37
+ Wan2.1-I2V-14B-720P/
38
+ outputs/
39
+ gradio_outputs/
40
+ ckpts/
41
+ loras/
42
+ loras_i2v/
LICENSE.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FREE for Non Commercial USE
2
+
3
+ You are free to:
4
+ - Share — copy and redistribute the material in any medium or format
5
+ - Adapt — remix, transform, and build upon the material
6
+ The licensor cannot revoke these freedoms as long as you follow the license terms.
7
+
8
+ Under the following terms:
9
+ - Attribution — You must give appropriate credit , provide a link to the license, and indicate if changes were made . You may do so in any reasonable manner, but not in any way that suggests the licensor endorses you or your use.
10
+ NonCommercial — You may not use the material for commercial purposes .
11
+
12
+ - No additional restrictions — You may not apply legal terms or technological measures that legally restrict others from doing anything the license permits.
13
+ Notices:
14
+
15
+ - You do not have to comply with the license for elements of the material in the public domain or where your use is permitted by an applicable exception or limitation .
16
+
17
+ No warranties are given. The license may not give you all of the permissions necessary for your intended use. For example, other rights such as publicity, privacy, or moral rights may limit how you use the material.
README.md CHANGED
@@ -1,12 +1,132 @@
1
  ---
2
- title: WanGP HunyuanVideoAvatar
3
- emoji: 🏢
4
- colorFrom: green
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 5.34.0
8
- app_file: app.py
9
- pinned: false
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: WanGP
3
+ emoji:
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 5.33.0
8
+ app_file: wgp.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ short_description: A great AI video app created by DeepBeepMeep
12
  ---
13
 
14
+ # WanGP
15
+
16
+ Note from @jbilcke-hf: this repository is a fork of the original WanGP repo by DeepBeepMeep.
17
+ It has been modified to run inside a Hugging Face space.
18
+
19
+ -----
20
+ <p align="center">
21
+ <b>WanGP by DeepBeepMeep : The best Open Source Video Generative Models Accessible to the GPU Poor</b>
22
+ </p>
23
+
24
+ WanGP supports the Wan (and derived models), Hunyuan Video and LTV Video models with:
25
+ - Low VRAM requirements (as low as 6 GB of VRAM is sufficient for certain models)
26
+ - Support for old GPUs (RTX 10XX, 20xx, ...)
27
+ - Very Fast on the latest GPUs
28
+ - Easy to use Full Web based interface
29
+ - Auto download of the required model adapted to your specific architecture
30
+ - Tools integrated to facilitate Video Generation : Mask Editor, Prompt Enhancer, Temporal and Spatial Generation
31
+ - Loras Support to customize each model
32
+ - Queuing system : make your shopping list of videos to generate and come back later
33
+
34
+ **Discord Server to get Help from Other Users and show your Best Videos:** https://discord.gg/g7efUW9jGV
35
+
36
+ **Follow DeepBeepMeep on Twitter/X to get the Latest News**: https://x.com/deepbeepmeep
37
+
38
+ ## 🔥 Latest Updates
39
+ ### May 28 2025: WanGP v5.41
40
+ 👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.\
41
+ You will need to do a *pip install -r requirements.txt*
42
+
43
+ ### May 28 2025: WanGP v5.4
44
+ 👋 World Exclusive : **Hunyuan Video Avatar** Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.\
45
+ Here is a link to the original repo where you will find some very interesting documentation and examples. https://github.com/Tencent-Hunyuan/HunyuanVideo-Avatar. Kudos to the Hunyuan Video Avatar team for the best model of its kind.\
46
+ Also many thanks to Reevoy24 for his repackaging / completing the documentation
47
+
48
+ ### May 28 2025: WanGP v5.31
49
+ 👋 Added **Phantom 14B**, a model that you can use to transfer objects / people in the video. My preference goes to Vace that remains the king of controlnets.
50
+ VACE improvements: Better sliding window transitions, image mask support in Matanyone, new Extend Video feature, and enhanced background removal options.
51
+
52
+ ### May 26, 2025: WanGP v5.3
53
+ 👋 Settings management revolution! Now you can:
54
+ - Select any generated video and click *Use Selected Video Settings* to instantly reuse its configuration
55
+ - Drag & drop videos to automatically extract their settings metadata
56
+ - Export/import settings as JSON files for easy sharing and backup
57
+
58
+ ### May 20, 2025: WanGP v5.2
59
+ 👋 **CausVid support** - Generate videos in just 4-12 steps with the new distilled Wan model! Also added experimental MoviiGen for 1080p generation (20GB+ VRAM required). Check the Loras documentation to get the usage instructions of CausVid.
60
+
61
+ ### May 18, 2025: WanGP v5.1
62
+ 👋 **LTX Video 13B Distilled** - Generate high-quality videos in less than one minute!
63
+
64
+ ### May 17, 2025: WanGP v5.0
65
+ 👋 **One App to Rule Them All!** Added Hunyuan Video and LTX Video support, plus Vace 14B and integrated prompt enhancer.
66
+
67
+ See full changelog: **[Changelog](docs/CHANGELOG.md)**
68
+
69
+ ## 📋 Table of Contents
70
+
71
+ - [🚀 Quick Start](#-quick-start)
72
+ - [📦 Installation](#-installation)
73
+ - [🎯 Usage](#-usage)
74
+ - [📚 Documentation](#-documentation)
75
+ - [🔗 Related Projects](#-related-projects)
76
+
77
+ ## 🚀 Quick Start
78
+
79
+ **One-click installation:** Get started instantly with [Pinokio App](https://pinokio.computer/)
80
+
81
+ **Manual installation:**
82
+ ```bash
83
+ git clone https://github.com/deepbeepmeep/Wan2GP.git
84
+ cd Wan2GP
85
+ conda create -n wan2gp python=3.10.9
86
+ conda activate wan2gp
87
+ pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
88
+ pip install -r requirements.txt
89
+ ```
90
+
91
+ **Run the application:**
92
+ ```bash
93
+ python wgp.py # Text-to-video (default)
94
+ python wgp.py --i2v # Image-to-video
95
+ ```
96
+
97
+ ## 📦 Installation
98
+
99
+ For detailed installation instructions for different GPU generations:
100
+ - **[Installation Guide](docs/INSTALLATION.md)** - Complete setup instructions for RTX 10XX to RTX 50XX
101
+
102
+ ## 🎯 Usage
103
+
104
+ ### Basic Usage
105
+ - **[Getting Started Guide](docs/GETTING_STARTED.md)** - First steps and basic usage
106
+ - **[Models Overview](docs/MODELS.md)** - Available models and their capabilities
107
+
108
+ ### Advanced Features
109
+ - **[Loras Guide](docs/LORAS.md)** - Using and managing Loras for customization
110
+ - **[VACE ControlNet](docs/VACE.md)** - Advanced video control and manipulation
111
+ - **[Command Line Reference](docs/CLI.md)** - All available command line options
112
+
113
+ ## 📚 Documentation
114
+
115
+ - **[Changelog](docs/CHANGELOG.md)** - Latest updates and version history
116
+ - **[Troubleshooting](docs/TROUBLESHOOTING.md)** - Common issues and solutions
117
+
118
+ ## 🔗 Related Projects
119
+
120
+ ### Other Models for the GPU Poor
121
+ - **[HuanyuanVideoGP](https://github.com/deepbeepmeep/HunyuanVideoGP)** - One of the best open source Text to Video generators
122
+ - **[Hunyuan3D-2GP](https://github.com/deepbeepmeep/Hunyuan3D-2GP)** - Image to 3D and text to 3D tool
123
+ - **[FluxFillGP](https://github.com/deepbeepmeep/FluxFillGP)** - Inpainting/outpainting tools based on Flux
124
+ - **[Cosmos1GP](https://github.com/deepbeepmeep/Cosmos1GP)** - Text to world generator and image/video to world
125
+ - **[OminiControlGP](https://github.com/deepbeepmeep/OminiControlGP)** - Flux-derived application for object transfer
126
+ - **[YuE GP](https://github.com/deepbeepmeep/YuEGP)** - Song generator with instruments and singer's voice
127
+
128
+ ---
129
+
130
+ <p align="center">
131
+ Made with ❤️ by DeepBeepMeep
132
+ </p>
docs/CHANGELOG.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ ## 🔥 Latest News
4
+ ### May 28 2025: WanGP v5.41
5
+ 👋 Bonus release: Support for **AccVideo** Lora to speed up x2 Video generations in Wan models. Check the Loras documentation to get the usage instructions of AccVideo.
6
+
7
+ ### May 28 2025: WanGP v5.4
8
+ 👋 World Exclusive : Hunyuan Video Avatar Support ! You won't need 80 GB of VRAM nor 32 GB oF VRAM, just 10 GB of VRAM will be sufficient to generate up to 15s of high quality speech / song driven Video at a high speed with no quality degradation. Support for TeaCache included.
9
+
10
+ ### May 26, 2025: WanGP v5.3
11
+ 👋 Happy with a Video generation and want to do more generations using the same settings but you can't remember what you did or you find it too hard to copy/paste one per one each setting from the file metadata? Rejoice! There are now multiple ways to turn this tedious process into a one click task:
12
+ - Select one Video recently generated in the Video Gallery and click *Use Selected Video Settings*
13
+ - Click *Drop File Here* and select a Video you saved somewhere, if the settings metadata have been saved with the Video you will be able to extract them automatically
14
+ - Click *Export Settings to File* to save on your harddrive the current settings. You will be able to use them later again by clicking *Drop File Here* and select this time a Settings json file
15
+
16
+ ### May 23, 2025: WanGP v5.21
17
+ 👋 Improvements for Vace: better transitions between Sliding Windows, Support for Image masks in Matanyone, new Extend Video for Vace, different types of automated background removal
18
+
19
+ ### May 20, 2025: WanGP v5.2
20
+ 👋 Added support for Wan CausVid which is a distilled Wan model that can generate nice looking videos in only 4 to 12 steps. The great thing is that Kijai (Kudos to him!) has created a CausVid Lora that can be combined with any existing Wan t2v model 14B like Wan Vace 14B. See [LORAS.md](LORAS.md) for instructions on how to use CausVid.
21
+
22
+ Also as an experiment I have added support for the MoviiGen, the first model that claims to be capable of generating 1080p videos (if you have enough VRAM (20GB...) and be ready to wait for a long time...). Don't hesitate to share your impressions on the Discord server.
23
+
24
+ ### May 18, 2025: WanGP v5.1
25
+ 👋 Bonus Day, added LTX Video 13B Distilled: generate in less than one minute, very high quality Videos!
26
+
27
+ ### May 17, 2025: WanGP v5.0
28
+ 👋 One App to Rule Them All! Added support for the other great open source architectures:
29
+ - **Hunyuan Video**: text 2 video (one of the best, if not the best t2v), image 2 video and the recently released Hunyuan Custom (very good identity preservation when injecting a person into a video)
30
+ - **LTX Video 13B** (released last week): very long video support and fast 720p generation. Wan GP version has been greatly optimized and reduced LTX Video VRAM requirements by 4!
31
+
32
+ Also:
33
+ - Added support for the best Control Video Model, released 2 days ago: Vace 14B
34
+ - New Integrated prompt enhancer to increase the quality of the generated videos
35
+
36
+ *You will need one more `pip install -r requirements.txt`*
37
+
38
+ ### May 5, 2025: WanGP v4.5
39
+ 👋 FantasySpeaking model, you can animate a talking head using a voice track. This works not only on people but also on objects. Also better seamless transitions between Vace sliding windows for very long videos. New high quality processing features (mixed 16/32 bits calculation and 32 bits VAE)
40
+
41
+ ### April 27, 2025: WanGP v4.4
42
+ 👋 Phantom model support, very good model to transfer people or objects into video, works quite well at 720p and with the number of steps > 30
43
+
44
+ ### April 25, 2025: WanGP v4.3
45
+ 👋 Added preview mode and support for Sky Reels v2 Diffusion Forcing for high quality "infinite length videos". Note that Skyreel uses causal attention that is only supported by Sdpa attention so even if you choose another type of attention, some of the processes will use Sdpa attention.
46
+
47
+ ### April 18, 2025: WanGP v4.2
48
+ 👋 FLF2V model support, official support from Wan for image2video start and end frames specialized for 720p.
49
+
50
+ ### April 17, 2025: WanGP v4.1
51
+ 👋 Recam Master model support, view a video from a different angle. The video to process must be at least 81 frames long and you should set at least 15 steps denoising to get good results.
52
+
53
+ ### April 13, 2025: WanGP v4.0
54
+ 👋 Lots of goodies for you!
55
+ - A new UI, tabs were replaced by a Dropdown box to easily switch models
56
+ - A new queuing system that lets you stack in a queue as many text2video, image2video tasks, ... as you want. Each task can rely on complete different generation parameters (different number of frames, steps, loras, ...). Many thanks to **Tophness** for being a big contributor on this new feature
57
+ - Temporal upsampling (Rife) and spatial upsampling (Lanczos) for a smoother video (32 fps or 64 fps) and to enlarge your video by x2 or x4. Check these new advanced options.
58
+ - Wan Vace Control Net support: with Vace you can inject in the scene people or objects, animate a person, perform inpainting or outpainting, continue a video, ... See [VACE.md](VACE.md) for introduction guide.
59
+ - Integrated *Matanyone* tool directly inside WanGP so that you can create easily inpainting masks used in Vace
60
+ - Sliding Window generation for Vace, create windows that can last dozens of seconds
61
+ - New optimizations for old generation GPUs: Generate 5s (81 frames, 15 steps) of Vace 1.3B with only 5GB and in only 6 minutes on a RTX 2080Ti and 5s of t2v 14B in less than 10 minutes.
62
+
63
+ ### March 27, 2025
64
+ 👋 Added support for the new Wan Fun InP models (image2video). The 14B Fun InP has probably better end image support but unfortunately existing loras do not work so well with it. The great novelty is the Fun InP image2 1.3B model: Image 2 Video is now accessible to even lower hardware configuration. It is not as good as the 14B models but very impressive for its size. Many thanks to the VideoX-Fun team (https://github.com/aigc-apps/VideoX-Fun)
65
+
66
+ ### March 26, 2025
67
+ 👋 Good news! Official support for RTX 50xx please check the [installation instructions](INSTALLATION.md).
68
+
69
+ ### March 24, 2025: Wan2.1GP v3.2
70
+ 👋
71
+ - Added Classifier-Free Guidance Zero Star. The video should match better the text prompt (especially with text2video) at no performance cost: many thanks to the **CFG Zero * Team**. Don't hesitate to give them a star if you appreciate the results: https://github.com/WeichenFan/CFG-Zero-star
72
+ - Added back support for PyTorch compilation with Loras. It seems it had been broken for some time
73
+ - Added possibility to keep a number of pregenerated videos in the Video Gallery (useful to compare outputs of different settings)
74
+
75
+ *You will need one more `pip install -r requirements.txt`*
76
+
77
+ ### March 19, 2025: Wan2.1GP v3.1
78
+ 👋 Faster launch and RAM optimizations (should require less RAM to run)
79
+
80
+ *You will need one more `pip install -r requirements.txt`*
81
+
82
+ ### March 18, 2025: Wan2.1GP v3.0
83
+ 👋
84
+ - New Tab based interface, you can switch from i2v to t2v conversely without restarting the app
85
+ - Experimental Dual Frames mode for i2v, you can also specify an End frame. It doesn't always work, so you will need a few attempts.
86
+ - You can save default settings in the files *i2v_settings.json* and *t2v_settings.json* that will be used when launching the app (you can also specify the path to different settings files)
87
+ - Slight acceleration with loras
88
+
89
+ *You will need one more `pip install -r requirements.txt`*
90
+
91
+ Many thanks to *Tophness* who created the framework (and did a big part of the work) of the multitabs and saved settings features
92
+
93
+ ### March 18, 2025: Wan2.1GP v2.11
94
+ 👋 Added more command line parameters to prefill the generation settings + customizable output directory and choice of type of metadata for generated videos. Many thanks to *Tophness* for his contributions.
95
+
96
+ *You will need one more `pip install -r requirements.txt` to reflect new dependencies*
97
+
98
+ ### March 18, 2025: Wan2.1GP v2.1
99
+ 👋 More Loras!: added support for 'Safetensors' and 'Replicate' Lora formats.
100
+
101
+ *You will need to refresh the requirements with a `pip install -r requirements.txt`*
102
+
103
+ ### March 17, 2025: Wan2.1GP v2.0
104
+ 👋 The Lora festival continues:
105
+ - Clearer user interface
106
+ - Download 30 Loras in one click to try them all (expand the info section)
107
+ - Very easy to use Loras as now Lora presets can input the subject (or other needed terms) of the Lora so that you don't have to modify manually a prompt
108
+ - Added basic macro prompt language to prefill prompts with different values. With one prompt template, you can generate multiple prompts.
109
+ - New Multiple images prompts: you can now combine any number of images with any number of text prompts (need to launch the app with --multiple-images)
110
+ - New command lines options to launch directly the 1.3B t2v model or the 14B t2v model
111
+
112
+ ### March 14, 2025: Wan2.1GP v1.7
113
+ 👋
114
+ - Lora Fest special edition: very fast loading/unload of loras for those Loras collectors around. You can also now add/remove loras in the Lora folder without restarting the app.
115
+ - Added experimental Skip Layer Guidance (advanced settings), that should improve the image quality at no extra cost. Many thanks to the *AmericanPresidentJimmyCarter* for the original implementation
116
+
117
+ *You will need to refresh the requirements `pip install -r requirements.txt`*
118
+
119
+ ### March 13, 2025: Wan2.1GP v1.6
120
+ 👋 Better Loras support, accelerated loading Loras.
121
+
122
+ *You will need to refresh the requirements `pip install -r requirements.txt`*
123
+
124
+ ### March 10, 2025: Wan2.1GP v1.5
125
+ 👋 Official Teacache support + Smart Teacache (find automatically best parameters for a requested speed multiplier), 10% speed boost with no quality loss, improved lora presets (they can now include prompts and comments to guide the user)
126
+
127
+ ### March 7, 2025: Wan2.1GP v1.4
128
+ 👋 Fix PyTorch compilation, now it is really 20% faster when activated
129
+
130
+ ### March 4, 2025: Wan2.1GP v1.3
131
+ 👋 Support for Image to Video with multiples images for different images/prompts combinations (requires *--multiple-images* switch), and added command line *--preload x* to preload in VRAM x MB of the main diffusion model if you find there is too much unused VRAM and you want to (slightly) accelerate the generation process.
132
+
133
+ *If you upgrade you will need to do a `pip install -r requirements.txt` again.*
134
+
135
+ ### March 4, 2025: Wan2.1GP v1.2
136
+ 👋 Implemented tiling on VAE encoding and decoding. No more VRAM peaks at the beginning and at the end
137
+
138
+ ### March 3, 2025: Wan2.1GP v1.1
139
+ 👋 Added Tea Cache support for faster generations: optimization of kijai's implementation (https://github.com/kijai/ComfyUI-WanVideoWrapper/) of teacache (https://github.com/ali-vilab/TeaCache)
140
+
141
+ ### March 2, 2025: Wan2.1GP by DeepBeepMeep v1
142
+ 👋 Brings:
143
+ - Support for all Wan including the Image to Video model
144
+ - Reduced memory consumption by 2, with possibility to generate more than 10s of video at 720p with a RTX 4090 and 10s of video at 480p with less than 12GB of VRAM. Many thanks to REFLEx (https://github.com/thu-ml/RIFLEx) for their algorithm that allows generating nice looking video longer than 5s.
145
+ - The usual perks: web interface, multiple generations, loras support, sage attention, auto download of models, ...
146
+
147
+ ## Original Wan Releases
148
+
149
+ ### February 25, 2025
150
+ 👋 We've released the inference code and weights of Wan2.1.
151
+
152
+ ### February 27, 2025
153
+ 👋 Wan2.1 has been integrated into [ComfyUI](https://comfyanonymous.github.io/ComfyUI_examples/wan/). Enjoy!
docs/CLI.md ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --vace-1-3B--vace-1-3B# Command Line Reference
2
+
3
+ This document covers all available command line options for WanGP.
4
+
5
+ ## Basic Usage
6
+
7
+ ```bash
8
+ # Default launch
9
+ python wgp.py
10
+
11
+ # Specific model modes
12
+ python wgp.py --i2v # Image-to-video
13
+ python wgp.py --t2v # Text-to-video (default)
14
+ python wgp.py --t2v-14B # 14B text-to-video model
15
+ python wgp.py --t2v-1-3B # 1.3B text-to-video model
16
+ python wgp.py --i2v-14B # 14B image-to-video model
17
+ python wgp.py --i2v-1-3B # Fun InP 1.3B image-to-video model
18
+ python wgp.py --vace-1-3B # VACE ControlNet 1.3B model
19
+ ```
20
+
21
+ ## Model and Performance Options
22
+
23
+ ### Model Configuration
24
+ ```bash
25
+ --quantize-transformer BOOL # Enable/disable transformer quantization (default: True)
26
+ --compile # Enable PyTorch compilation (requires Triton)
27
+ --attention MODE # Force attention mode: sdpa, flash, sage, sage2
28
+ --profile NUMBER # Performance profile 1-5 (default: 4)
29
+ --preload NUMBER # Preload N MB of diffusion model in VRAM
30
+ --fp16 # Force fp16 instead of bf16 models
31
+ --gpu DEVICE # Run on specific GPU device (e.g., "cuda:1")
32
+ ```
33
+
34
+ ### Performance Profiles
35
+ - **Profile 1**: Load entire current model in VRAM and keep all unused models in reserved RAM for fast VRAM tranfers
36
+ - **Profile 2**: Load model parts as needed, keep all unused models in reserved RAM for fast VRAM tranfers
37
+ - **Profile 3**: Load entire current model in VRAM (requires 24GB for 14B model)
38
+ - **Profile 4**: Default and recommended, load model parts as needed, most flexible option
39
+ - **Profile 5**: Minimum RAM usage
40
+
41
+ ### Memory Management
42
+ ```bash
43
+ --perc-reserved-mem-max FLOAT # Max percentage of RAM for reserved memory (< 0.5)
44
+ ```
45
+
46
+ ## Lora Configuration
47
+
48
+ ```bash
49
+ --lora-dir PATH # Path to Wan t2v loras directory
50
+ --lora-dir-i2v PATH # Path to Wan i2v loras directory
51
+ --lora-dir-hunyuan PATH # Path to Hunyuan t2v loras directory
52
+ --lora-dir-hunyuan-i2v PATH # Path to Hunyuan i2v loras directory
53
+ --lora-dir-ltxv PATH # Path to LTX Video loras directory
54
+ --lora-preset PRESET # Load lora preset file (.lset) on startup
55
+ --check-loras # Filter incompatible loras (slower startup)
56
+ ```
57
+
58
+ ## Generation Settings
59
+
60
+ ### Basic Generation
61
+ ```bash
62
+ --seed NUMBER # Set default seed value
63
+ --frames NUMBER # Set default number of frames to generate
64
+ --steps NUMBER # Set default number of denoising steps
65
+ --advanced # Launch with advanced mode enabled
66
+ ```
67
+
68
+ ### Advanced Generation
69
+ ```bash
70
+ --teacache MULTIPLIER # TeaCache speed multiplier: 0, 1.5, 1.75, 2.0, 2.25, 2.5
71
+ ```
72
+
73
+ ## Interface and Server Options
74
+
75
+ ### Server Configuration
76
+ ```bash
77
+ --server-port PORT # Gradio server port (default: 7860)
78
+ --server-name NAME # Gradio server name (default: localhost)
79
+ --listen # Make server accessible on network
80
+ --share # Create shareable HuggingFace URL for remote access
81
+ --open-browser # Open browser automatically when launching
82
+ ```
83
+
84
+ ### Interface Options
85
+ ```bash
86
+ --lock-config # Prevent modifying video engine configuration from interface
87
+ --theme THEME_NAME # UI theme: "default" or "gradio"
88
+ ```
89
+
90
+ ## File and Directory Options
91
+
92
+ ```bash
93
+ --settings PATH # Path to folder containing default settings for all models
94
+ --verbose LEVEL # Information level 0-2 (default: 1)
95
+ ```
96
+
97
+ ## Examples
98
+
99
+ ### Basic Usage Examples
100
+ ```bash
101
+ # Launch with specific model and loras
102
+ python wgp.py --t2v-14B --lora-preset mystyle.lset
103
+
104
+ # High-performance setup with compilation
105
+ python wgp.py --compile --attention sage2 --profile 3
106
+
107
+ # Low VRAM setup
108
+ python wgp.py --t2v-1-3B --profile 4 --attention sdpa
109
+
110
+ # Multiple images with custom lora directory
111
+ python wgp.py --i2v --multiple-images --lora-dir /path/to/shared/loras
112
+ ```
113
+
114
+ ### Server Configuration Examples
115
+ ```bash
116
+ # Network accessible server
117
+ python wgp.py --listen --server-port 8080
118
+
119
+ # Shareable server with custom theme
120
+ python wgp.py --share --theme gradio --open-browser
121
+
122
+ # Locked configuration for public use
123
+ python wgp.py --lock-config --share
124
+ ```
125
+
126
+ ### Advanced Performance Examples
127
+ ```bash
128
+ # Maximum performance (requires high-end GPU)
129
+ python wgp.py --compile --attention sage2 --profile 3 --preload 2000
130
+
131
+ # Optimized for RTX 2080Ti
132
+ python wgp.py --profile 4 --attention sdpa --teacache 2.0
133
+
134
+ # Memory-efficient setup
135
+ python wgp.py --fp16 --profile 4 --perc-reserved-mem-max 0.3
136
+ ```
137
+
138
+ ### TeaCache Configuration
139
+ ```bash
140
+ # Different speed multipliers
141
+ python wgp.py --teacache 1.5 # 1.5x speed, minimal quality loss
142
+ python wgp.py --teacache 2.0 # 2x speed, some quality loss
143
+ python wgp.py --teacache 2.5 # 2.5x speed, noticeable quality loss
144
+ python wgp.py --teacache 0 # Disable TeaCache
145
+ ```
146
+
147
+ ## Attention Modes
148
+
149
+ ### SDPA (Default)
150
+ ```bash
151
+ python wgp.py --attention sdpa
152
+ ```
153
+ - Available by default with PyTorch
154
+ - Good compatibility with all GPUs
155
+ - Moderate performance
156
+
157
+ ### Sage Attention
158
+ ```bash
159
+ python wgp.py --attention sage
160
+ ```
161
+ - Requires Triton installation
162
+ - 30% faster than SDPA
163
+ - Small quality cost
164
+
165
+ ### Sage2 Attention
166
+ ```bash
167
+ python wgp.py --attention sage2
168
+ ```
169
+ - Requires Triton and SageAttention 2.x
170
+ - 40% faster than SDPA
171
+ - Best performance option
172
+
173
+ ### Flash Attention
174
+ ```bash
175
+ python wgp.py --attention flash
176
+ ```
177
+ - May require CUDA kernel compilation
178
+ - Good performance
179
+ - Can be complex to install on Windows
180
+
181
+ ## Troubleshooting Command Lines
182
+
183
+ ### Fallback to Basic Setup
184
+ ```bash
185
+ # If advanced features don't work
186
+ python wgp.py --attention sdpa --profile 4 --fp16
187
+ ```
188
+
189
+ ### Debug Mode
190
+ ```bash
191
+ # Maximum verbosity for troubleshooting
192
+ python wgp.py --verbose 2 --check-loras
193
+ ```
194
+
195
+ ### Memory Issue Debugging
196
+ ```bash
197
+ # Minimal memory usage
198
+ python wgp.py --profile 4 --attention sdpa --perc-reserved-mem-max 0.2
199
+ ```
200
+
201
+
202
+
203
+ ## Configuration Files
204
+
205
+ ### Settings Files
206
+ Load custom settings:
207
+ ```bash
208
+ python wgp.py --settings /path/to/settings/folder
209
+ ```
210
+
211
+ ### Lora Presets
212
+ Create and share lora configurations:
213
+ ```bash
214
+ # Load specific preset
215
+ python wgp.py --lora-preset anime_style.lset
216
+
217
+ # With custom lora directory
218
+ python wgp.py --lora-preset mystyle.lset --lora-dir /shared/loras
219
+ ```
220
+
221
+ ## Environment Variables
222
+
223
+ While not command line options, these environment variables can affect behavior:
224
+ - `CUDA_VISIBLE_DEVICES` - Limit visible GPUs
225
+ - `PYTORCH_CUDA_ALLOC_CONF` - CUDA memory allocation settings
226
+ - `TRITON_CACHE_DIR` - Triton cache directory (for Sage attention)
docs/GETTING_STARTED.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Getting Started with WanGP
2
+
3
+ This guide will help you get started with WanGP video generation quickly and easily.
4
+
5
+ ## Prerequisites
6
+
7
+ Before starting, ensure you have:
8
+ - A compatible GPU (RTX 10XX or newer recommended)
9
+ - Python 3.10.9 installed
10
+ - At least 6GB of VRAM for basic models
11
+ - Internet connection for model downloads
12
+
13
+ ## Quick Setup
14
+
15
+ ### Option 1: One-Click Installation (Recommended)
16
+ Use [Pinokio App](https://pinokio.computer/) for the easiest installation experience.
17
+
18
+ ### Option 2: Manual Installation
19
+ ```bash
20
+ git clone https://github.com/deepbeepmeep/Wan2GP.git
21
+ cd Wan2GP
22
+ conda create -n wan2gp python=3.10.9
23
+ conda activate wan2gp
24
+ pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ For detailed installation instructions, see [INSTALLATION.md](INSTALLATION.md).
29
+
30
+ ## First Launch
31
+
32
+ ### Basic Launch
33
+ ```bash
34
+ python wgp.py
35
+ ```
36
+ This launches the WanGP generator with default settings. You will be able to pick from a Drop Down menu which model you want to use.
37
+
38
+ ### Alternative Modes
39
+ ```bash
40
+ python wgp.py --i2v # Wan Image-to-video mode
41
+ python wgp.py --t2v-1-3B # Wan Smaller, faster model
42
+ ```
43
+
44
+ ## Understanding the Interface
45
+
46
+ When you launch WanGP, you'll see a web interface with several sections:
47
+
48
+ ### Main Generation Panel
49
+ - **Model Selection**: Dropdown to choose between different models
50
+ - **Prompt**: Text description of what you want to generate
51
+ - **Generate Button**: Start the video generation process
52
+
53
+ ### Advanced Settings (click checkbox to enable)
54
+ - **Generation Settings**: Steps, guidance, seeds
55
+ - **Loras**: Additional style customizations
56
+ - **Sliding Window**: For longer videos
57
+
58
+ ## Your First Video
59
+
60
+ Let's generate a simple text-to-video:
61
+
62
+ 1. **Launch WanGP**: `python wgp.py`
63
+ 2. **Open Browser**: Navigate to `http://localhost:7860`
64
+ 3. **Enter Prompt**: "A cat walking in a garden"
65
+ 4. **Click Generate**: Wait for the video to be created
66
+ 5. **View Result**: The video will appear in the output section
67
+
68
+ ### Recommended First Settings
69
+ - **Model**: Wan 2.1 text2video 1.3B (faster, lower VRAM)
70
+ - **Frames**: 49 (about 2 seconds)
71
+ - **Steps**: 20 (good balance of speed/quality)
72
+
73
+ ## Model Selection
74
+
75
+ ### Text-to-Video Models
76
+ - **Wan 2.1 T2V 1.3B**: Fastest, lowest VRAM (6GB), good quality
77
+ - **Wan 2.1 T2V 14B**: Best quality, requires more VRAM (12GB+)
78
+ - **Hunyuan Video**: Excellent quality, slower generation
79
+ - **LTX Video**: Good for longer videos
80
+
81
+ ### Image-to-Video Models
82
+ - **Wan Fun InP 1.3B**: Fast image animation
83
+ - **Wan Fun InP 14B**: Higher quality image animation
84
+ - **VACE**: Advanced control over video generation
85
+
86
+ ### Choosing the Right Model
87
+ - **Low VRAM (6-8GB)**: Use 1.3B models
88
+ - **Medium VRAM (10-12GB)**: Use 14B models or Hunyuan
89
+ - **High VRAM (16GB+)**: Any model, longer videos
90
+
91
+ ## Basic Settings Explained
92
+
93
+ ### Generation Settings
94
+ - **Frames**: Number of frames (more = longer video)
95
+ - 25 frames ≈ 1 second
96
+ - 49 frames ≈ 2 seconds
97
+ - 73 frames ≈ 3 seconds
98
+
99
+ - **Steps**: Quality vs Speed tradeoff
100
+ - 15 steps: Fast, lower quality
101
+ - 20 steps: Good balance
102
+ - 30+ steps: High quality, slower
103
+
104
+ - **Guidance Scale**: How closely to follow the prompt
105
+ - 3-5: More creative interpretation
106
+ - 7-10: Closer to prompt description
107
+ - 12+: Very literal interpretation
108
+
109
+ ### Seeds
110
+ - **Random Seed**: Different result each time
111
+ - **Fixed Seed**: Reproducible results
112
+ - **Use same seed + prompt**: Generate variations
113
+
114
+ ## Common Beginner Issues
115
+
116
+ ### "Out of Memory" Errors
117
+ 1. Use smaller models (1.3B instead of 14B)
118
+ 2. Reduce frame count
119
+ 3. Lower resolution in advanced settings
120
+ 4. Enable quantization (usually on by default)
121
+
122
+ ### Slow Generation
123
+ 1. Use 1.3B models for speed
124
+ 2. Reduce number of steps
125
+ 3. Install Sage attention (see [INSTALLATION.md](INSTALLATION.md))
126
+ 4. Enable TeaCache: `python wgp.py --teacache 2.0`
127
+
128
+ ### Poor Quality Results
129
+ 1. Increase number of steps (25-30)
130
+ 2. Improve prompt description
131
+ 3. Use 14B models if you have enough VRAM
132
+ 4. Enable Skip Layer Guidance in advanced settings
133
+
134
+ ## Writing Good Prompts
135
+
136
+ ### Basic Structure
137
+ ```
138
+ [Subject] [Action] [Setting] [Style/Quality modifiers]
139
+ ```
140
+
141
+ ### Examples
142
+ ```
143
+ A red sports car driving through a mountain road at sunset, cinematic, high quality
144
+
145
+ A woman with long hair walking on a beach, waves in the background, realistic, detailed
146
+
147
+ A cat sitting on a windowsill watching rain, cozy atmosphere, soft lighting
148
+ ```
149
+
150
+ ### Tips
151
+ - Be specific about what you want
152
+ - Include style descriptions (cinematic, realistic, etc.)
153
+ - Mention lighting and atmosphere
154
+ - Describe the setting in detail
155
+ - Use quality modifiers (high quality, detailed, etc.)
156
+
157
+ ## Next Steps
158
+
159
+ Once you're comfortable with basic generation:
160
+
161
+ 1. **Explore Advanced Features**:
162
+ - [Loras Guide](LORAS.md) - Customize styles and characters
163
+ - [VACE ControlNet](VACE.md) - Advanced video control
164
+ - [Command Line Options](CLI.md) - Optimize performance
165
+
166
+ 2. **Improve Performance**:
167
+ - Install better attention mechanisms
168
+ - Optimize memory settings
169
+ - Use compilation for speed
170
+
171
+ 3. **Join the Community**:
172
+ - [Discord Server](https://discord.gg/g7efUW9jGV) - Get help and share videos
173
+ - Share your best results
174
+ - Learn from other users
175
+
176
+ ## Troubleshooting First Steps
177
+
178
+ ### Installation Issues
179
+ - Ensure Python 3.10.9 is used
180
+ - Check CUDA version compatibility
181
+ - See [INSTALLATION.md](INSTALLATION.md) for detailed steps
182
+
183
+ ### Generation Issues
184
+ - Check GPU compatibility
185
+ - Verify sufficient VRAM
186
+ - Try basic settings first
187
+ - See [TROUBLESHOOTING.md](TROUBLESHOOTING.md) for specific issues
188
+
189
+ ### Performance Issues
190
+ - Use appropriate model for your hardware
191
+ - Enable performance optimizations
192
+ - Check [CLI.md](CLI.md) for optimization flags
193
+
194
+ Remember: Start simple and gradually explore more advanced features as you become comfortable with the basics!
docs/INSTALLATION.md ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation Guide
2
+
3
+ This guide covers installation for different GPU generations and operating systems.
4
+
5
+ ## Requirements
6
+
7
+ - Python 3.10.9
8
+ - Conda or Python venv
9
+ - Compatible GPU (RTX 10XX or newer recommended)
10
+
11
+ ## Installation for RTX 10XX to RTX 40XX (Stable)
12
+
13
+ This installation uses PyTorch 2.6.0 which is well-tested and stable.
14
+
15
+ ### Step 1: Download and Setup Environment
16
+
17
+ ```shell
18
+ # Clone the repository
19
+ git clone https://github.com/deepbeepmeep/Wan2GP.git
20
+ cd Wan2GP
21
+
22
+ # Create Python 3.10.9 environment using conda
23
+ conda create -n wan2gp python=3.10.9
24
+ conda activate wan2gp
25
+ ```
26
+
27
+ ### Step 2: Install PyTorch
28
+
29
+ ```shell
30
+ # Install PyTorch 2.6.0 with CUDA 12.4
31
+ pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
32
+ ```
33
+
34
+ ### Step 3: Install Dependencies
35
+
36
+ ```shell
37
+ # Install core dependencies
38
+ pip install -r requirements.txt
39
+ ```
40
+
41
+ ### Step 4: Optional Performance Optimizations
42
+
43
+ #### Sage Attention (30% faster)
44
+
45
+ ```shell
46
+ # Windows only: Install Triton
47
+ pip install triton-windows
48
+
49
+ # For both Windows and Linux
50
+ pip install sageattention==1.0.6
51
+ ```
52
+
53
+ #### Sage 2 Attention (40% faster)
54
+
55
+ ```shell
56
+ # Windows
57
+ pip install triton-windows
58
+ pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu126torch2.6.0-cp310-cp310-win_amd64.whl
59
+
60
+ # Linux (manual compilation required)
61
+ git clone https://github.com/thu-ml/SageAttention
62
+ cd SageAttention
63
+ pip install -e .
64
+ ```
65
+
66
+ #### Flash Attention
67
+
68
+ ```shell
69
+ # May require CUDA kernel compilation on Windows
70
+ pip install flash-attn==2.7.2.post1
71
+ ```
72
+
73
+ ## Installation for RTX 50XX (Beta)
74
+
75
+ RTX 50XX GPUs require PyTorch 2.7.0 (beta). This version may be less stable.
76
+
77
+ ⚠️ **Important:** Use Python 3.10 for compatibility with pip wheels.
78
+
79
+ ### Step 1: Setup Environment
80
+
81
+ ```shell
82
+ # Clone and setup (same as above)
83
+ git clone https://github.com/deepbeepmeep/Wan2GP.git
84
+ cd Wan2GP
85
+ conda create -n wan2gp python=3.10.9
86
+ conda activate wan2gp
87
+ ```
88
+
89
+ ### Step 2: Install PyTorch Beta
90
+
91
+ ```shell
92
+ # Install PyTorch 2.7.0 with CUDA 12.8
93
+ pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
94
+ ```
95
+
96
+ ### Step 3: Install Dependencies
97
+
98
+ ```shell
99
+ pip install -r requirements.txt
100
+ ```
101
+
102
+ ### Step 4: Optional Optimizations for RTX 50XX
103
+
104
+ #### Sage Attention
105
+
106
+ ```shell
107
+ # Windows
108
+ pip install triton-windows
109
+ pip install sageattention==1.0.6
110
+
111
+ # Linux
112
+ pip install sageattention==1.0.6
113
+ ```
114
+
115
+ #### Sage 2 Attention
116
+
117
+ ```shell
118
+ # Windows
119
+ pip install triton-windows
120
+ pip install https://github.com/woct0rdho/SageAttention/releases/download/v2.1.1-windows/sageattention-2.1.1+cu128torch2.7.0-cp310-cp310-win_amd64.whl
121
+
122
+ # Linux (manual compilation)
123
+ git clone https://github.com/thu-ml/SageAttention
124
+ cd SageAttention
125
+ pip install -e .
126
+ ```
127
+
128
+ ## Attention Modes
129
+
130
+ WanGP supports several attention implementations:
131
+
132
+ - **SDPA** (default): Available by default with PyTorch
133
+ - **Sage**: 30% speed boost with small quality cost
134
+ - **Sage2**: 40% speed boost
135
+ - **Flash**: Good performance, may be complex to install on Windows
136
+
137
+ ## Performance Profiles
138
+
139
+ Choose a profile based on your hardware:
140
+
141
+ - **Profile 3 (LowRAM_HighVRAM)**: Loads entire model in VRAM, requires 24GB VRAM for 8-bit quantized 14B model
142
+ - **Profile 4 (LowRAM_LowVRAM)**: Default, loads model parts as needed, slower but lower VRAM requirement
143
+
144
+ ## Troubleshooting
145
+
146
+ ### Sage Attention Issues
147
+
148
+ If Sage attention doesn't work:
149
+
150
+ 1. Check if Triton is properly installed
151
+ 2. Clear Triton cache
152
+ 3. Fallback to SDPA attention:
153
+ ```bash
154
+ python wgp.py --attention sdpa
155
+ ```
156
+
157
+ ### Memory Issues
158
+
159
+ - Use lower resolution or shorter videos
160
+ - Enable quantization (default)
161
+ - Use Profile 4 for lower VRAM usage
162
+ - Consider using 1.3B models instead of 14B models
163
+
164
+ ### GPU Compatibility
165
+
166
+ - RTX 10XX, 20XX: Supported with SDPA attention
167
+ - RTX 30XX, 40XX: Full feature support
168
+ - RTX 50XX: Beta support with PyTorch 2.7.0
169
+
170
+ For more troubleshooting, see [TROUBLESHOOTING.md](TROUBLESHOOTING.md)
docs/LORAS.md ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Loras Guide
2
+
3
+ Loras (Low-Rank Adaptations) allow you to customize video generation models by adding specific styles, characters, or effects to your videos.
4
+
5
+ ## Directory Structure
6
+
7
+ Loras are organized in different folders based on the model they're designed for:
8
+
9
+ ### Text-to-Video Models
10
+ - `loras/` - General t2v loras
11
+ - `loras/1.3B/` - Loras specifically for 1.3B models
12
+ - `loras/14B/` - Loras specifically for 14B models
13
+
14
+ ### Image-to-Video Models
15
+ - `loras_i2v/` - Image-to-video loras
16
+
17
+ ### Other Models
18
+ - `loras_hunyuan/` - Hunyuan Video t2v loras
19
+ - `loras_hunyuan_i2v/` - Hunyuan Video i2v loras
20
+ - `loras_ltxv/` - LTX Video loras
21
+
22
+ ## Custom Lora Directory
23
+
24
+ You can specify custom lora directories when launching the app:
25
+
26
+ ```bash
27
+ # Use shared lora directory for both t2v and i2v
28
+ python wgp.py --lora-dir /path/to/shared/loras --lora-dir-i2v /path/to/shared/loras
29
+
30
+ # Specify different directories for different models
31
+ python wgp.py --lora-dir-hunyuan /path/to/hunyuan/loras --lora-dir-ltxv /path/to/ltx/loras
32
+ ```
33
+
34
+ ## Using Loras
35
+
36
+ ### Basic Usage
37
+
38
+ 1. Place your lora files in the appropriate directory
39
+ 2. Launch WanGP
40
+ 3. In the Advanced Tab, select the "Loras" section
41
+ 4. Check the loras you want to activate
42
+ 5. Set multipliers for each lora (default is 1.0)
43
+
44
+ ### Lora Multipliers
45
+
46
+ Multipliers control the strength of each lora's effect:
47
+
48
+ #### Simple Multipliers
49
+ ```
50
+ 1.2 0.8
51
+ ```
52
+ - First lora: 1.2 strength
53
+ - Second lora: 0.8 strength
54
+
55
+ #### Time-based Multipliers
56
+ For dynamic effects over generation steps, use comma-separated values:
57
+ ```
58
+ 0.9,0.8,0.7
59
+ 1.2,1.1,1.0
60
+ ```
61
+ - For 30 steps: steps 0-9 use first value, 10-19 use second, 20-29 use third
62
+ - First lora: 0.9 → 0.8 → 0.7
63
+ - Second lora: 1.2 → 1.1 → 1.0
64
+
65
+ ## Lora Presets
66
+
67
+ Presets are combinations of loras with predefined multipliers and prompts.
68
+
69
+ ### Creating Presets
70
+ 1. Configure your loras and multipliers
71
+ 2. Write a prompt with comments (lines starting with #)
72
+ 3. Save as a preset with `.lset` extension
73
+
74
+ ### Example Preset
75
+ ```
76
+ # Use the keyword "ohnvx" to trigger the lora
77
+ A ohnvx character is driving a car through the city
78
+ ```
79
+
80
+ ### Using Presets
81
+ ```bash
82
+ # Load preset on startup
83
+ python wgp.py --lora-preset mypreset.lset
84
+ ```
85
+
86
+ ### Managing Presets
87
+ - Edit, save, or delete presets directly from the web interface
88
+ - Presets include comments with usage instructions
89
+ - Share `.lset` files with other users
90
+
91
+ ## CausVid Lora (Video Generation Accelerator)
92
+
93
+ CausVid is a distilled Wan model that generates videos in 4-12 steps with 2x speed improvement.
94
+
95
+ ### Setup Instructions
96
+ 1. Download the CausVid Lora:
97
+ ```
98
+ https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors
99
+ ```
100
+ 2. Place in your `loras/` directory
101
+
102
+ ### Usage
103
+ 1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B)
104
+ 2. Enable Advanced Mode
105
+ 3. In Advanced Generation Tab:
106
+ - Set Guidance Scale = 1
107
+ - Set Shift Scale = 7
108
+ 4. In Advanced Lora Tab:
109
+ - Select CausVid Lora
110
+ - Set multiplier to 0.3
111
+ 5. Set generation steps to 12
112
+ 6. Generate!
113
+
114
+ ### CausVid Step/Multiplier Relationship
115
+ - **12 steps**: 0.3 multiplier (recommended)
116
+ - **8 steps**: 0.5-0.7 multiplier
117
+ - **4 steps**: 0.8-1.0 multiplier
118
+
119
+ *Note: Lower steps = lower quality (especially motion)*
120
+
121
+ ## Supported Formats
122
+
123
+ WanGP supports multiple lora formats:
124
+ - **Safetensors** (.safetensors)
125
+ - **Replicate** format
126
+ - **Standard PyTorch** (.pt, .pth)
127
+
128
+ ## AccVid Lora (Video Generation Accelerator)
129
+
130
+ AccVid is a distilled Wan model that generates videos with a 2x speed improvement since classifier free guidance is no longer needed (that is cfg = 1).
131
+
132
+ ### Setup Instructions
133
+ 1. Download the CausVid Lora:
134
+
135
+ - for t2v models:
136
+ ```
137
+ https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_T2V_14B_lora_rank32_fp16.safetensors
138
+ ```
139
+
140
+ - for i2v models:
141
+ ```
142
+ https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_AccVid_I2V_480P_14B_lora_rank32_fp16.safetensors
143
+ ```
144
+
145
+ 2. Place in your `loras/` directory or `loras_i2v/` directory
146
+
147
+ ### Usage
148
+ 1. Select a Wan t2v model (e.g., Wan 2.1 text2video 13B or Vace 13B) or Wan i2v model
149
+ 2. Enable Advanced Mode
150
+ 3. In Advanced Generation Tab:
151
+ - Set Guidance Scale = 1
152
+ - Set Shift Scale = 5
153
+ 4. The number steps remain unchanged compared to what you would use with the original model but it will be two times faster since classifier free guidance is not needed
154
+
155
+ ## Performance Tips
156
+
157
+ ### Fast Loading/Unloading
158
+ - Loras can be added/removed without restarting the app
159
+ - Use the "Refresh" button to detect new loras
160
+ - Enable `--check-loras` to filter incompatible loras (slower startup)
161
+
162
+ ### Memory Management
163
+ - Loras are loaded on-demand to save VRAM
164
+ - Multiple loras can be used simultaneously
165
+ - Time-based multipliers don't use extra memory
166
+
167
+ ## Finding Loras
168
+
169
+ ### Sources
170
+ - **[Civitai](https://civitai.com/)** - Large community collection
171
+ - **HuggingFace** - Official and community loras
172
+ - **Discord Server** - Community recommendations
173
+
174
+ ### Creating Loras
175
+ - **Kohya** - Popular training tool
176
+ - **OneTrainer** - Alternative training solution
177
+ - **Custom datasets** - Train on your own content
178
+
179
+ ## Macro System (Advanced)
180
+
181
+ Create multiple prompts from templates using macros:
182
+
183
+ ```
184
+ ! {Subject}="cat","woman","man", {Location}="forest","lake","city", {Possessive}="its","her","his"
185
+ In the video, a {Subject} is presented. The {Subject} is in a {Location} and looks at {Possessive} watch.
186
+ ```
187
+
188
+ This generates:
189
+ 1. "In the video, a cat is presented. The cat is in a forest and looks at its watch."
190
+ 2. "In the video, a woman is presented. The woman is in a lake and looks at her watch."
191
+ 3. "In the video, a man is presented. The man is in a city and looks at his watch."
192
+
193
+ ## Troubleshooting
194
+
195
+ ### Lora Not Working
196
+ 1. Check if lora is compatible with your model size (1.3B vs 14B)
197
+ 2. Verify lora format is supported
198
+ 3. Try different multiplier values
199
+ 4. Check the lora was trained for your model type (t2v vs i2v)
200
+
201
+ ### Performance Issues
202
+ 1. Reduce number of active loras
203
+ 2. Lower multiplier values
204
+ 3. Use `--check-loras` to filter incompatible files
205
+ 4. Clear lora cache if issues persist
206
+
207
+ ### Memory Errors
208
+ 1. Use fewer loras simultaneously
209
+ 2. Reduce model size (use 1.3B instead of 14B)
210
+ 3. Lower video resolution or frame count
211
+ 4. Enable quantization if not already active
212
+
213
+ ## Command Line Options
214
+
215
+ ```bash
216
+ # Lora-related command line options
217
+ --lora-dir path # Path to t2v loras directory
218
+ --lora-dir-i2v path # Path to i2v loras directory
219
+ --lora-dir-hunyuan path # Path to Hunyuan t2v loras
220
+ --lora-dir-hunyuan-i2v path # Path to Hunyuan i2v loras
221
+ --lora-dir-ltxv path # Path to LTX Video loras
222
+ --lora-preset preset # Load preset on startup
223
+ --check-loras # Filter incompatible loras
224
+ ```
docs/MODELS.md ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models Overview
2
+
3
+ WanGP supports multiple video generation models, each optimized for different use cases and hardware configurations.
4
+
5
+
6
+ ## Wan 2.1 Text2Video Models
7
+ Please note that that the term *Text2Video* refers to the underlying Wan architecture but as it has been greatly improved overtime many derived Text2Video models can now generate videos using images.
8
+
9
+ #### Wan 2.1 Text2Video 1.3B
10
+ - **Size**: 1.3 billion parameters
11
+ - **VRAM**: 6GB minimum
12
+ - **Speed**: Fast generation
13
+ - **Quality**: Good quality for the size
14
+ - **Best for**: Quick iterations, lower-end hardware
15
+ - **Command**: `python wgp.py --t2v-1-3B`
16
+
17
+ #### Wan 2.1 Text2Video 14B
18
+ - **Size**: 14 billion parameters
19
+ - **VRAM**: 12GB+ recommended
20
+ - **Speed**: Slower but higher quality
21
+ - **Quality**: Excellent detail and coherence
22
+ - **Best for**: Final production videos
23
+ - **Command**: `python wgp.py --t2v-14B`
24
+
25
+ #### Wan Vace 1.3B
26
+ - **Type**: ControlNet for advanced video control
27
+ - **VRAM**: 6GB minimum
28
+ - **Features**: Motion transfer, object injection, inpainting
29
+ - **Best for**: Advanced video manipulation
30
+ - **Command**: `python wgp.py --vace-1.3B`
31
+
32
+ #### Wan Vace 14B
33
+ - **Type**: Large ControlNet model
34
+ - **VRAM**: 12GB+ recommended
35
+ - **Features**: All Vace features with higher quality
36
+ - **Best for**: Professional video editing workflows
37
+
38
+ #### MoviiGen (Experimental)
39
+ - **Resolution**: Claims 1080p capability
40
+ - **VRAM**: 20GB+ required
41
+ - **Speed**: Very slow generation
42
+ - **Features**: Should generate cinema like video, specialized for 2.1 / 1 ratios
43
+ - **Status**: Experimental, feedback welcome
44
+
45
+ <BR>
46
+
47
+ ## Wan 2.1 Image-to-Video Models
48
+
49
+ #### Wan 2.1 Image2Video 14B
50
+ - **Size**: 14 billion parameters
51
+ - **VRAM**: 12GB+ recommended
52
+ - **Speed**: Slower but higher quality
53
+ - **Quality**: Excellent detail and coherence
54
+ - **Best for**: Most Loras available work with this model
55
+ - **Command**: `python wgp.py --i2v-14B`
56
+
57
+ #### FLF2V
58
+ - **Type**: Start/end frame specialist
59
+ - **Resolution**: Optimized for 720p
60
+ - **Official**: Wan team supported
61
+ - **Use case**: Image-to-video with specific endpoints
62
+
63
+
64
+ <BR>
65
+
66
+ ## Wan 2.1 Specialized Models
67
+
68
+ #### FantasySpeaking
69
+ - **Type**: Talking head animation
70
+ - **Input**: Voice track + image
71
+ - **Works on**: People and objects
72
+ - **Use case**: Lip-sync and voice-driven animation
73
+
74
+ #### Phantom
75
+ - **Type**: Person/object transfer
76
+ - **Resolution**: Works well at 720p
77
+ - **Requirements**: 30+ steps for good results
78
+ - **Best for**: Transferring subjects between videos
79
+
80
+ #### Recam Master
81
+ - **Type**: Viewpoint change
82
+ - **Requirements**: 81+ frame input videos, 15+ denoising steps
83
+ - **Use case**: View same scene from different angles
84
+
85
+ #### Sky Reels v2
86
+ - **Type**: Diffusion Forcing model
87
+ - **Specialty**: "Infinite length" videos
88
+ - **Features**: High quality continuous generation
89
+
90
+
91
+ <BR>
92
+
93
+ ## Wan Fun InP Models
94
+
95
+ #### Wan Fun InP 1.3B
96
+ - **Size**: 1.3 billion parameters
97
+ - **VRAM**: 6GB minimum
98
+ - **Quality**: Good for the size, accessible to lower hardware
99
+ - **Best for**: Entry-level image animation
100
+ - **Command**: `python wgp.py --i2v-1-3B`
101
+
102
+ #### Wan Fun InP 14B
103
+ - **Size**: 14 billion parameters
104
+ - **VRAM**: 12GB+ recommended
105
+ - **Quality**: Better end image support
106
+ - **Limitation**: Existing loras don't work as well
107
+
108
+ <BR>
109
+
110
+ ## Wan Special Loras
111
+ ### Causvid
112
+ - **Type**: Distilled model (Lora implementation)
113
+ - **Speed**: 4-12 steps generation, 2x faster
114
+ - **Compatible**: Works with Wan 14B models
115
+ - **Setup**: Requires CausVid Lora (see [LORAS.md](LORAS.md))
116
+
117
+
118
+ <BR>
119
+
120
+ ## Hunyuan Video Models
121
+
122
+ #### Hunyuan Video Text2Video
123
+ - **Quality**: Among the best open source t2v models
124
+ - **VRAM**: 12GB+ recommended
125
+ - **Speed**: Slower generation but excellent results
126
+ - **Features**: Superior text adherence and video quality, up to 10s of video
127
+ - **Best for**: High-quality text-to-video generation
128
+
129
+ #### Hunyuan Video Custom
130
+ - **Specialty**: Identity preservation
131
+ - **Use case**: Injecting specific people into videos
132
+ - **Quality**: Excellent for character consistency
133
+ - **Best for**: Character-focused video generation
134
+
135
+ #### Hunyuan Video Avater
136
+ - **Specialty**: Generate up to 15s of high quality speech / song driven Video .
137
+ - **Use case**: Injecting specific people into videos
138
+ - **Quality**: Excellent for character consistency
139
+ - **Best for**: Character-focused video generation, Video synchronized with voice
140
+
141
+
142
+ <BR>
143
+
144
+ ## LTX Video Models
145
+
146
+ #### LTX Video 13B
147
+ - **Specialty**: Long video generation
148
+ - **Resolution**: Fast 720p generation
149
+ - **VRAM**: Optimized by WanGP (4x reduction in requirements)
150
+ - **Best for**: Longer duration videos
151
+
152
+ #### LTX Video 13B Distilled
153
+ - **Speed**: Generate in less than one minute
154
+ - **Quality**: Very high quality despite speed
155
+ - **Best for**: Rapid prototyping and quick results
156
+
157
+ <BR>
158
+
159
+ ## Model Selection Guide
160
+
161
+ ### By Hardware (VRAM)
162
+
163
+ #### 6-8GB VRAM
164
+ - Wan 2.1 T2V 1.3B
165
+ - Wan Fun InP 1.3B
166
+ - Wan Vace 1.3B
167
+
168
+ #### 10-12GB VRAM
169
+ - Wan 2.1 T2V 14B
170
+ - Wan Fun InP 14B
171
+ - Hunyuan Video (with optimizations)
172
+ - LTX Video 13B
173
+
174
+ #### 16GB+ VRAM
175
+ - All models supported
176
+ - Longer videos possible
177
+ - Higher resolutions
178
+ - Multiple simultaneous Loras
179
+
180
+ #### 20GB+ VRAM
181
+ - MoviiGen (experimental 1080p)
182
+ - Very long videos
183
+ - Maximum quality settings
184
+
185
+ ### By Use Case
186
+
187
+ #### Quick Prototyping
188
+ 1. **LTX Video 13B Distilled** - Fastest, high quality
189
+ 2. **Wan 2.1 T2V 1.3B** - Fast, good quality
190
+ 3. **CausVid Lora** - 4-12 steps, very fast
191
+
192
+ #### Best Quality
193
+ 1. **Hunyuan Video** - Overall best t2v quality
194
+ 2. **Wan 2.1 T2V 14B** - Excellent Wan quality
195
+ 3. **Wan Vace 14B** - Best for controlled generation
196
+
197
+ #### Advanced Control
198
+ 1. **Wan Vace 14B/1.3B** - Motion transfer, object injection
199
+ 2. **Phantom** - Person/object transfer
200
+ 3. **FantasySpeaking** - Voice-driven animation
201
+
202
+ #### Long Videos
203
+ 1. **LTX Video 13B** - Specialized for length
204
+ 2. **Sky Reels v2** - Infinite length videos
205
+ 3. **Wan Vace + Sliding Windows** - Up to 1 minute
206
+
207
+ #### Lower Hardware
208
+ 1. **Wan Fun InP 1.3B** - Image-to-video
209
+ 2. **Wan 2.1 T2V 1.3B** - Text-to-video
210
+ 3. **Wan Vace 1.3B** - Advanced control
211
+
212
+ <BR>
213
+
214
+ ## Performance Comparison
215
+
216
+ ### Speed (Relative)
217
+ 1. **CausVid Lora** (4-12 steps) - Fastest
218
+ 2. **LTX Video Distilled** - Very fast
219
+ 3. **Wan 1.3B models** - Fast
220
+ 4. **Wan 14B models** - Medium
221
+ 5. **Hunyuan Video** - Slower
222
+ 6. **MoviiGen** - Slowest
223
+
224
+ ### Quality (Subjective)
225
+ 1. **Hunyuan Video** - Highest overall
226
+ 2. **Wan 14B models** - Excellent
227
+ 3. **LTX Video models** - Very good
228
+ 4. **Wan 1.3B models** - Good
229
+ 5. **CausVid** - Good (varies with steps)
230
+
231
+ ### VRAM Efficiency
232
+ 1. **Wan 1.3B models** - Most efficient
233
+ 2. **LTX Video** (with WanGP optimizations)
234
+ 3. **Wan 14B models**
235
+ 4. **Hunyuan Video**
236
+ 5. **MoviiGen** - Least efficient
237
+
238
+ <BR>
239
+
240
+ ## Model Switching
241
+
242
+ WanGP allows switching between models without restarting:
243
+
244
+ 1. Use the dropdown menu in the web interface
245
+ 2. Models are loaded on-demand
246
+ 3. Previous model is unloaded to save VRAM
247
+ 4. Settings are preserved when possible
248
+
249
+ <BR>
250
+
251
+ ## Tips for Model Selection
252
+
253
+ ### First Time Users
254
+ Start with **Wan 2.1 T2V 1.3B** to learn the interface and test your hardware.
255
+
256
+ ### Production Work
257
+ Use **Hunyuan Video** or **Wan 14B** models for final output quality.
258
+
259
+ ### Experimentation
260
+ **CausVid Lora** or **LTX Distilled** for rapid iteration and testing.
261
+
262
+ ### Specialized Tasks
263
+ - **VACE** for advanced control
264
+ - **FantasySpeaking** for talking heads
265
+ - **LTX Video** for long sequences
266
+
267
+ ### Hardware Optimization
268
+ Always start with the largest model your VRAM can handle, then optimize settings for speed vs quality based on your needs.
docs/TROUBLESHOOTING.md ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Troubleshooting Guide
2
+
3
+ This guide covers common issues and their solutions when using WanGP.
4
+
5
+ ## Installation Issues
6
+
7
+ ### PyTorch Installation Problems
8
+
9
+ #### CUDA Version Mismatch
10
+ **Problem**: PyTorch can't detect GPU or CUDA errors
11
+ **Solution**:
12
+ ```bash
13
+ # Check your CUDA version
14
+ nvidia-smi
15
+
16
+ # Install matching PyTorch version
17
+ # For CUDA 12.4 (RTX 10XX-40XX)
18
+ pip install torch==2.6.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu124
19
+
20
+ # For CUDA 12.8 (RTX 50XX)
21
+ pip install torch==2.7.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/test/cu128
22
+ ```
23
+
24
+ #### Python Version Issues
25
+ **Problem**: Package compatibility errors
26
+ **Solution**: Ensure you're using Python 3.10.9
27
+ ```bash
28
+ python --version # Should show 3.10.9
29
+ conda create -n wan2gp python=3.10.9
30
+ ```
31
+
32
+ ### Dependency Installation Failures
33
+
34
+ #### Triton Installation (Windows)
35
+ **Problem**: `pip install triton-windows` fails
36
+ **Solution**:
37
+ 1. Update pip: `pip install --upgrade pip`
38
+ 2. Try pre-compiled wheel
39
+ 3. Fallback to SDPA attention: `python wgp.py --attention sdpa`
40
+
41
+ #### SageAttention Compilation Issues
42
+ **Problem**: SageAttention installation fails
43
+ **Solution**:
44
+ 1. Install Visual Studio Build Tools (Windows)
45
+ 2. Use pre-compiled wheels when available
46
+ 3. Fallback to basic attention modes
47
+
48
+ ## Memory Issues
49
+
50
+ ### CUDA Out of Memory
51
+
52
+ #### During Model Loading
53
+ **Problem**: "CUDA out of memory" when loading model
54
+ **Solutions**:
55
+ ```bash
56
+ # Use smaller model
57
+ python wgp.py --t2v-1-3B
58
+
59
+ # Enable quantization (usually default)
60
+ python wgp.py --quantize-transformer True
61
+
62
+ # Use memory-efficient profile
63
+ python wgp.py --profile 4
64
+
65
+ # Reduce preloaded model size
66
+ python wgp.py --preload 0
67
+ ```
68
+
69
+ #### During Video Generation
70
+ **Problem**: Memory error during generation
71
+ **Solutions**:
72
+ 1. Reduce frame count (shorter videos)
73
+ 2. Lower resolution in advanced settings
74
+ 3. Use lower batch size
75
+ 4. Clear GPU cache between generations
76
+
77
+ ### System RAM Issues
78
+
79
+ #### High RAM Usage
80
+ **Problem**: System runs out of RAM
81
+ **Solutions**:
82
+ ```bash
83
+ # Limit reserved memory
84
+ python wgp.py --perc-reserved-mem-max 0.3
85
+
86
+ # Use minimal RAM profile
87
+ python wgp.py --profile 5
88
+
89
+ # Enable swap file (OS level)
90
+ ```
91
+
92
+ ## Performance Issues
93
+
94
+ ### Slow Generation Speed
95
+
96
+ #### General Optimization
97
+ ```bash
98
+ # Enable compilation (requires Triton)
99
+ python wgp.py --compile
100
+
101
+ # Use faster attention
102
+ python wgp.py --attention sage2
103
+
104
+ # Enable TeaCache
105
+ python wgp.py --teacache 2.0
106
+
107
+ # Use high-performance profile
108
+ python wgp.py --profile 3
109
+ ```
110
+
111
+ #### GPU-Specific Optimizations
112
+
113
+ **RTX 10XX/20XX Series**:
114
+ ```bash
115
+ python wgp.py --attention sdpa --profile 4 --teacache 1.5
116
+ ```
117
+
118
+ **RTX 30XX/40XX Series**:
119
+ ```bash
120
+ python wgp.py --compile --attention sage --profile 3 --teacache 2.0
121
+ ```
122
+
123
+ **RTX 50XX Series**:
124
+ ```bash
125
+ python wgp.py --attention sage --profile 4 --fp16
126
+ ```
127
+
128
+ ### Attention Mechanism Issues
129
+
130
+ #### Sage Attention Not Working
131
+ **Problem**: Sage attention fails to compile or work
132
+ **Diagnostic Steps**:
133
+ 1. Check Triton installation:
134
+ ```python
135
+ import triton
136
+ print(triton.__version__)
137
+ ```
138
+ 2. Clear Triton cache:
139
+ ```bash
140
+ # Windows
141
+ rmdir /s %USERPROFILE%\.triton
142
+ # Linux
143
+ rm -rf ~/.triton
144
+ ```
145
+ 3. Fallback solution:
146
+ ```bash
147
+ python wgp.py --attention sdpa
148
+ ```
149
+
150
+ #### Flash Attention Issues
151
+ **Problem**: Flash attention compilation fails
152
+ **Solution**:
153
+ - Windows: Often requires manual CUDA kernel compilation
154
+ - Linux: Usually works with `pip install flash-attn`
155
+ - Fallback: Use Sage or SDPA attention
156
+
157
+ ## Model-Specific Issues
158
+
159
+ ### Lora Problems
160
+
161
+ #### Loras Not Loading
162
+ **Problem**: Loras don't appear in the interface
163
+ **Solutions**:
164
+ 1. Check file format (should be .safetensors, .pt, or .pth)
165
+ 2. Verify correct directory:
166
+ ```
167
+ loras/ # For t2v models
168
+ loras_i2v/ # For i2v models
169
+ loras_hunyuan/ # For Hunyuan models
170
+ ```
171
+ 3. Click "Refresh" button in interface
172
+ 4. Use `--check-loras` to filter incompatible files
173
+
174
+ #### Lora Compatibility Issues
175
+ **Problem**: Lora causes errors or poor results
176
+ **Solutions**:
177
+ 1. Check model size compatibility (1.3B vs 14B)
178
+ 2. Verify lora was trained for your model type
179
+ 3. Try different multiplier values
180
+ 4. Use `--check-loras` flag to auto-filter
181
+
182
+ ### VACE-Specific Issues
183
+
184
+ #### Poor VACE Results
185
+ **Problem**: VACE generates poor quality or unexpected results
186
+ **Solutions**:
187
+ 1. Enable Skip Layer Guidance
188
+ 2. Use detailed prompts describing all elements
189
+ 3. Ensure proper mask creation with Matanyone
190
+ 4. Check reference image quality
191
+ 5. Use at least 15 steps, preferably 30+
192
+
193
+ #### Matanyone Tool Issues
194
+ **Problem**: Mask creation difficulties
195
+ **Solutions**:
196
+ 1. Use negative point prompts to refine selection
197
+ 2. Create multiple sub-masks and combine them
198
+ 3. Try different background removal options
199
+ 4. Ensure sufficient contrast in source video
200
+
201
+ ## Network and Server Issues
202
+
203
+ ### Gradio Interface Problems
204
+
205
+ #### Port Already in Use
206
+ **Problem**: "Port 7860 is already in use"
207
+ **Solution**:
208
+ ```bash
209
+ # Use different port
210
+ python wgp.py --server-port 7861
211
+
212
+ # Or kill existing process
213
+ # Windows
214
+ netstat -ano | findstr :7860
215
+ taskkill /PID <PID> /F
216
+
217
+ # Linux
218
+ lsof -i :7860
219
+ kill <PID>
220
+ ```
221
+
222
+ #### Interface Not Loading
223
+ **Problem**: Browser shows "connection refused"
224
+ **Solutions**:
225
+ 1. Check if server started successfully
226
+ 2. Try `http://127.0.0.1:7860` instead of `localhost:7860`
227
+ 3. Disable firewall temporarily
228
+ 4. Use `--listen` flag for network access
229
+
230
+ ### Remote Access Issues
231
+
232
+ #### Sharing Not Working
233
+ **Problem**: `--share` flag doesn't create public URL
234
+ **Solutions**:
235
+ 1. Check internet connection
236
+ 2. Try different network
237
+ 3. Use `--listen` with port forwarding
238
+ 4. Check firewall settings
239
+
240
+ ## Quality Issues
241
+
242
+ ### Poor Video Quality
243
+
244
+ #### General Quality Improvements
245
+ 1. Increase number of steps (25-30+)
246
+ 2. Use larger models (14B instead of 1.3B)
247
+ 3. Enable Skip Layer Guidance
248
+ 4. Improve prompt descriptions
249
+ 5. Use higher resolution settings
250
+
251
+ #### Specific Quality Issues
252
+
253
+ **Blurry Videos**:
254
+ - Increase steps
255
+ - Check source image quality (i2v)
256
+ - Reduce TeaCache multiplier
257
+ - Use higher guidance scale
258
+
259
+ **Inconsistent Motion**:
260
+ - Use longer overlap in sliding windows
261
+ - Reduce window size
262
+ - Improve prompt consistency
263
+ - Check control video quality (VACE)
264
+
265
+ **Color Issues**:
266
+ - Check model compatibility
267
+ - Adjust guidance scale
268
+ - Verify input image color space
269
+ - Try different VAE settings
270
+
271
+ ## Advanced Debugging
272
+
273
+ ### Enable Verbose Output
274
+ ```bash
275
+ # Maximum verbosity
276
+ python wgp.py --verbose 2
277
+
278
+ # Check lora compatibility
279
+ python wgp.py --check-loras --verbose 2
280
+ ```
281
+
282
+ ### Memory Debugging
283
+ ```bash
284
+ # Monitor GPU memory
285
+ nvidia-smi -l 1
286
+
287
+ # Reduce memory usage
288
+ python wgp.py --profile 4 --perc-reserved-mem-max 0.2
289
+ ```
290
+
291
+ ### Performance Profiling
292
+ ```bash
293
+ # Test different configurations
294
+ python wgp.py --attention sdpa --profile 4 # Baseline
295
+ python wgp.py --attention sage --profile 3 # Performance
296
+ python wgp.py --compile --teacache 2.0 # Maximum speed
297
+ ```
298
+
299
+ ## Getting Help
300
+
301
+ ### Before Asking for Help
302
+ 1. Check this troubleshooting guide
303
+ 2. Read the relevant documentation:
304
+ - [Installation Guide](INSTALLATION.md)
305
+ - [Getting Started](GETTING_STARTED.md)
306
+ - [Command Line Reference](CLI.md)
307
+ 3. Try basic fallback configuration:
308
+ ```bash
309
+ python wgp.py --attention sdpa --profile 4
310
+ ```
311
+
312
+ ### Community Support
313
+ - **Discord Server**: https://discord.gg/g7efUW9jGV
314
+ - Provide relevant information:
315
+ - GPU model and VRAM amount
316
+ - Python and PyTorch versions
317
+ - Complete error messages
318
+ - Command used to launch WanGP
319
+ - Operating system
320
+
321
+ ### Reporting Bugs
322
+ When reporting issues:
323
+ 1. Include system specifications
324
+ 2. Provide complete error logs
325
+ 3. List the exact steps to reproduce
326
+ 4. Mention any modifications to default settings
327
+ 5. Include command line arguments used
328
+
329
+ ## Emergency Fallback
330
+
331
+ If nothing works, try this minimal configuration:
332
+ ```bash
333
+ # Absolute minimum setup
334
+ python wgp.py --t2v-1-3B --attention sdpa --profile 4 --teacache 0 --fp16
335
+
336
+ # If that fails, check basic PyTorch installation
337
+ python -c "import torch; print(torch.cuda.is_available())"
338
+ ```
docs/VACE.md ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # VACE ControlNet Guide
2
+
3
+ VACE is a powerful ControlNet that enables Video-to-Video and Reference-to-Video generation. It allows you to inject your own images into output videos, animate characters, perform inpainting/outpainting, and continue videos.
4
+
5
+ ## Overview
6
+
7
+ VACE is probably one of the most powerful Wan models available. With it, you can:
8
+ - Inject people or objects into scenes
9
+ - Animate characters
10
+ - Perform video inpainting and outpainting
11
+ - Continue existing videos
12
+ - Transfer motion from one video to another
13
+ - Change the style of scenes while preserving depth
14
+
15
+ ## Getting Started
16
+
17
+ ### Model Selection
18
+ 1. Select either "Vace 1.3B" or "Vace 13B" from the dropdown menu
19
+ 2. Note: VACE works best with videos up to 7 seconds with the Riflex option enabled
20
+
21
+ ### Input Types
22
+
23
+ VACE accepts three types of visual hints (which can be combined):
24
+
25
+ #### 1. Control Video
26
+ - Transfer motion or depth to a new video
27
+ - Use only the first n frames and extrapolate the rest
28
+ - Perform inpainting with grey color (127) as mask areas
29
+ - Grey areas will be filled based on text prompt and reference images
30
+
31
+ #### 2. Reference Images
32
+ - Use as background/setting for the video
33
+ - Inject people or objects of your choice
34
+ - Select multiple reference images
35
+ - **Tip**: Replace complex backgrounds with white for better object integration
36
+ - Always describe injected objects/people explicitly in your text prompt
37
+
38
+ #### 3. Video Mask
39
+ - Stronger control over which parts to keep (black) or replace (white)
40
+ - Perfect for inpainting/outpainting
41
+ - Example: White mask except at beginning/end (black) keeps first/last frames while generating middle content
42
+
43
+ ## Common Use Cases
44
+
45
+ ### Motion Transfer
46
+ **Goal**: Animate a character of your choice using motion from another video
47
+ **Setup**:
48
+ - Reference Images: Your character
49
+ - Control Video: Person performing desired motion
50
+ - Text Prompt: Describe your character and the action
51
+
52
+ ### Object/Person Injection
53
+ **Goal**: Insert people or objects into a scene
54
+ **Setup**:
55
+ - Reference Images: The people/objects to inject
56
+ - Text Prompt: Describe the scene and explicitly mention the injected elements
57
+
58
+ ### Character Animation
59
+ **Goal**: Animate a character based on text description
60
+ **Setup**:
61
+ - Control Video: Video of person moving
62
+ - Text Prompt: Detailed description of your character
63
+
64
+ ### Style Transfer with Depth
65
+ **Goal**: Change scene style while preserving spatial relationships
66
+ **Setup**:
67
+ - Control Video: Original video (for depth information)
68
+ - Text Prompt: New style description
69
+
70
+ ## Integrated Matanyone Tool
71
+
72
+ WanGP includes the Matanyone tool, specifically tuned for VACE workflows. This helps create control videos and masks simultaneously.
73
+
74
+ ### Creating Face Replacement Masks
75
+ 1. Load your video in Matanyone
76
+ 2. Click on the face in the first frame
77
+ 3. Create a mask for the face
78
+ 4. Generate both control video and mask video with "Generate Video Matting"
79
+ 5. Export to VACE with "Export to current Video Input and Video Mask"
80
+ 6. Load replacement face image in Reference Images field
81
+
82
+ ### Advanced Matanyone Tips
83
+ - **Negative Point Prompts**: Remove parts from current selection
84
+ - **Sub Masks**: Create multiple independent masks, then combine them
85
+ - **Background Masks**: Select everything except the character (useful for background replacement)
86
+ - Enable/disable sub masks in Matanyone settings
87
+
88
+ ## Recommended Settings
89
+
90
+ ### Quality Settings
91
+ - **Skip Layer Guidance**: Turn ON with default configuration for better results
92
+ - **Long Prompts**: Use detailed descriptions, especially for background elements not in reference images
93
+ - **Steps**: Use at least 15 steps for good quality, 30+ for best results
94
+
95
+ ### Sliding Window Settings
96
+ For very long videos, configure sliding windows properly:
97
+
98
+ - **Window Size**: Set appropriate duration for your content
99
+ - **Overlap Frames**: Long enough for motion continuity, short enough to avoid blur propagation
100
+ - **Discard Last Frames**: Remove at least 4 frames from each window (VACE 1.3B tends to blur final frames)
101
+
102
+ ### Background Removal
103
+ VACE includes automatic background removal options:
104
+ - Use for reference images containing people/objects
105
+ - **Don't use** for landscape/setting reference images (first reference image)
106
+ - Multiple background removal types available
107
+
108
+ ## Window Sliding for Long Videos
109
+
110
+ Generate videos up to 1 minute by merging multiple windows:
111
+
112
+ ### How It Works
113
+ - Each window uses corresponding time segment from control video
114
+ - Example: 0-4s control video → first window, 4-8s → second window, etc.
115
+ - Automatic overlap management ensures smooth transitions
116
+
117
+ ### Settings
118
+ - **Window Size**: Duration of each generation window
119
+ - **Overlap Frames**: Frames shared between windows for continuity
120
+ - **Discard Last Frames**: Remove poor-quality ending frames
121
+ - **Add Overlapped Noise**: Reduce quality degradation over time
122
+
123
+ ### Formula
124
+ ```
125
+ Generated Frames = [Windows - 1] × [Window Size - Overlap - Discard] + Window Size
126
+ ```
127
+
128
+ ### Multi-Line Prompts (Experimental)
129
+ - Each line of prompt used for different window
130
+ - If more windows than prompt lines, last line repeats
131
+ - Separate lines with carriage return
132
+
133
+ ## Advanced Features
134
+
135
+ ### Extend Video
136
+ Click "Extend the Video Sample, Please!" during generation to add more windows dynamically.
137
+
138
+ ### Noise Addition
139
+ Add noise to overlapped frames to hide accumulated errors and quality degradation.
140
+
141
+ ### Frame Truncation
142
+ Automatically remove lower-quality final frames from each window (recommended: 4 frames for VACE 1.3B).
143
+
144
+ ## External Resources
145
+
146
+ ### Official VACE Resources
147
+ - **GitHub**: https://github.com/ali-vilab/VACE/tree/main/vace/gradios
148
+ - **User Guide**: https://github.com/ali-vilab/VACE/blob/main/UserGuide.md
149
+ - **Preprocessors**: Gradio tools for preparing materials
150
+
151
+ ### Recommended External Tools
152
+ - **Annotation Tools**: For creating precise masks
153
+ - **Video Editors**: For preparing control videos
154
+ - **Background Removal**: For cleaning reference images
155
+
156
+ ## Troubleshooting
157
+
158
+ ### Poor Quality Results
159
+ 1. Use longer, more detailed prompts
160
+ 2. Enable Skip Layer Guidance
161
+ 3. Increase number of steps (30+)
162
+ 4. Check reference image quality
163
+ 5. Ensure proper mask creation
164
+
165
+ ### Inconsistent Windows
166
+ 1. Increase overlap frames
167
+ 2. Use consistent prompting across windows
168
+ 3. Add noise to overlapped frames
169
+ 4. Reduce discard frames if losing too much content
170
+
171
+ ### Memory Issues
172
+ 1. Use VACE 1.3B instead of 13B
173
+ 2. Reduce video length or resolution
174
+ 3. Decrease window size
175
+ 4. Enable quantization
176
+
177
+ ### Blurry Results
178
+ 1. Reduce overlap frames
179
+ 2. Increase discard last frames
180
+ 3. Use higher resolution reference images
181
+ 4. Check control video quality
182
+
183
+ ## Tips for Best Results
184
+
185
+ 1. **Detailed Prompts**: Describe everything in the scene, especially elements not in reference images
186
+ 2. **Quality Reference Images**: Use high-resolution, well-lit reference images
187
+ 3. **Proper Masking**: Take time to create precise masks with Matanyone
188
+ 4. **Iterative Approach**: Start with short videos, then extend successful results
189
+ 5. **Background Preparation**: Remove complex backgrounds from object/person reference images
190
+ 6. **Consistent Lighting**: Match lighting between reference images and intended scene
fantasytalking/infer.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Alibaba Inc. All Rights Reserved.
2
+
3
+ from transformers import Wav2Vec2Model, Wav2Vec2Processor
4
+
5
+ from .model import FantasyTalkingAudioConditionModel
6
+ from .utils import get_audio_features
7
+ import gc, torch
8
+
9
+ def parse_audio(audio_path, num_frames, fps = 23, device = "cuda"):
10
+ fantasytalking = FantasyTalkingAudioConditionModel(None, 768, 2048).to(device)
11
+ from mmgp import offload
12
+ from accelerate import init_empty_weights
13
+ from fantasytalking.model import AudioProjModel
14
+
15
+ torch.set_grad_enabled(False)
16
+
17
+ with init_empty_weights():
18
+ proj_model = AudioProjModel( 768, 2048)
19
+ offload.load_model_data(proj_model, "ckpts/fantasy_proj_model.safetensors")
20
+ proj_model.to("cpu").eval().requires_grad_(False)
21
+
22
+ wav2vec_model_dir = "ckpts/wav2vec"
23
+ wav2vec_processor = Wav2Vec2Processor.from_pretrained(wav2vec_model_dir)
24
+ wav2vec = Wav2Vec2Model.from_pretrained(wav2vec_model_dir, device_map="cpu").eval().requires_grad_(False)
25
+ wav2vec.to(device)
26
+ proj_model.to(device)
27
+ audio_wav2vec_fea = get_audio_features( wav2vec, wav2vec_processor, audio_path, fps, num_frames )
28
+
29
+ audio_proj_fea = proj_model(audio_wav2vec_fea)
30
+ pos_idx_ranges = fantasytalking.split_audio_sequence( audio_proj_fea.size(1), num_frames=num_frames )
31
+ audio_proj_split, audio_context_lens = fantasytalking.split_tensor_with_padding( audio_proj_fea, pos_idx_ranges, expand_length=4 ) # [b,21,9+8,768]
32
+ wav2vec, proj_model= None, None
33
+ gc.collect()
34
+ torch.cuda.empty_cache()
35
+
36
+ return audio_proj_split, audio_context_lens
fantasytalking/model.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from wan.modules.attention import pay_attention
5
+
6
+
7
+ class AudioProjModel(nn.Module):
8
+ def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
9
+ super().__init__()
10
+ self.cross_attention_dim = cross_attention_dim
11
+ self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
12
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
13
+
14
+ def forward(self, audio_embeds):
15
+ context_tokens = self.proj(audio_embeds)
16
+ context_tokens = self.norm(context_tokens)
17
+ return context_tokens # [B,L,C]
18
+
19
+ class WanCrossAttentionProcessor(nn.Module):
20
+ def __init__(self, context_dim, hidden_dim):
21
+ super().__init__()
22
+
23
+ self.context_dim = context_dim
24
+ self.hidden_dim = hidden_dim
25
+
26
+ self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
27
+ self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
28
+
29
+ nn.init.zeros_(self.k_proj.weight)
30
+ nn.init.zeros_(self.v_proj.weight)
31
+
32
+ def __call__(
33
+ self,
34
+ q: torch.Tensor,
35
+ audio_proj: torch.Tensor,
36
+ latents_num_frames: int = 21,
37
+ audio_context_lens = None
38
+ ) -> torch.Tensor:
39
+ """
40
+ audio_proj: [B, 21, L3, C]
41
+ audio_context_lens: [B*21].
42
+ """
43
+ b, l, n, d = q.shape
44
+
45
+ if len(audio_proj.shape) == 4:
46
+ audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
47
+ ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
48
+ ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
49
+ qkv_list = [audio_q, ip_key, ip_value]
50
+ del q, audio_q, ip_key, ip_value
51
+ audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
52
+ audio_x = audio_x.view(b, l, n, d)
53
+ audio_x = audio_x.flatten(2)
54
+ elif len(audio_proj.shape) == 3:
55
+ ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
56
+ ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
57
+ qkv_list = [q, ip_key, ip_value]
58
+ del q, ip_key, ip_value
59
+ audio_x = pay_attention(qkv_list, k_lens =audio_context_lens) #audio_context_lens
60
+ audio_x = audio_x.flatten(2)
61
+ return audio_x
62
+
63
+
64
+ class FantasyTalkingAudioConditionModel(nn.Module):
65
+ def __init__(self, wan_dit, audio_in_dim: int, audio_proj_dim: int):
66
+ super().__init__()
67
+
68
+ self.audio_in_dim = audio_in_dim
69
+ self.audio_proj_dim = audio_proj_dim
70
+
71
+ def split_audio_sequence(self, audio_proj_length, num_frames=81):
72
+ """
73
+ Map the audio feature sequence to corresponding latent frame slices.
74
+
75
+ Args:
76
+ audio_proj_length (int): The total length of the audio feature sequence
77
+ (e.g., 173 in audio_proj[1, 173, 768]).
78
+ num_frames (int): The number of video frames in the training data (default: 81).
79
+
80
+ Returns:
81
+ list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
82
+ (within the audio feature sequence) corresponding to a latent frame.
83
+ """
84
+ # Average number of tokens per original video frame
85
+ tokens_per_frame = audio_proj_length / num_frames
86
+
87
+ # Each latent frame covers 4 video frames, and we want the center
88
+ tokens_per_latent_frame = tokens_per_frame * 4
89
+ half_tokens = int(tokens_per_latent_frame / 2)
90
+
91
+ pos_indices = []
92
+ for i in range(int((num_frames - 1) / 4) + 1):
93
+ if i == 0:
94
+ pos_indices.append(0)
95
+ else:
96
+ start_token = tokens_per_frame * ((i - 1) * 4 + 1)
97
+ end_token = tokens_per_frame * (i * 4 + 1)
98
+ center_token = int((start_token + end_token) / 2) - 1
99
+ pos_indices.append(center_token)
100
+
101
+ # Build index ranges centered around each position
102
+ pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
103
+
104
+ # Adjust the first range to avoid negative start index
105
+ pos_idx_ranges[0] = [
106
+ -(half_tokens * 2 - pos_idx_ranges[1][0]),
107
+ pos_idx_ranges[1][0],
108
+ ]
109
+
110
+ return pos_idx_ranges
111
+
112
+ def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
113
+ """
114
+ Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
115
+ if the range exceeds the input boundaries.
116
+
117
+ Args:
118
+ input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
119
+ pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
120
+ expand_length (int): Number of tokens to expand on both sides of each subsequence.
121
+
122
+ Returns:
123
+ sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
124
+ Each element is a padded subsequence.
125
+ k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
126
+ Useful for ignoring padding tokens in attention masks.
127
+ """
128
+ pos_idx_ranges = [
129
+ [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
130
+ ]
131
+ sub_sequences = []
132
+ seq_len = input_tensor.size(1) # 173
133
+ max_valid_idx = seq_len - 1 # 172
134
+ k_lens_list = []
135
+ for start, end in pos_idx_ranges:
136
+ # Calculate the fill amount
137
+ pad_front = max(-start, 0)
138
+ pad_back = max(end - max_valid_idx, 0)
139
+
140
+ # Calculate the start and end indices of the valid part
141
+ valid_start = max(start, 0)
142
+ valid_end = min(end, max_valid_idx)
143
+
144
+ # Extract the valid part
145
+ if valid_start <= valid_end:
146
+ valid_part = input_tensor[:, valid_start : valid_end + 1, :]
147
+ else:
148
+ valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
149
+
150
+ # In the sequence dimension (the 1st dimension) perform padding
151
+ padded_subseq = F.pad(
152
+ valid_part,
153
+ (0, 0, 0, pad_back + pad_front, 0, 0),
154
+ mode="constant",
155
+ value=0,
156
+ )
157
+ k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
158
+
159
+ sub_sequences.append(padded_subseq)
160
+ return torch.stack(sub_sequences, dim=1), torch.tensor(
161
+ k_lens_list, dtype=torch.long
162
+ )
fantasytalking/utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Alibaba Inc. All Rights Reserved.
2
+
3
+ import imageio
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ from tqdm import tqdm
9
+
10
+
11
+ def resize_image_by_longest_edge(image_path, target_size):
12
+ image = Image.open(image_path).convert("RGB")
13
+ width, height = image.size
14
+ scale = target_size / max(width, height)
15
+ new_size = (int(width * scale), int(height * scale))
16
+ return image.resize(new_size, Image.LANCZOS)
17
+
18
+
19
+ def save_video(frames, save_path, fps, quality=9, ffmpeg_params=None):
20
+ writer = imageio.get_writer(
21
+ save_path, fps=fps, quality=quality, ffmpeg_params=ffmpeg_params
22
+ )
23
+ for frame in tqdm(frames, desc="Saving video"):
24
+ frame = np.array(frame)
25
+ writer.append_data(frame)
26
+ writer.close()
27
+
28
+
29
+ def get_audio_features(wav2vec, audio_processor, audio_path, fps, num_frames):
30
+ sr = 16000
31
+ audio_input, sample_rate = librosa.load(audio_path, sr=sr) # 采样率为 16kHz
32
+
33
+ start_time = 0
34
+ # end_time = (0 + (num_frames - 1) * 1) / fps
35
+ end_time = num_frames / fps
36
+
37
+ start_sample = int(start_time * sr)
38
+ end_sample = int(end_time * sr)
39
+
40
+ try:
41
+ audio_segment = audio_input[start_sample:end_sample]
42
+ except:
43
+ audio_segment = audio_input
44
+
45
+ input_values = audio_processor(
46
+ audio_segment, sampling_rate=sample_rate, return_tensors="pt"
47
+ ).input_values.to("cuda")
48
+
49
+ with torch.no_grad():
50
+ fea = wav2vec(input_values).last_hidden_state
51
+
52
+ return fea
hyvideo/__init__.py ADDED
File without changes
hyvideo/config.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from .constants import *
3
+ import re
4
+ from .modules.models import HUNYUAN_VIDEO_CONFIG
5
+
6
+
7
+ def parse_args(namespace=None):
8
+ parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
9
+
10
+ parser = add_network_args(parser)
11
+ parser = add_extra_models_args(parser)
12
+ parser = add_denoise_schedule_args(parser)
13
+ parser = add_inference_args(parser)
14
+ parser = add_parallel_args(parser)
15
+
16
+ args = parser.parse_args(namespace=namespace)
17
+ args = sanity_check_args(args)
18
+
19
+ return args
20
+
21
+
22
+ def add_network_args(parser: argparse.ArgumentParser):
23
+ group = parser.add_argument_group(title="HunyuanVideo network args")
24
+
25
+
26
+ group.add_argument(
27
+ "--quantize-transformer",
28
+ action="store_true",
29
+ help="On the fly 'transformer' quantization"
30
+ )
31
+
32
+
33
+ group.add_argument(
34
+ "--lora-dir-i2v",
35
+ type=str,
36
+ default="loras_i2v",
37
+ help="Path to a directory that contains Loras for i2v"
38
+ )
39
+
40
+
41
+ group.add_argument(
42
+ "--lora-dir",
43
+ type=str,
44
+ default="",
45
+ help="Path to a directory that contains Loras"
46
+ )
47
+
48
+
49
+ group.add_argument(
50
+ "--lora-preset",
51
+ type=str,
52
+ default="",
53
+ help="Lora preset to preload"
54
+ )
55
+
56
+ # group.add_argument(
57
+ # "--lora-preset-i2v",
58
+ # type=str,
59
+ # default="",
60
+ # help="Lora preset to preload for i2v"
61
+ # )
62
+
63
+ group.add_argument(
64
+ "--profile",
65
+ type=str,
66
+ default=-1,
67
+ help="Profile No"
68
+ )
69
+
70
+ group.add_argument(
71
+ "--verbose",
72
+ type=str,
73
+ default=1,
74
+ help="Verbose level"
75
+ )
76
+
77
+ group.add_argument(
78
+ "--server-port",
79
+ type=str,
80
+ default=0,
81
+ help="Server port"
82
+ )
83
+
84
+ group.add_argument(
85
+ "--server-name",
86
+ type=str,
87
+ default="",
88
+ help="Server name"
89
+ )
90
+
91
+ group.add_argument(
92
+ "--open-browser",
93
+ action="store_true",
94
+ help="open browser"
95
+ )
96
+
97
+ group.add_argument(
98
+ "--t2v",
99
+ action="store_true",
100
+ help="text to video mode"
101
+ )
102
+
103
+ group.add_argument(
104
+ "--i2v",
105
+ action="store_true",
106
+ help="image to video mode"
107
+ )
108
+
109
+ group.add_argument(
110
+ "--compile",
111
+ action="store_true",
112
+ help="Enable pytorch compilation"
113
+ )
114
+
115
+ group.add_argument(
116
+ "--fast",
117
+ action="store_true",
118
+ help="use Fast HunyuanVideo model"
119
+ )
120
+
121
+ group.add_argument(
122
+ "--fastest",
123
+ action="store_true",
124
+ help="activate the best config"
125
+ )
126
+
127
+ group.add_argument(
128
+ "--attention",
129
+ type=str,
130
+ default="",
131
+ help="attention mode"
132
+ )
133
+
134
+ group.add_argument(
135
+ "--vae-config",
136
+ type=str,
137
+ default="",
138
+ help="vae config mode"
139
+ )
140
+
141
+ parser.add_argument(
142
+ "--share",
143
+ action="store_true",
144
+ help="Create a shared URL to access webserver remotely"
145
+ )
146
+
147
+ parser.add_argument(
148
+ "--lock-config",
149
+ action="store_true",
150
+ help="Prevent modifying the configuration from the web interface"
151
+ )
152
+
153
+ parser.add_argument(
154
+ "--preload",
155
+ type=str,
156
+ default="0",
157
+ help="Megabytes of the diffusion model to preload in VRAM"
158
+ )
159
+
160
+ parser.add_argument(
161
+ "--multiple-images",
162
+ action="store_true",
163
+ help="Allow inputting multiple images with image to video"
164
+ )
165
+
166
+
167
+ # Main model
168
+ group.add_argument(
169
+ "--model",
170
+ type=str,
171
+ choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
172
+ default="HYVideo-T/2-cfgdistill",
173
+ )
174
+ group.add_argument(
175
+ "--latent-channels",
176
+ type=str,
177
+ default=16,
178
+ help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
179
+ "it still needs to match the latent channels of the VAE model.",
180
+ )
181
+ group.add_argument(
182
+ "--precision",
183
+ type=str,
184
+ default="bf16",
185
+ choices=PRECISIONS,
186
+ help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
187
+ )
188
+
189
+ # RoPE
190
+ group.add_argument(
191
+ "--rope-theta", type=int, default=256, help="Theta used in RoPE."
192
+ )
193
+ return parser
194
+
195
+
196
+ def add_extra_models_args(parser: argparse.ArgumentParser):
197
+ group = parser.add_argument_group(
198
+ title="Extra models args, including vae, text encoders and tokenizers)"
199
+ )
200
+
201
+ # - VAE
202
+ group.add_argument(
203
+ "--vae",
204
+ type=str,
205
+ default="884-16c-hy",
206
+ choices=list(VAE_PATH),
207
+ help="Name of the VAE model.",
208
+ )
209
+ group.add_argument(
210
+ "--vae-precision",
211
+ type=str,
212
+ default="fp16",
213
+ choices=PRECISIONS,
214
+ help="Precision mode for the VAE model.",
215
+ )
216
+ group.add_argument(
217
+ "--vae-tiling",
218
+ action="store_true",
219
+ help="Enable tiling for the VAE model to save GPU memory.",
220
+ )
221
+ group.set_defaults(vae_tiling=True)
222
+
223
+ group.add_argument(
224
+ "--text-encoder",
225
+ type=str,
226
+ default="llm",
227
+ choices=list(TEXT_ENCODER_PATH),
228
+ help="Name of the text encoder model.",
229
+ )
230
+ group.add_argument(
231
+ "--text-encoder-precision",
232
+ type=str,
233
+ default="fp16",
234
+ choices=PRECISIONS,
235
+ help="Precision mode for the text encoder model.",
236
+ )
237
+ group.add_argument(
238
+ "--text-states-dim",
239
+ type=int,
240
+ default=4096,
241
+ help="Dimension of the text encoder hidden states.",
242
+ )
243
+ group.add_argument(
244
+ "--text-len", type=int, default=256, help="Maximum length of the text input."
245
+ )
246
+ group.add_argument(
247
+ "--tokenizer",
248
+ type=str,
249
+ default="llm",
250
+ choices=list(TOKENIZER_PATH),
251
+ help="Name of the tokenizer model.",
252
+ )
253
+ group.add_argument(
254
+ "--prompt-template",
255
+ type=str,
256
+ default="dit-llm-encode",
257
+ choices=PROMPT_TEMPLATE,
258
+ help="Image prompt template for the decoder-only text encoder model.",
259
+ )
260
+ group.add_argument(
261
+ "--prompt-template-video",
262
+ type=str,
263
+ default="dit-llm-encode-video",
264
+ choices=PROMPT_TEMPLATE,
265
+ help="Video prompt template for the decoder-only text encoder model.",
266
+ )
267
+ group.add_argument(
268
+ "--hidden-state-skip-layer",
269
+ type=int,
270
+ default=2,
271
+ help="Skip layer for hidden states.",
272
+ )
273
+ group.add_argument(
274
+ "--apply-final-norm",
275
+ action="store_true",
276
+ help="Apply final normalization to the used text encoder hidden states.",
277
+ )
278
+
279
+ # - CLIP
280
+ group.add_argument(
281
+ "--text-encoder-2",
282
+ type=str,
283
+ default="clipL",
284
+ choices=list(TEXT_ENCODER_PATH),
285
+ help="Name of the second text encoder model.",
286
+ )
287
+ group.add_argument(
288
+ "--text-encoder-precision-2",
289
+ type=str,
290
+ default="fp16",
291
+ choices=PRECISIONS,
292
+ help="Precision mode for the second text encoder model.",
293
+ )
294
+ group.add_argument(
295
+ "--text-states-dim-2",
296
+ type=int,
297
+ default=768,
298
+ help="Dimension of the second text encoder hidden states.",
299
+ )
300
+ group.add_argument(
301
+ "--tokenizer-2",
302
+ type=str,
303
+ default="clipL",
304
+ choices=list(TOKENIZER_PATH),
305
+ help="Name of the second tokenizer model.",
306
+ )
307
+ group.add_argument(
308
+ "--text-len-2",
309
+ type=int,
310
+ default=77,
311
+ help="Maximum length of the second text input.",
312
+ )
313
+
314
+ return parser
315
+
316
+
317
+ def add_denoise_schedule_args(parser: argparse.ArgumentParser):
318
+ group = parser.add_argument_group(title="Denoise schedule args")
319
+
320
+ group.add_argument(
321
+ "--denoise-type",
322
+ type=str,
323
+ default="flow",
324
+ help="Denoise type for noised inputs.",
325
+ )
326
+
327
+ # Flow Matching
328
+ group.add_argument(
329
+ "--flow-shift",
330
+ type=float,
331
+ default=7.0,
332
+ help="Shift factor for flow matching schedulers.",
333
+ )
334
+ group.add_argument(
335
+ "--flow-reverse",
336
+ action="store_true",
337
+ help="If reverse, learning/sampling from t=1 -> t=0.",
338
+ )
339
+ group.add_argument(
340
+ "--flow-solver",
341
+ type=str,
342
+ default="euler",
343
+ help="Solver for flow matching.",
344
+ )
345
+ group.add_argument(
346
+ "--use-linear-quadratic-schedule",
347
+ action="store_true",
348
+ help="Use linear quadratic schedule for flow matching."
349
+ "Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
350
+ )
351
+ group.add_argument(
352
+ "--linear-schedule-end",
353
+ type=int,
354
+ default=25,
355
+ help="End step for linear quadratic schedule for flow matching.",
356
+ )
357
+
358
+ return parser
359
+
360
+
361
+ def add_inference_args(parser: argparse.ArgumentParser):
362
+ group = parser.add_argument_group(title="Inference args")
363
+
364
+ # ======================== Model loads ========================
365
+ group.add_argument(
366
+ "--model-base",
367
+ type=str,
368
+ default="ckpts",
369
+ help="Root path of all the models, including t2v models and extra models.",
370
+ )
371
+ group.add_argument(
372
+ "--dit-weight",
373
+ type=str,
374
+ default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
375
+ help="Path to the HunyuanVideo model. If None, search the model in the args.model_root."
376
+ "1. If it is a file, load the model directly."
377
+ "2. If it is a directory, search the model in the directory. Support two types of models: "
378
+ "1) named `pytorch_model_*.pt`"
379
+ "2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
380
+ )
381
+ group.add_argument(
382
+ "--model-resolution",
383
+ type=str,
384
+ default="540p",
385
+ choices=["540p", "720p"],
386
+ help="Root path of all the models, including t2v models and extra models.",
387
+ )
388
+ group.add_argument(
389
+ "--load-key",
390
+ type=str,
391
+ default="module",
392
+ help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
393
+ )
394
+ group.add_argument(
395
+ "--use-cpu-offload",
396
+ action="store_true",
397
+ help="Use CPU offload for the model load.",
398
+ )
399
+
400
+ # ======================== Inference general setting ========================
401
+ group.add_argument(
402
+ "--batch-size",
403
+ type=int,
404
+ default=1,
405
+ help="Batch size for inference and evaluation.",
406
+ )
407
+ group.add_argument(
408
+ "--infer-steps",
409
+ type=int,
410
+ default=50,
411
+ help="Number of denoising steps for inference.",
412
+ )
413
+ group.add_argument(
414
+ "--disable-autocast",
415
+ action="store_true",
416
+ help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
417
+ )
418
+ group.add_argument(
419
+ "--save-path",
420
+ type=str,
421
+ default="./results",
422
+ help="Path to save the generated samples.",
423
+ )
424
+ group.add_argument(
425
+ "--save-path-suffix",
426
+ type=str,
427
+ default="",
428
+ help="Suffix for the directory of saved samples.",
429
+ )
430
+ group.add_argument(
431
+ "--name-suffix",
432
+ type=str,
433
+ default="",
434
+ help="Suffix for the names of saved samples.",
435
+ )
436
+ group.add_argument(
437
+ "--num-videos",
438
+ type=int,
439
+ default=1,
440
+ help="Number of videos to generate for each prompt.",
441
+ )
442
+ # ---sample size---
443
+ group.add_argument(
444
+ "--video-size",
445
+ type=int,
446
+ nargs="+",
447
+ default=(720, 1280),
448
+ help="Video size for training. If a single value is provided, it will be used for both height "
449
+ "and width. If two values are provided, they will be used for height and width "
450
+ "respectively.",
451
+ )
452
+ group.add_argument(
453
+ "--video-length",
454
+ type=int,
455
+ default=129,
456
+ help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
457
+ )
458
+ # --- prompt ---
459
+ group.add_argument(
460
+ "--prompt",
461
+ type=str,
462
+ default=None,
463
+ help="Prompt for sampling during evaluation.",
464
+ )
465
+ group.add_argument(
466
+ "--seed-type",
467
+ type=str,
468
+ default="auto",
469
+ choices=["file", "random", "fixed", "auto"],
470
+ help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
471
+ "random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
472
+ "seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
473
+ "fixed `seed` value.",
474
+ )
475
+ group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
476
+
477
+ # Classifier-Free Guidance
478
+ group.add_argument(
479
+ "--neg-prompt", type=str, default=None, help="Negative prompt for sampling."
480
+ )
481
+ group.add_argument(
482
+ "--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale."
483
+ )
484
+ group.add_argument(
485
+ "--embedded-cfg-scale",
486
+ type=float,
487
+ default=6.0,
488
+ help="Embeded classifier free guidance scale.",
489
+ )
490
+
491
+ group.add_argument(
492
+ "--reproduce",
493
+ action="store_true",
494
+ help="Enable reproducibility by setting random seeds and deterministic algorithms.",
495
+ )
496
+
497
+ return parser
498
+
499
+
500
+ def add_parallel_args(parser: argparse.ArgumentParser):
501
+ group = parser.add_argument_group(title="Parallel args")
502
+
503
+ # ======================== Model loads ========================
504
+ group.add_argument(
505
+ "--ulysses-degree",
506
+ type=int,
507
+ default=1,
508
+ help="Ulysses degree.",
509
+ )
510
+ group.add_argument(
511
+ "--ring-degree",
512
+ type=int,
513
+ default=1,
514
+ help="Ulysses degree.",
515
+ )
516
+
517
+ return parser
518
+
519
+
520
+ def sanity_check_args(args):
521
+ # VAE channels
522
+ vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
523
+ if not re.match(vae_pattern, args.vae):
524
+ raise ValueError(
525
+ f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'."
526
+ )
527
+ vae_channels = int(args.vae.split("-")[1][:-1])
528
+ if args.latent_channels is None:
529
+ args.latent_channels = vae_channels
530
+ if vae_channels != args.latent_channels:
531
+ raise ValueError(
532
+ f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels})."
533
+ )
534
+ return args
hyvideo/constants.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+
4
+ __all__ = [
5
+ "C_SCALE",
6
+ "PROMPT_TEMPLATE",
7
+ "MODEL_BASE",
8
+ "PRECISIONS",
9
+ "NORMALIZATION_TYPE",
10
+ "ACTIVATION_TYPE",
11
+ "VAE_PATH",
12
+ "TEXT_ENCODER_PATH",
13
+ "TOKENIZER_PATH",
14
+ "TEXT_PROJECTION",
15
+ "DATA_TYPE",
16
+ "NEGATIVE_PROMPT",
17
+ "NEGATIVE_PROMPT_I2V",
18
+ "FLOW_PATH_TYPE",
19
+ "FLOW_PREDICT_TYPE",
20
+ "FLOW_LOSS_WEIGHT",
21
+ "FLOW_SNR_TYPE",
22
+ "FLOW_SOLVER",
23
+ ]
24
+
25
+ PRECISION_TO_TYPE = {
26
+ 'fp32': torch.float32,
27
+ 'fp16': torch.float16,
28
+ 'bf16': torch.bfloat16,
29
+ }
30
+
31
+ # =================== Constant Values =====================
32
+ # Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
33
+ # overflow error when tensorboard logging values.
34
+ C_SCALE = 1_000_000_000_000_000
35
+
36
+ # When using decoder-only models, we must provide a prompt template to instruct the text encoder
37
+ # on how to generate the text.
38
+ # --------------------------------------------------------------------
39
+ PROMPT_TEMPLATE_ENCODE = (
40
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
41
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
42
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
43
+ )
44
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
45
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
46
+ "1. The main content and theme of the video."
47
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
48
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
49
+ "4. background environment, light, style and atmosphere."
50
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
51
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
52
+ )
53
+
54
+ PROMPT_TEMPLATE_ENCODE_I2V = (
55
+ "<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the image by detailing the color, shape, size, texture, "
56
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
57
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
58
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
59
+ )
60
+
61
+ PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
62
+ "<|start_header_id|>system<|end_header_id|>\n\n<image>\nDescribe the video by detailing the following aspects according to the reference image: "
63
+ "1. The main content and theme of the video."
64
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
65
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
66
+ "4. background environment, light, style and atmosphere."
67
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
68
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
69
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
70
+ )
71
+
72
+ NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
73
+ NEGATIVE_PROMPT_I2V = "deformation, a poor composition and deformed video, bad teeth, bad eyes, bad limbs"
74
+
75
+ PROMPT_TEMPLATE = {
76
+ "dit-llm-encode": {
77
+ "template": PROMPT_TEMPLATE_ENCODE,
78
+ "crop_start": 36,
79
+ },
80
+ "dit-llm-encode-video": {
81
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
82
+ "crop_start": 95,
83
+ },
84
+ "dit-llm-encode-i2v": {
85
+ "template": PROMPT_TEMPLATE_ENCODE_I2V,
86
+ "crop_start": 36,
87
+ "image_emb_start": 5,
88
+ "image_emb_end": 581,
89
+ "image_emb_len": 576,
90
+ "double_return_token_id": 271
91
+ },
92
+ "dit-llm-encode-video-i2v": {
93
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO_I2V,
94
+ "crop_start": 103,
95
+ "image_emb_start": 5,
96
+ "image_emb_end": 581,
97
+ "image_emb_len": 576,
98
+ "double_return_token_id": 271
99
+ },
100
+ }
101
+
102
+ # ======================= Model ======================
103
+ PRECISIONS = {"fp32", "fp16", "bf16"}
104
+ NORMALIZATION_TYPE = {"layer", "rms"}
105
+ ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
106
+
107
+ # =================== Model Path =====================
108
+ MODEL_BASE = os.getenv("MODEL_BASE", "./ckpts")
109
+
110
+ # =================== Data =======================
111
+ DATA_TYPE = {"image", "video", "image_video"}
112
+
113
+ # 3D VAE
114
+ VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
115
+
116
+ # Text Encoder
117
+ TEXT_ENCODER_PATH = {
118
+ "clipL": f"{MODEL_BASE}/clip_vit_large_patch14",
119
+ "llm": f"{MODEL_BASE}/llava-llama-3-8b",
120
+ "llm-i2v": f"{MODEL_BASE}/llava-llama-3-8b",
121
+ }
122
+
123
+ # Tokenizer
124
+ TOKENIZER_PATH = {
125
+ "clipL": f"{MODEL_BASE}/clip_vit_large_patch14",
126
+ "llm": f"{MODEL_BASE}/llava-llama-3-8b",
127
+ "llm-i2v": f"{MODEL_BASE}/llava-llama-3-8b",
128
+ }
129
+
130
+ TEXT_PROJECTION = {
131
+ "linear", # Default, an nn.Linear() layer
132
+ "single_refiner", # Single TokenRefiner. Refer to LI-DiT
133
+ }
134
+
135
+ # Flow Matching path type
136
+ FLOW_PATH_TYPE = {
137
+ "linear", # Linear trajectory between noise and data
138
+ "gvp", # Generalized variance-preserving SDE
139
+ "vp", # Variance-preserving SDE
140
+ }
141
+
142
+ # Flow Matching predict type
143
+ FLOW_PREDICT_TYPE = {
144
+ "velocity", # Predict velocity
145
+ "score", # Predict score
146
+ "noise", # Predict noise
147
+ }
148
+
149
+ # Flow Matching loss weight
150
+ FLOW_LOSS_WEIGHT = {
151
+ "velocity", # Weight loss by velocity
152
+ "likelihood", # Weight loss by likelihood
153
+ }
154
+
155
+ # Flow Matching SNR type
156
+ FLOW_SNR_TYPE = {
157
+ "lognorm", # Log-normal SNR
158
+ "uniform", # Uniform SNR
159
+ }
160
+
161
+ # Flow Matching solvers
162
+ FLOW_SOLVER = {
163
+ "euler", # Euler solver
164
+ }
hyvideo/data_kits/audio_dataset.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import math
4
+ import json
5
+ import torch
6
+ import random
7
+ import librosa
8
+ import traceback
9
+ import torchvision
10
+ import numpy as np
11
+ import pandas as pd
12
+ from PIL import Image
13
+ from einops import rearrange
14
+ from torch.utils.data import Dataset
15
+ from decord import VideoReader, cpu
16
+ from transformers import CLIPImageProcessor
17
+ import torchvision.transforms as transforms
18
+ from torchvision.transforms import ToPILImage
19
+
20
+
21
+
22
+ def get_audio_feature(feature_extractor, audio_path):
23
+ audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
24
+ assert sampling_rate == 16000
25
+
26
+ audio_features = []
27
+ window = 750*640
28
+ for i in range(0, len(audio_input), window):
29
+ audio_feature = feature_extractor(audio_input[i:i+window],
30
+ sampling_rate=sampling_rate,
31
+ return_tensors="pt",
32
+ ).input_features
33
+ audio_features.append(audio_feature)
34
+
35
+ audio_features = torch.cat(audio_features, dim=-1)
36
+ return audio_features, len(audio_input) // 640
37
+
38
+
39
+ class VideoAudioTextLoaderVal(Dataset):
40
+ def __init__(
41
+ self,
42
+ image_size: int,
43
+ meta_file: str,
44
+ **kwargs,
45
+ ):
46
+ super().__init__()
47
+ self.meta_file = meta_file
48
+ self.image_size = image_size
49
+ self.text_encoder = kwargs.get("text_encoder", None) # llava_text_encoder
50
+ self.text_encoder_2 = kwargs.get("text_encoder_2", None) # clipL_text_encoder
51
+ self.feature_extractor = kwargs.get("feature_extractor", None)
52
+ self.meta_files = []
53
+
54
+ csv_data = pd.read_csv(meta_file)
55
+ for idx in range(len(csv_data)):
56
+ self.meta_files.append(
57
+ {
58
+ "videoid": str(csv_data["videoid"][idx]),
59
+ "image_path": str(csv_data["image"][idx]),
60
+ "audio_path": str(csv_data["audio"][idx]),
61
+ "prompt": str(csv_data["prompt"][idx]),
62
+ "fps": float(csv_data["fps"][idx])
63
+ }
64
+ )
65
+
66
+ self.llava_transform = transforms.Compose(
67
+ [
68
+ transforms.Resize((336, 336), interpolation=transforms.InterpolationMode.BILINEAR),
69
+ transforms.ToTensor(),
70
+ transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
71
+ ]
72
+ )
73
+ self.clip_image_processor = CLIPImageProcessor()
74
+
75
+ self.device = torch.device("cuda")
76
+ self.weight_dtype = torch.float16
77
+
78
+
79
+ def __len__(self):
80
+ return len(self.meta_files)
81
+
82
+ @staticmethod
83
+ def get_text_tokens(text_encoder, description, dtype_encode="video"):
84
+ text_inputs = text_encoder.text2tokens(description, data_type=dtype_encode)
85
+ text_ids = text_inputs["input_ids"].squeeze(0)
86
+ text_mask = text_inputs["attention_mask"].squeeze(0)
87
+ return text_ids, text_mask
88
+
89
+ def get_batch_data(self, idx):
90
+ meta_file = self.meta_files[idx]
91
+ videoid = meta_file["videoid"]
92
+ image_path = meta_file["image_path"]
93
+ audio_path = meta_file["audio_path"]
94
+ prompt = "Authentic, Realistic, Natural, High-quality, Lens-Fixed, " + meta_file["prompt"]
95
+ fps = meta_file["fps"]
96
+
97
+ img_size = self.image_size
98
+ ref_image = Image.open(image_path).convert('RGB')
99
+
100
+ # Resize reference image
101
+ w, h = ref_image.size
102
+ scale = img_size / min(w, h)
103
+ new_w = round(w * scale / 64) * 64
104
+ new_h = round(h * scale / 64) * 64
105
+
106
+ if img_size == 704:
107
+ img_size_long = 1216
108
+ if new_w * new_h > img_size * img_size_long:
109
+ import math
110
+ scale = math.sqrt(img_size * img_size_long / w / h)
111
+ new_w = round(w * scale / 64) * 64
112
+ new_h = round(h * scale / 64) * 64
113
+
114
+ ref_image = ref_image.resize((new_w, new_h), Image.LANCZOS)
115
+
116
+ ref_image = np.array(ref_image)
117
+ ref_image = torch.from_numpy(ref_image)
118
+
119
+ audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_path)
120
+ audio_prompts = audio_input[0]
121
+
122
+ motion_bucket_id_heads = np.array([25] * 4)
123
+ motion_bucket_id_exps = np.array([30] * 4)
124
+ motion_bucket_id_heads = torch.from_numpy(motion_bucket_id_heads)
125
+ motion_bucket_id_exps = torch.from_numpy(motion_bucket_id_exps)
126
+ fps = torch.from_numpy(np.array(fps))
127
+
128
+ to_pil = ToPILImage()
129
+ pixel_value_ref = rearrange(ref_image.clone().unsqueeze(0), "b h w c -> b c h w") # (b c h w)
130
+
131
+ pixel_value_ref_llava = [self.llava_transform(to_pil(image)) for image in pixel_value_ref]
132
+ pixel_value_ref_llava = torch.stack(pixel_value_ref_llava, dim=0)
133
+ pixel_value_ref_clip = self.clip_image_processor(
134
+ images=Image.fromarray((pixel_value_ref[0].permute(1,2,0)).data.cpu().numpy().astype(np.uint8)),
135
+ return_tensors="pt"
136
+ ).pixel_values[0]
137
+ pixel_value_ref_clip = pixel_value_ref_clip.unsqueeze(0)
138
+
139
+ # Encode text prompts
140
+
141
+ text_ids, text_mask = self.get_text_tokens(self.text_encoder, prompt)
142
+ text_ids_2, text_mask_2 = self.get_text_tokens(self.text_encoder_2, prompt)
143
+
144
+ # Output batch
145
+ batch = {
146
+ "text_prompt": prompt, #
147
+ "videoid": videoid,
148
+ "pixel_value_ref": pixel_value_ref.to(dtype=torch.float16), # 参考图,用于vae提特征 (1, 3, h, w), 取值范围(0, 255)
149
+ "pixel_value_ref_llava": pixel_value_ref_llava.to(dtype=torch.float16), # 参考图,用于llava提特征 (1, 3, 336, 336), 取值范围 = CLIP取值范围
150
+ "pixel_value_ref_clip": pixel_value_ref_clip.to(dtype=torch.float16), # 参考图,用于clip_image_encoder提特征 (1, 3, 244, 244), 取值范围 = CLIP取值范围
151
+ "audio_prompts": audio_prompts.to(dtype=torch.float16),
152
+ "motion_bucket_id_heads": motion_bucket_id_heads.to(dtype=text_ids.dtype),
153
+ "motion_bucket_id_exps": motion_bucket_id_exps.to(dtype=text_ids.dtype),
154
+ "fps": fps.to(dtype=torch.float16),
155
+ "text_ids": text_ids.clone(), # 对应llava_text_encoder
156
+ "text_mask": text_mask.clone(), # 对应llava_text_encoder
157
+ "text_ids_2": text_ids_2.clone(), # 对应clip_text_encoder
158
+ "text_mask_2": text_mask_2.clone(), # 对应clip_text_encoder
159
+ "audio_len": audio_len,
160
+ "image_path": image_path,
161
+ "audio_path": audio_path,
162
+ }
163
+ return batch
164
+
165
+ def __getitem__(self, idx):
166
+ return self.get_batch_data(idx)
167
+
168
+
169
+
170
+
hyvideo/data_kits/audio_preprocessor.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import cv2
4
+ import json
5
+ import time
6
+ import decord
7
+ import einops
8
+ import librosa
9
+ import torch
10
+ import random
11
+ import argparse
12
+ import traceback
13
+ import numpy as np
14
+ from tqdm import tqdm
15
+ from PIL import Image
16
+ from einops import rearrange
17
+
18
+
19
+
20
+ def get_facemask(ref_image, align_instance, area=1.25):
21
+ # ref_image: (b f c h w)
22
+ bsz, f, c, h, w = ref_image.shape
23
+ images = rearrange(ref_image, "b f c h w -> (b f) h w c").data.cpu().numpy().astype(np.uint8)
24
+ face_masks = []
25
+ for image in images:
26
+ image_pil = Image.fromarray(image).convert("RGB")
27
+ _, _, bboxes_list = align_instance(np.array(image_pil)[:,:,[2,1,0]], maxface=True)
28
+ try:
29
+ bboxSrc = bboxes_list[0]
30
+ except:
31
+ bboxSrc = [0, 0, w, h]
32
+ x1, y1, ww, hh = bboxSrc
33
+ x2, y2 = x1 + ww, y1 + hh
34
+ ww, hh = (x2-x1) * area, (y2-y1) * area
35
+ center = [(x2+x1)//2, (y2+y1)//2]
36
+ x1 = max(center[0] - ww//2, 0)
37
+ y1 = max(center[1] - hh//2, 0)
38
+ x2 = min(center[0] + ww//2, w)
39
+ y2 = min(center[1] + hh//2, h)
40
+
41
+ face_mask = np.zeros_like(np.array(image_pil))
42
+ face_mask[int(y1):int(y2), int(x1):int(x2)] = 1.0
43
+ face_masks.append(torch.from_numpy(face_mask[...,:1]))
44
+ face_masks = torch.stack(face_masks, dim=0) # (b*f, h, w, c)
45
+ face_masks = rearrange(face_masks, "(b f) h w c -> b c f h w", b=bsz, f=f)
46
+ face_masks = face_masks.to(device=ref_image.device, dtype=ref_image.dtype)
47
+ return face_masks
48
+
49
+
50
+ def encode_audio(wav2vec, audio_feats, fps, num_frames=129):
51
+ if fps == 25:
52
+ start_ts = [0]
53
+ step_ts = [1]
54
+ elif fps == 12.5:
55
+ start_ts = [0]
56
+ step_ts = [2]
57
+ num_frames = min(num_frames, 400)
58
+ audio_feats = wav2vec.encoder(audio_feats.unsqueeze(0)[:, :, :3000], output_hidden_states=True).hidden_states
59
+ audio_feats = torch.stack(audio_feats, dim=2)
60
+ audio_feats = torch.cat([torch.zeros_like(audio_feats[:,:4]), audio_feats], 1)
61
+
62
+ audio_prompts = []
63
+ for bb in range(1):
64
+ audio_feats_list = []
65
+ for f in range(num_frames):
66
+ cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
67
+ audio_clip = audio_feats[bb:bb+1, cur_t: cur_t+10]
68
+ audio_feats_list.append(audio_clip)
69
+ audio_feats_list = torch.stack(audio_feats_list, 1)
70
+ audio_prompts.append(audio_feats_list)
71
+ audio_prompts = torch.cat(audio_prompts)
72
+ return audio_prompts
hyvideo/data_kits/data_tools.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import imageio
6
+ import torchvision
7
+ from einops import rearrange
8
+
9
+
10
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8, quality=8):
11
+ videos = rearrange(videos, "b c t h w -> t b c h w")
12
+ outputs = []
13
+ for x in videos:
14
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
15
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
16
+ if rescale:
17
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
18
+ x = torch.clamp(x,0,1)
19
+ x = (x * 255).numpy().astype(np.uint8)
20
+ outputs.append(x)
21
+
22
+ os.makedirs(os.path.dirname(path), exist_ok=True)
23
+ imageio.mimsave(path, outputs, fps=fps, quality=quality)
24
+
25
+ def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
26
+ crop_h, crop_w = crop_img.shape[:2]
27
+ target_w, target_h = size
28
+ scale_h, scale_w = target_h / crop_h, target_w / crop_w
29
+ if scale_w > scale_h:
30
+ resize_h = int(target_h*resize_ratio)
31
+ resize_w = int(crop_w / crop_h * resize_h)
32
+ else:
33
+ resize_w = int(target_w*resize_ratio)
34
+ resize_h = int(crop_h / crop_w * resize_w)
35
+ crop_img = cv2.resize(crop_img, (resize_w, resize_h))
36
+ pad_left = (target_w - resize_w) // 2
37
+ pad_top = (target_h - resize_h) // 2
38
+ pad_right = target_w - resize_w - pad_left
39
+ pad_bottom = target_h - resize_h - pad_top
40
+ crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
41
+ return crop_img
hyvideo/data_kits/face_align/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .align import AlignImage
hyvideo/data_kits/face_align/align.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ from .detface import DetFace
5
+
6
+ class AlignImage(object):
7
+ def __init__(self, device='cuda', det_path=''):
8
+ self.facedet = DetFace(pt_path=det_path, confThreshold=0.5, nmsThreshold=0.45, device=device)
9
+
10
+ @torch.no_grad()
11
+ def __call__(self, im, maxface=False):
12
+ bboxes, kpss, scores = self.facedet.detect(im)
13
+ face_num = bboxes.shape[0]
14
+
15
+ five_pts_list = []
16
+ scores_list = []
17
+ bboxes_list = []
18
+ for i in range(face_num):
19
+ five_pts_list.append(kpss[i].reshape(5,2))
20
+ scores_list.append(scores[i])
21
+ bboxes_list.append(bboxes[i])
22
+
23
+ if maxface and face_num>1:
24
+ max_idx = 0
25
+ max_area = (bboxes[0, 2])*(bboxes[0, 3])
26
+ for i in range(1, face_num):
27
+ area = (bboxes[i,2])*(bboxes[i,3])
28
+ if area>max_area:
29
+ max_idx = i
30
+ five_pts_list = [five_pts_list[max_idx]]
31
+ scores_list = [scores_list[max_idx]]
32
+ bboxes_list = [bboxes_list[max_idx]]
33
+
34
+ return five_pts_list, scores_list, bboxes_list
hyvideo/data_kits/face_align/detface.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: UTF-8 -*-
2
+ import os
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ import torchvision
7
+
8
+
9
+ def xyxy2xywh(x):
10
+ # Convert nx4 boxes from [x1, y1, x2, y2] to [x, y, w, h] where xy1=top-left, xy2=bottom-right
11
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
12
+ y[:, 0] = (x[:, 0] + x[:, 2]) / 2 # x center
13
+ y[:, 1] = (x[:, 1] + x[:, 3]) / 2 # y center
14
+ y[:, 2] = x[:, 2] - x[:, 0] # width
15
+ y[:, 3] = x[:, 3] - x[:, 1] # height
16
+ return y
17
+
18
+
19
+ def xywh2xyxy(x):
20
+ # Convert nx4 boxes from [x, y, w, h] to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
21
+ y = x.clone() if isinstance(x, torch.Tensor) else np.copy(x)
22
+ y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
23
+ y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
24
+ y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
25
+ y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
26
+ return y
27
+
28
+
29
+ def box_iou(box1, box2):
30
+ # https://github.com/pytorch/vision/blob/master/torchvision/ops/boxes.py
31
+ """
32
+ Return intersection-over-union (Jaccard index) of boxes.
33
+ Both sets of boxes are expected to be in (x1, y1, x2, y2) format.
34
+ Arguments:
35
+ box1 (Tensor[N, 4])
36
+ box2 (Tensor[M, 4])
37
+ Returns:
38
+ iou (Tensor[N, M]): the NxM matrix containing the pairwise
39
+ IoU values for every element in boxes1 and boxes2
40
+ """
41
+
42
+ def box_area(box):
43
+ # box = 4xn
44
+ return (box[2] - box[0]) * (box[3] - box[1])
45
+
46
+ area1 = box_area(box1.T)
47
+ area2 = box_area(box2.T)
48
+
49
+ # inter(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
50
+ inter = (torch.min(box1[:, None, 2:], box2[:, 2:]) -
51
+ torch.max(box1[:, None, :2], box2[:, :2])).clamp(0).prod(2)
52
+ # iou = inter / (area1 + area2 - inter)
53
+ return inter / (area1[:, None] + area2 - inter)
54
+
55
+
56
+ def scale_coords(img1_shape, coords, img0_shape, ratio_pad=None):
57
+ # Rescale coords (xyxy) from img1_shape to img0_shape
58
+ if ratio_pad is None: # calculate from img0_shape
59
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
60
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
61
+ else:
62
+ gain = ratio_pad[0][0]
63
+ pad = ratio_pad[1]
64
+
65
+ coords[:, [0, 2]] -= pad[0] # x padding
66
+ coords[:, [1, 3]] -= pad[1] # y padding
67
+ coords[:, :4] /= gain
68
+ clip_coords(coords, img0_shape)
69
+ return coords
70
+
71
+
72
+ def clip_coords(boxes, img_shape):
73
+ # Clip bounding xyxy bounding boxes to image shape (height, width)
74
+ boxes[:, 0].clamp_(0, img_shape[1]) # x1
75
+ boxes[:, 1].clamp_(0, img_shape[0]) # y1
76
+ boxes[:, 2].clamp_(0, img_shape[1]) # x2
77
+ boxes[:, 3].clamp_(0, img_shape[0]) # y2
78
+
79
+
80
+ def scale_coords_landmarks(img1_shape, coords, img0_shape, ratio_pad=None):
81
+ # Rescale coords (xyxy) from img1_shape to img0_shape
82
+ if ratio_pad is None: # calculate from img0_shape
83
+ gain = min(img1_shape[0] / img0_shape[0], img1_shape[1] / img0_shape[1]) # gain = old / new
84
+ pad = (img1_shape[1] - img0_shape[1] * gain) / 2, (img1_shape[0] - img0_shape[0] * gain) / 2 # wh padding
85
+ else:
86
+ gain = ratio_pad[0][0]
87
+ pad = ratio_pad[1]
88
+
89
+ coords[:, [0, 2, 4, 6, 8]] -= pad[0] # x padding
90
+ coords[:, [1, 3, 5, 7, 9]] -= pad[1] # y padding
91
+ coords[:, :10] /= gain
92
+ #clip_coords(coords, img0_shape)
93
+ coords[:, 0].clamp_(0, img0_shape[1]) # x1
94
+ coords[:, 1].clamp_(0, img0_shape[0]) # y1
95
+ coords[:, 2].clamp_(0, img0_shape[1]) # x2
96
+ coords[:, 3].clamp_(0, img0_shape[0]) # y2
97
+ coords[:, 4].clamp_(0, img0_shape[1]) # x3
98
+ coords[:, 5].clamp_(0, img0_shape[0]) # y3
99
+ coords[:, 6].clamp_(0, img0_shape[1]) # x4
100
+ coords[:, 7].clamp_(0, img0_shape[0]) # y4
101
+ coords[:, 8].clamp_(0, img0_shape[1]) # x5
102
+ coords[:, 9].clamp_(0, img0_shape[0]) # y5
103
+ return coords
104
+
105
+
106
+ def show_results(img, xywh, conf, landmarks, class_num):
107
+ h,w,c = img.shape
108
+ tl = 1 or round(0.002 * (h + w) / 2) + 1 # line/font thickness
109
+ x1 = int(xywh[0] * w - 0.5 * xywh[2] * w)
110
+ y1 = int(xywh[1] * h - 0.5 * xywh[3] * h)
111
+ x2 = int(xywh[0] * w + 0.5 * xywh[2] * w)
112
+ y2 = int(xywh[1] * h + 0.5 * xywh[3] * h)
113
+ cv2.rectangle(img, (x1,y1), (x2, y2), (0,255,0), thickness=tl, lineType=cv2.LINE_AA)
114
+
115
+ clors = [(255,0,0),(0,255,0),(0,0,255),(255,255,0),(0,255,255)]
116
+
117
+ for i in range(5):
118
+ point_x = int(landmarks[2 * i] * w)
119
+ point_y = int(landmarks[2 * i + 1] * h)
120
+ cv2.circle(img, (point_x, point_y), tl+1, clors[i], -1)
121
+
122
+ tf = max(tl - 1, 1) # font thickness
123
+ label = str(conf)[:5]
124
+ cv2.putText(img, label, (x1, y1 - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
125
+ return img
126
+
127
+
128
+ def make_divisible(x, divisor):
129
+ # Returns x evenly divisible by divisor
130
+ return (x // divisor) * divisor
131
+
132
+
133
+ def non_max_suppression_face(prediction, conf_thres=0.5, iou_thres=0.45, classes=None, agnostic=False, labels=()):
134
+ """Performs Non-Maximum Suppression (NMS) on inference results
135
+ Returns:
136
+ detections with shape: nx6 (x1, y1, x2, y2, conf, cls)
137
+ """
138
+
139
+ nc = prediction.shape[2] - 15 # number of classes
140
+ xc = prediction[..., 4] > conf_thres # candidates
141
+
142
+ # Settings
143
+ min_wh, max_wh = 2, 4096 # (pixels) minimum and maximum box width and height
144
+ # time_limit = 10.0 # seconds to quit after
145
+ redundant = True # require redundant detections
146
+ multi_label = nc > 1 # multiple labels per box (adds 0.5ms/img)
147
+ merge = False # use merge-NMS
148
+
149
+ # t = time.time()
150
+ output = [torch.zeros((0, 16), device=prediction.device)] * prediction.shape[0]
151
+ for xi, x in enumerate(prediction): # image index, image inference
152
+ # Apply constraints
153
+ # x[((x[..., 2:4] < min_wh) | (x[..., 2:4] > max_wh)).any(1), 4] = 0 # width-height
154
+ x = x[xc[xi]] # confidence
155
+
156
+ # Cat apriori labels if autolabelling
157
+ if labels and len(labels[xi]):
158
+ l = labels[xi]
159
+ v = torch.zeros((len(l), nc + 15), device=x.device)
160
+ v[:, :4] = l[:, 1:5] # box
161
+ v[:, 4] = 1.0 # conf
162
+ v[range(len(l)), l[:, 0].long() + 15] = 1.0 # cls
163
+ x = torch.cat((x, v), 0)
164
+
165
+ # If none remain process next image
166
+ if not x.shape[0]:
167
+ continue
168
+
169
+ # Compute conf
170
+ x[:, 15:] *= x[:, 4:5] # conf = obj_conf * cls_conf
171
+
172
+ # Box (center x, center y, width, height) to (x1, y1, x2, y2)
173
+ box = xywh2xyxy(x[:, :4])
174
+
175
+ # Detections matrix nx6 (xyxy, conf, landmarks, cls)
176
+ if multi_label:
177
+ i, j = (x[:, 15:] > conf_thres).nonzero(as_tuple=False).T
178
+ x = torch.cat((box[i], x[i, j + 15, None], x[i, 5:15] ,j[:, None].float()), 1)
179
+ else: # best class only
180
+ conf, j = x[:, 15:].max(1, keepdim=True)
181
+ x = torch.cat((box, conf, x[:, 5:15], j.float()), 1)[conf.view(-1) > conf_thres]
182
+
183
+ # Filter by class
184
+ if classes is not None:
185
+ x = x[(x[:, 5:6] == torch.tensor(classes, device=x.device)).any(1)]
186
+
187
+ # If none remain process next image
188
+ n = x.shape[0] # number of boxes
189
+ if not n:
190
+ continue
191
+
192
+ # Batched NMS
193
+ c = x[:, 15:16] * (0 if agnostic else max_wh) # classes
194
+ boxes, scores = x[:, :4] + c, x[:, 4] # boxes (offset by class), scores
195
+ i = torchvision.ops.nms(boxes, scores, iou_thres) # NMS
196
+ #if i.shape[0] > max_det: # limit detections
197
+ # i = i[:max_det]
198
+ if merge and (1 < n < 3E3): # Merge NMS (boxes merged using weighted mean)
199
+ # update boxes as boxes(i,4) = weights(i,n) * boxes(n,4)
200
+ iou = box_iou(boxes[i], boxes) > iou_thres # iou matrix
201
+ weights = iou * scores[None] # box weights
202
+ x[i, :4] = torch.mm(weights, x[:, :4]).float() / weights.sum(1, keepdim=True) # merged boxes
203
+ if redundant:
204
+ i = i[iou.sum(1) > 1] # require redundancy
205
+
206
+ output[xi] = x[i]
207
+ # if (time.time() - t) > time_limit:
208
+ # break # time limit exceeded
209
+
210
+ return output
211
+
212
+
213
+ class DetFace():
214
+ def __init__(self, pt_path, confThreshold=0.5, nmsThreshold=0.45, device='cuda'):
215
+ assert os.path.exists(pt_path)
216
+
217
+ self.inpSize = 416
218
+ self.conf_thres = confThreshold
219
+ self.iou_thres = nmsThreshold
220
+ self.test_device = torch.device(device if torch.cuda.is_available() else "cpu")
221
+ self.model = torch.jit.load(pt_path).to(self.test_device)
222
+ self.last_w = 416
223
+ self.last_h = 416
224
+ self.grids = None
225
+
226
+ @torch.no_grad()
227
+ def detect(self, srcimg):
228
+ # t0=time.time()
229
+
230
+ h0, w0 = srcimg.shape[:2] # orig hw
231
+ r = self.inpSize / min(h0, w0) # resize image to img_size
232
+ h1 = int(h0*r+31)//32*32
233
+ w1 = int(w0*r+31)//32*32
234
+
235
+ img = cv2.resize(srcimg, (w1,h1), interpolation=cv2.INTER_LINEAR)
236
+
237
+ # Convert
238
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # BGR to RGB
239
+
240
+ # Run inference
241
+ img = torch.from_numpy(img).to(self.test_device).permute(2,0,1)
242
+ img = img.float()/255 # uint8 to fp16/32 0-1
243
+ if img.ndimension() == 3:
244
+ img = img.unsqueeze(0)
245
+
246
+ # Inference
247
+ if h1 != self.last_h or w1 != self.last_w or self.grids is None:
248
+ grids = []
249
+ for scale in [8,16,32]:
250
+ ny = h1//scale
251
+ nx = w1//scale
252
+ yv, xv = torch.meshgrid([torch.arange(ny), torch.arange(nx)])
253
+ grid = torch.stack((xv, yv), 2).view((1,1,ny, nx, 2)).float()
254
+ grids.append(grid.to(self.test_device))
255
+ self.grids = grids
256
+ self.last_w = w1
257
+ self.last_h = h1
258
+
259
+ pred = self.model(img, self.grids).cpu()
260
+
261
+ # Apply NMS
262
+ det = non_max_suppression_face(pred, self.conf_thres, self.iou_thres)[0]
263
+ # Process detections
264
+ # det = pred[0]
265
+ bboxes = np.zeros((det.shape[0], 4))
266
+ kpss = np.zeros((det.shape[0], 5, 2))
267
+ scores = np.zeros((det.shape[0]))
268
+ # gn = torch.tensor([w0, h0, w0, h0]).to(pred) # normalization gain whwh
269
+ # gn_lks = torch.tensor([w0, h0, w0, h0, w0, h0, w0, h0, w0, h0]).to(pred) # normalization gain landmarks
270
+ det = det.cpu().numpy()
271
+
272
+ for j in range(det.shape[0]):
273
+ # xywh = (xyxy2xywh(det[j, :4].view(1, 4)) / gn).view(4).cpu().numpy()
274
+ bboxes[j, 0] = det[j, 0] * w0/w1
275
+ bboxes[j, 1] = det[j, 1] * h0/h1
276
+ bboxes[j, 2] = det[j, 2] * w0/w1 - bboxes[j, 0]
277
+ bboxes[j, 3] = det[j, 3] * h0/h1 - bboxes[j, 1]
278
+ scores[j] = det[j, 4]
279
+ # landmarks = (det[j, 5:15].view(1, 10) / gn_lks).view(5,2).cpu().numpy()
280
+ kpss[j, :, :] = det[j, 5:15].reshape(5, 2) * np.array([[w0/w1,h0/h1]])
281
+ # class_num = det[j, 15].cpu().numpy()
282
+ # orgimg = show_results(orgimg, xywh, conf, landmarks, class_num)
283
+ return bboxes, kpss, scores
hyvideo/diffusion/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .pipelines import HunyuanVideoPipeline
2
+ from .schedulers import FlowMatchDiscreteScheduler
hyvideo/diffusion/pipelines/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .pipeline_hunyuan_video import HunyuanVideoPipeline
2
+ from .pipeline_hunyuan_video_audio import HunyuanVideoAudioPipeline
hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py ADDED
@@ -0,0 +1,1421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import torch
22
+ import torch.distributed as dist
23
+ import numpy as np
24
+ from dataclasses import dataclass
25
+ from packaging import version
26
+
27
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
28
+ from diffusers.configuration_utils import FrozenDict
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.utils import BaseOutput
31
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
32
+ from diffusers.models import AutoencoderKL
33
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
34
+ from diffusers.schedulers import KarrasDiffusionSchedulers
35
+ from diffusers.utils import (
36
+ USE_PEFT_BACKEND,
37
+ deprecate,
38
+ logging,
39
+ replace_example_docstring,
40
+ scale_lora_layers,
41
+ unscale_lora_layers,
42
+ )
43
+ from diffusers.utils.torch_utils import randn_tensor
44
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
45
+ from diffusers.utils import BaseOutput
46
+
47
+ from ...constants import PRECISION_TO_TYPE
48
+ from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
49
+ from ...text_encoder import TextEncoder
50
+ from ...modules import HYVideoDiffusionTransformer
51
+ from mmgp import offload
52
+ from ...utils.data_utils import black_image
53
+ from einops import rearrange
54
+
55
+ EXAMPLE_DOC_STRING = """"""
56
+
57
+
58
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
59
+ """
60
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
61
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
62
+ """
63
+ std_text = noise_pred_text.std(
64
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
65
+ )
66
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
67
+ # rescale the results from guidance (fixes overexposure)
68
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
69
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
70
+ noise_cfg = (
71
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
72
+ )
73
+ return noise_cfg
74
+
75
+
76
+ def retrieve_timesteps(
77
+ scheduler,
78
+ num_inference_steps: Optional[int] = None,
79
+ device: Optional[Union[str, torch.device]] = None,
80
+ timesteps: Optional[List[int]] = None,
81
+ sigmas: Optional[List[float]] = None,
82
+ **kwargs,
83
+ ):
84
+ """
85
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
86
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
87
+
88
+ Args:
89
+ scheduler (`SchedulerMixin`):
90
+ The scheduler to get timesteps from.
91
+ num_inference_steps (`int`):
92
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
93
+ must be `None`.
94
+ device (`str` or `torch.device`, *optional*):
95
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
96
+ timesteps (`List[int]`, *optional*):
97
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
98
+ `num_inference_steps` and `sigmas` must be `None`.
99
+ sigmas (`List[float]`, *optional*):
100
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
101
+ `num_inference_steps` and `timesteps` must be `None`.
102
+
103
+ Returns:
104
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
105
+ second element is the number of inference steps.
106
+ """
107
+ if timesteps is not None and sigmas is not None:
108
+ raise ValueError(
109
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
110
+ )
111
+ if timesteps is not None:
112
+ accepts_timesteps = "timesteps" in set(
113
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
114
+ )
115
+ if not accepts_timesteps:
116
+ raise ValueError(
117
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
118
+ f" timestep schedules. Please check whether you are using the correct scheduler."
119
+ )
120
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
121
+ timesteps = scheduler.timesteps
122
+ num_inference_steps = len(timesteps)
123
+ elif sigmas is not None:
124
+ accept_sigmas = "sigmas" in set(
125
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
126
+ )
127
+ if not accept_sigmas:
128
+ raise ValueError(
129
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
130
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
131
+ )
132
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
133
+ timesteps = scheduler.timesteps
134
+ num_inference_steps = len(timesteps)
135
+ else:
136
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
137
+ timesteps = scheduler.timesteps
138
+ return timesteps, num_inference_steps
139
+
140
+
141
+ @dataclass
142
+ class HunyuanVideoPipelineOutput(BaseOutput):
143
+ videos: Union[torch.Tensor, np.ndarray]
144
+
145
+
146
+ class HunyuanVideoPipeline(DiffusionPipeline):
147
+ r"""
148
+ Pipeline for text-to-video generation using HunyuanVideo.
149
+
150
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
151
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
152
+
153
+ Args:
154
+ vae ([`AutoencoderKL`]):
155
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
156
+ text_encoder ([`TextEncoder`]):
157
+ Frozen text-encoder.
158
+ text_encoder_2 ([`TextEncoder`]):
159
+ Frozen text-encoder_2.
160
+ transformer ([`HYVideoDiffusionTransformer`]):
161
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
162
+ scheduler ([`SchedulerMixin`]):
163
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
164
+ """
165
+
166
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
167
+ _optional_components = ["text_encoder_2"]
168
+ _exclude_from_cpu_offload = ["transformer"]
169
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
170
+
171
+ def __init__(
172
+ self,
173
+ vae: AutoencoderKL,
174
+ text_encoder: TextEncoder,
175
+ transformer: HYVideoDiffusionTransformer,
176
+ scheduler: KarrasDiffusionSchedulers,
177
+ text_encoder_2: Optional[TextEncoder] = None,
178
+ progress_bar_config: Dict[str, Any] = None,
179
+ args=None,
180
+ ):
181
+ super().__init__()
182
+
183
+ # ==========================================================================================
184
+ if progress_bar_config is None:
185
+ progress_bar_config = {}
186
+ if not hasattr(self, "_progress_bar_config"):
187
+ self._progress_bar_config = {}
188
+ self._progress_bar_config.update(progress_bar_config)
189
+
190
+ self.args = args
191
+ # ==========================================================================================
192
+
193
+ if (
194
+ hasattr(scheduler.config, "steps_offset")
195
+ and scheduler.config.steps_offset != 1
196
+ ):
197
+ deprecation_message = (
198
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
199
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
200
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
201
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
202
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
203
+ " file"
204
+ )
205
+ deprecate(
206
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
207
+ )
208
+ new_config = dict(scheduler.config)
209
+ new_config["steps_offset"] = 1
210
+ scheduler._internal_dict = FrozenDict(new_config)
211
+
212
+ if (
213
+ hasattr(scheduler.config, "clip_sample")
214
+ and scheduler.config.clip_sample is True
215
+ ):
216
+ deprecation_message = (
217
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
218
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
219
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
220
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
221
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
222
+ )
223
+ deprecate(
224
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
225
+ )
226
+ new_config = dict(scheduler.config)
227
+ new_config["clip_sample"] = False
228
+ scheduler._internal_dict = FrozenDict(new_config)
229
+
230
+ self.register_modules(
231
+ vae=vae,
232
+ text_encoder=text_encoder,
233
+ transformer=transformer,
234
+ scheduler=scheduler,
235
+ text_encoder_2=text_encoder_2,
236
+ )
237
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
238
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
239
+ self.noise_pertub = 0
240
+
241
+ def encode_prompt(
242
+ self,
243
+ prompt,
244
+ name,
245
+ device,
246
+ num_videos_per_prompt,
247
+ do_classifier_free_guidance,
248
+ negative_prompt=None,
249
+ pixel_value_llava: Optional[torch.Tensor] = None,
250
+ uncond_pixel_value_llava: Optional[torch.Tensor] = None,
251
+ prompt_embeds: Optional[torch.Tensor] = None,
252
+ attention_mask: Optional[torch.Tensor] = None,
253
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
254
+ negative_attention_mask: Optional[torch.Tensor] = None,
255
+ lora_scale: Optional[float] = None,
256
+ clip_skip: Optional[int] = None,
257
+ text_encoder: Optional[TextEncoder] = None,
258
+ data_type: Optional[str] = "image",
259
+ semantic_images=None
260
+ ):
261
+ r"""
262
+ Encodes the prompt into text encoder hidden states.
263
+
264
+ Args:
265
+ prompt (`str` or `List[str]`, *optional*):
266
+ prompt to be encoded
267
+ device: (`torch.device`):
268
+ torch device
269
+ num_videos_per_prompt (`int`):
270
+ number of videos that should be generated per prompt
271
+ do_classifier_free_guidance (`bool`):
272
+ whether to use classifier free guidance or not
273
+ negative_prompt (`str` or `List[str]`, *optional*):
274
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
275
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
276
+ less than `1`).
277
+ pixel_value_llava (`torch.Tensor`, *optional*):
278
+ The image tensor for llava.
279
+ uncond_pixel_value_llava (`torch.Tensor`, *optional*):
280
+ The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
281
+ less than `1`).
282
+ prompt_embeds (`torch.Tensor`, *optional*):
283
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
284
+ provided, text embeddings will be generated from `prompt` input argument.
285
+ attention_mask (`torch.Tensor`, *optional*):
286
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
287
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
288
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
289
+ argument.
290
+ negative_attention_mask (`torch.Tensor`, *optional*):
291
+ lora_scale (`float`, *optional*):
292
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
293
+ clip_skip (`int`, *optional*):
294
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
295
+ the output of the pre-final layer will be used for computing the prompt embeddings.
296
+ text_encoder (TextEncoder, *optional*):
297
+ data_type (`str`, *optional*):
298
+ """
299
+ if text_encoder is None:
300
+ text_encoder = self.text_encoder
301
+
302
+ # set lora scale so that monkey patched LoRA
303
+ # function of text encoder can correctly access it
304
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
305
+ self._lora_scale = lora_scale
306
+
307
+ # dynamically adjust the LoRA scale
308
+ if not USE_PEFT_BACKEND:
309
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
310
+ else:
311
+ scale_lora_layers(text_encoder.model, lora_scale)
312
+
313
+ if prompt is not None and isinstance(prompt, str):
314
+ batch_size = 1
315
+ elif prompt is not None and isinstance(prompt, list):
316
+ batch_size = len(prompt)
317
+ else:
318
+ batch_size = prompt_embeds.shape[0]
319
+
320
+ if prompt_embeds is None:
321
+ # textual inversion: process multi-vector tokens if necessary
322
+ if isinstance(self, TextualInversionLoaderMixin):
323
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
324
+
325
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name = name)
326
+
327
+ if pixel_value_llava is not None:
328
+ text_inputs['pixel_value_llava'] = pixel_value_llava
329
+ text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575 * len(pixel_value_llava))).to(text_inputs['attention_mask'])], dim=1)
330
+
331
+ if clip_skip is None:
332
+ prompt_outputs = text_encoder.encode(
333
+ text_inputs, data_type=data_type, semantic_images=semantic_images, device=device
334
+ )
335
+ prompt_embeds = prompt_outputs.hidden_state
336
+ else:
337
+ prompt_outputs = text_encoder.encode(
338
+ text_inputs,
339
+ output_hidden_states=True,
340
+ data_type=data_type,
341
+ semantic_images=semantic_images,
342
+ device=device,
343
+ )
344
+ # Access the `hidden_states` first, that contains a tuple of
345
+ # all the hidden states from the encoder layers. Then index into
346
+ # the tuple to access the hidden states from the desired layer.
347
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
348
+ # We also need to apply the final LayerNorm here to not mess with the
349
+ # representations. The `last_hidden_states` that we typically use for
350
+ # obtaining the final prompt representations passes through the LayerNorm
351
+ # layer.
352
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(
353
+ prompt_embeds
354
+ )
355
+
356
+ attention_mask = prompt_outputs.attention_mask
357
+ if attention_mask is not None:
358
+ attention_mask = attention_mask.to(device)
359
+ bs_embed, seq_len = attention_mask.shape
360
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
361
+ attention_mask = attention_mask.view(
362
+ bs_embed * num_videos_per_prompt, seq_len
363
+ )
364
+
365
+ if text_encoder is not None:
366
+ prompt_embeds_dtype = text_encoder.dtype
367
+ elif self.transformer is not None:
368
+ prompt_embeds_dtype = self.transformer.dtype
369
+ else:
370
+ prompt_embeds_dtype = prompt_embeds.dtype
371
+
372
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
373
+
374
+ if prompt_embeds.ndim == 2:
375
+ bs_embed, _ = prompt_embeds.shape
376
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
377
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
378
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
379
+ else:
380
+ bs_embed, seq_len, _ = prompt_embeds.shape
381
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
382
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
383
+ prompt_embeds = prompt_embeds.view(
384
+ bs_embed * num_videos_per_prompt, seq_len, -1
385
+ )
386
+
387
+ # get unconditional embeddings for classifier free guidance
388
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
389
+ uncond_tokens: List[str]
390
+ if negative_prompt is None:
391
+ uncond_tokens = [""] * batch_size
392
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
393
+ raise TypeError(
394
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
395
+ f" {type(prompt)}."
396
+ )
397
+ elif isinstance(negative_prompt, str):
398
+ uncond_tokens = [negative_prompt]
399
+ elif batch_size != len(negative_prompt):
400
+ raise ValueError(
401
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
402
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
403
+ " the batch size of `prompt`."
404
+ )
405
+ else:
406
+ uncond_tokens = negative_prompt
407
+
408
+ # textual inversion: process multi-vector tokens if necessary
409
+ if isinstance(self, TextualInversionLoaderMixin):
410
+ uncond_tokens = self.maybe_convert_prompt(
411
+ uncond_tokens, text_encoder.tokenizer
412
+ )
413
+
414
+ # max_length = prompt_embeds.shape[1]
415
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type, name = name)
416
+
417
+ if semantic_images is not None:
418
+ uncond_image = [black_image(img.size[0], img.size[1]) for img in semantic_images]
419
+ else:
420
+ uncond_image = None
421
+
422
+ if uncond_pixel_value_llava is not None:
423
+ uncond_input['pixel_value_llava'] = uncond_pixel_value_llava
424
+ uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575 * len(uncond_pixel_value_llava))).to(uncond_input['attention_mask'])], dim=1)
425
+
426
+ negative_prompt_outputs = text_encoder.encode(
427
+ uncond_input, data_type=data_type, semantic_images=uncond_image, device=device
428
+ )
429
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
430
+
431
+ negative_attention_mask = negative_prompt_outputs.attention_mask
432
+ if negative_attention_mask is not None:
433
+ negative_attention_mask = negative_attention_mask.to(device)
434
+ _, seq_len = negative_attention_mask.shape
435
+ negative_attention_mask = negative_attention_mask.repeat(
436
+ 1, num_videos_per_prompt
437
+ )
438
+ negative_attention_mask = negative_attention_mask.view(
439
+ batch_size * num_videos_per_prompt, seq_len
440
+ )
441
+
442
+ if do_classifier_free_guidance:
443
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
444
+ seq_len = negative_prompt_embeds.shape[1]
445
+
446
+ negative_prompt_embeds = negative_prompt_embeds.to(
447
+ dtype=prompt_embeds_dtype, device=device
448
+ )
449
+
450
+ if negative_prompt_embeds.ndim == 2:
451
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
452
+ 1, num_videos_per_prompt
453
+ )
454
+ negative_prompt_embeds = negative_prompt_embeds.view(
455
+ batch_size * num_videos_per_prompt, -1
456
+ )
457
+ else:
458
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
459
+ 1, num_videos_per_prompt, 1
460
+ )
461
+ negative_prompt_embeds = negative_prompt_embeds.view(
462
+ batch_size * num_videos_per_prompt, seq_len, -1
463
+ )
464
+
465
+ if text_encoder is not None:
466
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
467
+ # Retrieve the original scale by scaling back the LoRA layers
468
+ unscale_lora_layers(text_encoder.model, lora_scale)
469
+
470
+ return (
471
+ prompt_embeds,
472
+ negative_prompt_embeds,
473
+ attention_mask,
474
+ negative_attention_mask,
475
+ )
476
+
477
+ def decode_latents(self, latents, enable_tiling=True):
478
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
479
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
480
+
481
+ latents = 1 / self.vae.config.scaling_factor * latents
482
+ if enable_tiling:
483
+ self.vae.enable_tiling()
484
+ image = self.vae.decode(latents, return_dict=False)[0]
485
+ else:
486
+ image = self.vae.decode(latents, return_dict=False)[0]
487
+ image = (image / 2 + 0.5).clamp(0, 1)
488
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
489
+ if image.ndim == 4:
490
+ image = image.cpu().permute(0, 2, 3, 1).float()
491
+ else:
492
+ image = image.cpu().float()
493
+ return image
494
+
495
+ def prepare_extra_func_kwargs(self, func, kwargs):
496
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
497
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
498
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
499
+ # and should be between [0, 1]
500
+ extra_step_kwargs = {}
501
+
502
+ for k, v in kwargs.items():
503
+ accepts = k in set(inspect.signature(func).parameters.keys())
504
+ if accepts:
505
+ extra_step_kwargs[k] = v
506
+ return extra_step_kwargs
507
+
508
+ def check_inputs(
509
+ self,
510
+ prompt,
511
+ height,
512
+ width,
513
+ video_length,
514
+ callback_steps,
515
+ pixel_value_llava=None,
516
+ uncond_pixel_value_llava=None,
517
+ negative_prompt=None,
518
+ prompt_embeds=None,
519
+ negative_prompt_embeds=None,
520
+ callback_on_step_end_tensor_inputs=None,
521
+ vae_ver="88-4c-sd",
522
+ ):
523
+ if height % 8 != 0 or width % 8 != 0:
524
+ raise ValueError(
525
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
526
+ )
527
+
528
+ if video_length is not None:
529
+ if "884" in vae_ver:
530
+ if video_length != 1 and (video_length - 1) % 4 != 0:
531
+ raise ValueError(
532
+ f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
533
+ )
534
+ elif "888" in vae_ver:
535
+ if video_length != 1 and (video_length - 1) % 8 != 0:
536
+ raise ValueError(
537
+ f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
538
+ )
539
+
540
+ if callback_steps is not None and (
541
+ not isinstance(callback_steps, int) or callback_steps <= 0
542
+ ):
543
+ raise ValueError(
544
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
545
+ f" {type(callback_steps)}."
546
+ )
547
+ if callback_on_step_end_tensor_inputs is not None and not all(
548
+ k in self._callback_tensor_inputs
549
+ for k in callback_on_step_end_tensor_inputs
550
+ ):
551
+ raise ValueError(
552
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
553
+ )
554
+
555
+ if prompt is not None and prompt_embeds is not None:
556
+ raise ValueError(
557
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
558
+ " only forward one of the two."
559
+ )
560
+ elif prompt is None and prompt_embeds is None:
561
+ raise ValueError(
562
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
563
+ )
564
+ elif prompt is not None and (
565
+ not isinstance(prompt, str) and not isinstance(prompt, list)
566
+ ):
567
+ raise ValueError(
568
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
569
+ )
570
+
571
+ if negative_prompt is not None and negative_prompt_embeds is not None:
572
+ raise ValueError(
573
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
574
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
575
+ )
576
+
577
+
578
+ if pixel_value_llava is not None and uncond_pixel_value_llava is not None:
579
+ if len(pixel_value_llava) != len(uncond_pixel_value_llava):
580
+ raise ValueError(
581
+ "`pixel_value_llava` and `uncond_pixel_value_llava` must have the same length when passed directly, but"
582
+ f" got: `pixel_value_llava` {len(pixel_value_llava)} != `uncond_pixel_value_llava`"
583
+ f" {len(uncond_pixel_value_llava)}."
584
+ )
585
+
586
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
587
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
588
+ raise ValueError(
589
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
590
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
591
+ f" {negative_prompt_embeds.shape}."
592
+ )
593
+
594
+ def get_timesteps(self, num_inference_steps, strength, device):
595
+ # get the original timestep using init_timestep
596
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
597
+
598
+ t_start = max(num_inference_steps - init_timestep, 0)
599
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
600
+ if hasattr(self.scheduler, "set_begin_index"):
601
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
602
+
603
+ return timesteps.to(device), num_inference_steps - t_start
604
+
605
+
606
+ def prepare_latents(
607
+ self,
608
+ batch_size,
609
+ num_channels_latents,
610
+ num_inference_steps,
611
+ height,
612
+ width,
613
+ video_length,
614
+ dtype,
615
+ device,
616
+ timesteps,
617
+ generator,
618
+ latents=None,
619
+ denoise_strength=1.0,
620
+ img_latents=None,
621
+ i2v_mode=False,
622
+ i2v_condition_type=None,
623
+ i2v_stability=True,
624
+ ):
625
+ if i2v_mode and i2v_condition_type == "latent_concat":
626
+ num_channels_latents = (num_channels_latents - 1) // 2
627
+ shape = (
628
+ batch_size,
629
+ num_channels_latents,
630
+ video_length,
631
+ int(height) // self.vae_scale_factor,
632
+ int(width) // self.vae_scale_factor,
633
+ )
634
+ if isinstance(generator, list) and len(generator) != batch_size:
635
+ raise ValueError(
636
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
637
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
638
+ )
639
+
640
+ if i2v_mode and i2v_stability:
641
+ if img_latents.shape[2] == 1:
642
+ img_latents = img_latents.repeat(1, 1, video_length, 1, 1)
643
+ x0 = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
644
+ x1 = img_latents
645
+
646
+ t = torch.tensor([0.999]).to(device=device)
647
+ latents = x0 * t + x1 * (1 - t)
648
+ latents = latents.to(dtype=dtype)
649
+
650
+ if denoise_strength == 0:
651
+ if latents is None:
652
+ latents = randn_tensor(
653
+ shape, generator=generator, device=device, dtype=dtype
654
+ )
655
+ else:
656
+ latents = latents.to(device)
657
+ original_latents = None
658
+ noise = None
659
+ timesteps = timesteps
660
+ else:
661
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
662
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, denoise_strength, device)
663
+
664
+ if latents is None:
665
+ latents = noise
666
+ original_latents = None
667
+ else:
668
+ latents = latents.to(device)
669
+ latent_timestep = timesteps[:1]
670
+ frames_needed = noise.shape[2]
671
+ current_frames = latents.shape[2]
672
+
673
+ if frames_needed > current_frames:
674
+ repeat_factor = frames_needed - current_frames
675
+ additional_frame = torch.randn((latents.size(0), latents.size(1),repeat_factor, latents.size(3), latents.size(4)), dtype=latents.dtype, device=latents.device)
676
+ latents = torch.cat((additional_frame, latents), dim=2)
677
+ self.additional_frames = repeat_factor
678
+ elif frames_needed < current_frames:
679
+ latents = latents[:, :, :frames_needed, :, :]
680
+
681
+ original_latents = latents.clone()
682
+ latents = latents * (1 - latent_timestep / 1000) + latent_timestep / 1000 * noise
683
+ print(f'debug:latent_timestep={latent_timestep}, latents-size={latents.shape}')
684
+
685
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
686
+ if hasattr(self.scheduler, "init_noise_sigma"):
687
+ # scale the initial noise by the standard deviation required by the scheduler
688
+ latents = latents * self.scheduler.init_noise_sigma
689
+ return latents, original_latents, noise, timesteps
690
+
691
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
692
+ def get_guidance_scale_embedding(
693
+ self,
694
+ w: torch.Tensor,
695
+ embedding_dim: int = 512,
696
+ dtype: torch.dtype = torch.float32,
697
+ ) -> torch.Tensor:
698
+ """
699
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
700
+
701
+ Args:
702
+ w (`torch.Tensor`):
703
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
704
+ embedding_dim (`int`, *optional*, defaults to 512):
705
+ Dimension of the embeddings to generate.
706
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
707
+ Data type of the generated embeddings.
708
+
709
+ Returns:
710
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
711
+ """
712
+ assert len(w.shape) == 1
713
+ w = w * 1000.0
714
+
715
+ half_dim = embedding_dim // 2
716
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
717
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
718
+ emb = w.to(dtype)[:, None] * emb[None, :]
719
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
720
+ if embedding_dim % 2 == 1: # zero pad
721
+ emb = torch.nn.functional.pad(emb, (0, 1))
722
+ assert emb.shape == (w.shape[0], embedding_dim)
723
+ return emb
724
+
725
+ @property
726
+ def guidance_scale(self):
727
+ return self._guidance_scale
728
+
729
+ @property
730
+ def guidance_rescale(self):
731
+ return self._guidance_rescale
732
+
733
+ @property
734
+ def clip_skip(self):
735
+ return self._clip_skip
736
+
737
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
738
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
739
+ # corresponds to doing no classifier free guidance.
740
+ @property
741
+ def do_classifier_free_guidance(self):
742
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
743
+ return self._guidance_scale > 1
744
+
745
+ @property
746
+ def cross_attention_kwargs(self):
747
+ return self._cross_attention_kwargs
748
+
749
+ @property
750
+ def num_timesteps(self):
751
+ return self._num_timesteps
752
+
753
+ @property
754
+ def interrupt(self):
755
+ return self._interrupt
756
+
757
+ @torch.no_grad()
758
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
759
+ def __call__(
760
+ self,
761
+ prompt: Union[str, List[str]],
762
+ height: int,
763
+ width: int,
764
+ video_length: int,
765
+ name: Union[str, List[str]] = None,
766
+ data_type: str = "video",
767
+ num_inference_steps: int = 50,
768
+ timesteps: List[int] = None,
769
+ sigmas: List[float] = None,
770
+ guidance_scale: float = 7.5,
771
+ negative_prompt: Optional[Union[str, List[str]]] = None,
772
+ pixel_value_ref=None,
773
+ # ref_latents: Optional[torch.Tensor] = None,
774
+ # uncond_ref_latents: Optional[torch.Tensor] = None,
775
+ pixel_value_llava: Optional[torch.Tensor] = None,
776
+ uncond_pixel_value_llava: Optional[torch.Tensor] = None,
777
+ ip_cfg_scale: float = 0.0,
778
+ use_deepcache: int = 1,
779
+ num_videos_per_prompt: Optional[int] = 1,
780
+ eta: float = 0.0,
781
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
782
+ latents: Optional[torch.Tensor] = None,
783
+ prompt_embeds: Optional[torch.Tensor] = None,
784
+ attention_mask: Optional[torch.Tensor] = None,
785
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
786
+ negative_attention_mask: Optional[torch.Tensor] = None,
787
+ output_type: Optional[str] = "pil",
788
+ return_dict: bool = True,
789
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
790
+ guidance_rescale: float = 0.0,
791
+ clip_skip: Optional[int] = None,
792
+ callback_on_step_end: Optional[
793
+ Union[
794
+ Callable[[int, int, Dict], None],
795
+ PipelineCallback,
796
+ MultiPipelineCallbacks,
797
+ ]
798
+ ] = None,
799
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
800
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
801
+ vae_ver: str = "88-4c-sd",
802
+ enable_tiling: bool = False,
803
+ n_tokens: Optional[int] = None,
804
+ video_val_flag: bool=False,
805
+ denoise_strength: float = 1.0,
806
+ mask = None,
807
+ embedded_guidance_scale: Optional[float] = None,
808
+ i2v_mode: bool = False,
809
+ i2v_condition_type: str = None,
810
+ i2v_stability: bool = True,
811
+ img_latents: Optional[torch.Tensor] = None,
812
+ semantic_images=None,
813
+ joint_pass = False,
814
+ cfg_star_rescale = False,
815
+ callback = None,
816
+ **kwargs,
817
+ ):
818
+ r"""
819
+ The call function to the pipeline for generation.
820
+
821
+ Args:
822
+ prompt (`str` or `List[str]`):
823
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
824
+ height (`int`):
825
+ The height in pixels of the generated image.
826
+ width (`int`):
827
+ The width in pixels of the generated image.
828
+ video_length (`int`):
829
+ The number of frames in the generated video.
830
+ num_inference_steps (`int`, *optional*, defaults to 50):
831
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
832
+ expense of slower inference.
833
+ timesteps (`List[int]`, *optional*):
834
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
835
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
836
+ passed will be used. Must be in descending order.
837
+ sigmas (`List[float]`, *optional*):
838
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
839
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
840
+ will be used.
841
+ guidance_scale (`float`, *optional*, defaults to 7.5):
842
+ A higher guidance scale value encourages the model to generate images closely linked to the text
843
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
844
+ negative_prompt (`str` or `List[str]`, *optional*):
845
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
846
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
847
+ ref_latents (`torch.Tensor`, *optional*):
848
+ The image tensor for time-concat.
849
+ uncond_ref_latents (`torch.Tensor`, *optional*):
850
+ The image tensor for time-concat. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
851
+ less than `1`).
852
+ pixel_value_llava (`torch.Tensor`, *optional*):
853
+ The image tensor for llava.
854
+ uncond_pixel_value_llava (`torch.Tensor`, *optional*):
855
+ The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
856
+ less than `1`).
857
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
858
+ The number of images to generate per prompt.
859
+ eta (`float`, *optional*, defaults to 0.0):
860
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
861
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
862
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
863
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
864
+ generation deterministic.
865
+ latents (`torch.Tensor`, *optional*):
866
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
867
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
868
+ tensor is generated by sampling using the supplied random `generator`.
869
+ prompt_embeds (`torch.Tensor`, *optional*):
870
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
871
+ provided, text embeddings are generated from the `prompt` input argument.
872
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
873
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
874
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
875
+
876
+ output_type (`str`, *optional*, defaults to `"pil"`):
877
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
878
+ return_dict (`bool`, *optional*, defaults to `True`):
879
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
880
+ plain tuple.
881
+ cross_attention_kwargs (`dict`, *optional*):
882
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
883
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
884
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
885
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
886
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
887
+ using zero terminal SNR.
888
+ clip_skip (`int`, *optional*):
889
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
890
+ the output of the pre-final layer will be used for computing the prompt embeddings.
891
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
892
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
893
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
894
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
895
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
896
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
897
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
898
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
899
+ `._callback_tensor_inputs` attribute of your pipeline class.
900
+
901
+ Examples:
902
+
903
+ Returns:
904
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
905
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
906
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
907
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
908
+ "not-safe-for-work" (nsfw) content.
909
+ """
910
+ callback_steps = kwargs.pop("callback_steps", None)
911
+
912
+ # if callback is not None:
913
+ # deprecate(
914
+ # "callback",
915
+ # "1.0.0",
916
+ # "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
917
+ # )
918
+ # if callback_steps is not None:
919
+ # deprecate(
920
+ # "callback_steps",
921
+ # "1.0.0",
922
+ # "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
923
+ # )
924
+
925
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
926
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
927
+
928
+ if pixel_value_ref != None:
929
+ pixel_value_ref = pixel_value_ref * 2 - 1.
930
+ pixel_value_ref_for_vae = rearrange(pixel_value_ref,"b c h w -> b c 1 h w")
931
+
932
+ ref_latents = self.vae.encode(pixel_value_ref_for_vae.clone()).latent_dist.sample()
933
+ uncond_ref_latents = self.vae.encode(torch.ones_like(pixel_value_ref_for_vae)).latent_dist.sample()
934
+ ref_latents.mul_(self.vae.config.scaling_factor)
935
+ uncond_ref_latents.mul_(self.vae.config.scaling_factor)
936
+ else:
937
+ ref_latents = None
938
+ uncond_ref_latents = None
939
+
940
+
941
+ # 0. Default height and width to unet
942
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
943
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
944
+ # to deal with lora scaling and other possible forward hooks
945
+ trans = self.transformer
946
+ if trans.enable_teacache:
947
+ teacache_multiplier = trans.teacache_multiplier
948
+ trans.accumulated_rel_l1_distance = 0
949
+ trans.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
950
+ # trans.teacache_start_step = int(tea_cache_start_step_perc*num_inference_steps/100)
951
+ # 1. Check inputs. Raise error if not correct
952
+ self.check_inputs(
953
+ prompt,
954
+ height,
955
+ width,
956
+ video_length,
957
+ callback_steps,
958
+ negative_prompt,
959
+ pixel_value_llava,
960
+ uncond_pixel_value_llava,
961
+ prompt_embeds,
962
+ negative_prompt_embeds,
963
+ callback_on_step_end_tensor_inputs,
964
+ vae_ver=vae_ver,
965
+ )
966
+
967
+ self._guidance_scale = guidance_scale
968
+ self._guidance_rescale = guidance_rescale
969
+ self._clip_skip = clip_skip
970
+ self._cross_attention_kwargs = cross_attention_kwargs
971
+ self._interrupt = False
972
+
973
+ # 2. Define call parameters
974
+ if prompt is not None and isinstance(prompt, str):
975
+ batch_size = 1
976
+ elif prompt is not None and isinstance(prompt, list):
977
+ batch_size = len(prompt)
978
+ else:
979
+ batch_size = prompt_embeds.shape[0]
980
+
981
+ device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
982
+
983
+ # 3. Encode input prompt
984
+ lora_scale = (
985
+ self.cross_attention_kwargs.get("scale", None)
986
+ if self.cross_attention_kwargs is not None
987
+ else None
988
+ )
989
+
990
+ (
991
+ prompt_embeds,
992
+ negative_prompt_embeds,
993
+ prompt_mask,
994
+ negative_prompt_mask,
995
+ ) = self.encode_prompt(
996
+ prompt,
997
+ name,
998
+ device,
999
+ num_videos_per_prompt,
1000
+ self.do_classifier_free_guidance,
1001
+ negative_prompt,
1002
+ pixel_value_llava=pixel_value_llava,
1003
+ uncond_pixel_value_llava=uncond_pixel_value_llava,
1004
+ prompt_embeds=prompt_embeds,
1005
+ attention_mask=attention_mask,
1006
+ negative_prompt_embeds=negative_prompt_embeds,
1007
+ negative_attention_mask=negative_attention_mask,
1008
+ lora_scale=lora_scale,
1009
+ clip_skip=self.clip_skip,
1010
+ data_type=data_type,
1011
+ semantic_images=semantic_images
1012
+ )
1013
+ if self.text_encoder_2 is not None:
1014
+ (
1015
+ prompt_embeds_2,
1016
+ negative_prompt_embeds_2,
1017
+ prompt_mask_2,
1018
+ negative_prompt_mask_2,
1019
+ ) = self.encode_prompt(
1020
+ prompt,
1021
+ name,
1022
+ device,
1023
+ num_videos_per_prompt,
1024
+ self.do_classifier_free_guidance,
1025
+ negative_prompt,
1026
+ prompt_embeds=None,
1027
+ attention_mask=None,
1028
+ negative_prompt_embeds=None,
1029
+ negative_attention_mask=None,
1030
+ lora_scale=lora_scale,
1031
+ clip_skip=self.clip_skip,
1032
+ text_encoder=self.text_encoder_2,
1033
+ data_type=data_type,
1034
+ )
1035
+ else:
1036
+ prompt_embeds_2 = None
1037
+ negative_prompt_embeds_2 = None
1038
+ prompt_mask_2 = None
1039
+ negative_prompt_mask_2 = None
1040
+
1041
+ # For classifier free guidance, we need to do two forward passes.
1042
+ # Here we concatenate the unconditional and text embeddings into a single batch
1043
+ # to avoid doing two forward passes
1044
+ if self.do_classifier_free_guidance:
1045
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1046
+ if prompt_mask is not None:
1047
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
1048
+ if prompt_embeds_2 is not None:
1049
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
1050
+ if prompt_mask_2 is not None:
1051
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
1052
+
1053
+ if self.do_classifier_free_guidance:
1054
+ if ref_latents is not None:
1055
+ ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
1056
+ if prompt_mask[0].sum() > 575:
1057
+ prompt_mask[0] = torch.cat([torch.ones((1, prompt_mask[0].sum() - 575)).to(prompt_mask),
1058
+ torch.zeros((1, prompt_mask.shape[1] - prompt_mask[0].sum() + 575)).to(prompt_mask)], dim=1)
1059
+
1060
+ if ip_cfg_scale>0:
1061
+ prompt_embeds = torch.cat([prompt_embeds, prompt_embeds[1:]])
1062
+ prompt_embeds_2 = torch.cat([prompt_embeds_2, prompt_embeds_2[1:]])
1063
+ prompt_mask = torch.cat([prompt_mask, prompt_mask[1:]], dim=0)
1064
+ ref_latents = torch.cat([uncond_ref_latents, uncond_ref_latents, ref_latents[1:]], dim=0)
1065
+
1066
+
1067
+ # 4. Prepare timesteps
1068
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
1069
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
1070
+ )
1071
+ timesteps, num_inference_steps = retrieve_timesteps(
1072
+ self.scheduler,
1073
+ num_inference_steps,
1074
+ device,
1075
+ timesteps,
1076
+ sigmas,
1077
+ **extra_set_timesteps_kwargs,
1078
+ )
1079
+
1080
+ if "884" in vae_ver:
1081
+ video_length = (video_length - 1) // 4 + 1
1082
+ elif "888" in vae_ver:
1083
+ video_length = (video_length - 1) // 8 + 1
1084
+ else:
1085
+ video_length = video_length
1086
+
1087
+ if self.transformer.mixed_precision:
1088
+ latent_dtype = torch.float32
1089
+ else:
1090
+ latent_dtype = torch.bfloat16
1091
+ if prompt_embeds != None:
1092
+ prompt_embeds = prompt_embeds.to(torch.bfloat16)
1093
+ if prompt_embeds_2 != None:
1094
+ prompt_embeds_2 = prompt_embeds_2.to(torch.bfloat16)
1095
+ # if prompt_mask != None:
1096
+ # prompt_mask = prompt_mask.to(torch.bfloat16)
1097
+ # 5. Prepare latent variables
1098
+ num_channels_latents = self.transformer.config.in_channels
1099
+ latents, original_latents, noise, timesteps = self.prepare_latents(
1100
+ batch_size * num_videos_per_prompt,
1101
+ num_channels_latents,
1102
+ num_inference_steps,
1103
+ height,
1104
+ width,
1105
+ video_length,
1106
+ latent_dtype, #prompt_embeds.dtype,
1107
+ device,
1108
+ timesteps,
1109
+ generator,
1110
+ latents,
1111
+ denoise_strength,
1112
+ img_latents=img_latents,
1113
+ i2v_mode=i2v_mode,
1114
+ i2v_condition_type=i2v_condition_type,
1115
+ i2v_stability=i2v_stability
1116
+ )
1117
+
1118
+ if i2v_mode and i2v_condition_type == "latent_concat":
1119
+ if img_latents.shape[2] == 1:
1120
+ img_latents_concat = img_latents.repeat(1, 1, video_length, 1, 1)
1121
+ else:
1122
+ img_latents_concat = img_latents
1123
+ img_latents_concat[:, :, 1:, ...] = 0
1124
+
1125
+ i2v_mask = torch.zeros(video_length)
1126
+ i2v_mask[0] = 1
1127
+
1128
+ mask_concat = torch.ones(img_latents_concat.shape[0], 1, img_latents_concat.shape[2], img_latents_concat.shape[3],
1129
+ img_latents_concat.shape[4]).to(device=img_latents.device)
1130
+ mask_concat[:, :, 1:, ...] = 0
1131
+
1132
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1133
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
1134
+ self.scheduler.step,
1135
+ {"generator": generator, "eta": eta},
1136
+ )
1137
+
1138
+ vae_precision = "fp16" # torch.float16
1139
+ precision = "bf16" # torch.bfloat16
1140
+
1141
+ disable_autocast = True
1142
+
1143
+ target_dtype = PRECISION_TO_TYPE[precision]
1144
+ autocast_enabled = target_dtype != torch.float32 and not disable_autocast
1145
+ vae_dtype = self.vae._model_dtype # PRECISION_TO_TYPE[vae_precision]
1146
+ vae_autocast_enabled = vae_dtype != torch.float32 and not disable_autocast
1147
+
1148
+ # 7. Denoising loop
1149
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1150
+ self._num_timesteps = len(timesteps)
1151
+ start_scale = ip_cfg_scale # 3.0
1152
+ end_scale = 1.0
1153
+ step_scale = (start_scale - end_scale) / (self._num_timesteps - 1 + 1e-3)
1154
+
1155
+ # print('sigmas used in generation:', self.scheduler.sigmas)
1156
+ # print('inference timesteps used in generation:', timesteps)
1157
+
1158
+
1159
+ # 8. Mask latents
1160
+ mask_latents = None
1161
+ if mask is not None:
1162
+ target_video_length = mask.shape[0]
1163
+ target_height = mask.shape[1]
1164
+ target_width = mask.shape[2]
1165
+
1166
+ mask_length = (target_video_length - 1) // 4 + 1
1167
+ mask_height = target_height // 8
1168
+ mask_width = target_width // 8
1169
+
1170
+ mask = mask[...,0:1]
1171
+ mask = mask.unsqueeze(0)
1172
+ mask = rearrange(mask, "b t h w c -> b c t h w")
1173
+
1174
+ mask_latents = torch.nn.functional.interpolate(mask, size=(mask_length, mask_height, mask_width))
1175
+ mask_latents = mask_latents.to(device)
1176
+
1177
+ if mask_latents is not None:
1178
+ mask_latents_model_input = (
1179
+ torch.cat([mask_latents] * 2)
1180
+ if self.do_classifier_free_guidance
1181
+ else mask_latents
1182
+ )
1183
+ print(f'maskinfo, mask={mask.shape}, mask_latents_model_input={mask_latents_model_input.shape} ')
1184
+
1185
+
1186
+ if callback != None:
1187
+ callback(-1, None, True)
1188
+
1189
+ load_latent = True
1190
+ load_latent = False
1191
+
1192
+ multi_passes_free_guidance = not joint_pass
1193
+ if load_latent:
1194
+ timesteps = []
1195
+
1196
+ latent_items = 2 if self.do_classifier_free_guidance else 1
1197
+ if ip_cfg_scale>0:
1198
+ latent_items += 1
1199
+
1200
+ if self.transformer.enable_teacache:
1201
+ self.transformer.previous_residual = [None] * latent_items
1202
+
1203
+ # if is_progress_bar:
1204
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1205
+ for i, t in enumerate(timesteps):
1206
+ offload.set_step_no_for_lora(self.transformer, i)
1207
+ if self.interrupt:
1208
+ continue
1209
+ if i2v_mode and i2v_condition_type == "token_replace":
1210
+ latents = torch.concat([img_latents, latents[:, :, 1:, :, :]], dim=2)
1211
+
1212
+ # expand the latents if we are doing classifier free guidance
1213
+ if i2v_mode and i2v_condition_type == "latent_concat":
1214
+ latent_model_input = torch.concat([latents, img_latents_concat, mask_concat], dim=1)
1215
+ else:
1216
+ latent_model_input = latents
1217
+
1218
+ latent_model_input = torch.cat([latent_model_input] * latent_items) if latent_items > 1 else latent_model_input
1219
+
1220
+ latent_model_input = self.scheduler.scale_model_input(
1221
+ latent_model_input, t
1222
+ )
1223
+
1224
+ if mask_latents is not None:
1225
+ original_latents_noise = original_latents * (1 - t / 1000.0) + t / 1000.0 * noise
1226
+ original_latent_noise_model_input = (
1227
+ torch.cat([original_latents_noise] * 2)
1228
+ if self.do_classifier_free_guidance
1229
+ else original_latents_noise
1230
+ )
1231
+ original_latent_noise_model_input = self.scheduler.scale_model_input(original_latent_noise_model_input, t)
1232
+ latent_model_input = mask_latents_model_input * latent_model_input + (1 - mask_latents_model_input) * original_latent_noise_model_input
1233
+
1234
+ t_expand = t.repeat(latent_model_input.shape[0])
1235
+ guidance_expand = (
1236
+ torch.tensor(
1237
+ [embedded_guidance_scale] * latent_model_input.shape[0],
1238
+ dtype=torch.float32,
1239
+ device=device,
1240
+ ).to(latent_dtype)
1241
+ * 1000.0
1242
+ if embedded_guidance_scale is not None
1243
+ else None
1244
+ )
1245
+
1246
+ # predict the noise residual
1247
+ with torch.autocast(
1248
+ device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
1249
+ ):
1250
+
1251
+ if self.do_classifier_free_guidance and multi_passes_free_guidance:
1252
+ for j in range(len(latent_model_input)):
1253
+ ret = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
1254
+ latent_model_input[j].unsqueeze(0), # [2, 16, 33, 24, 42]
1255
+ t_expand[j].unsqueeze(0), # [2]
1256
+ text_states=prompt_embeds[j].unsqueeze(0), # [2, 256, 4096]
1257
+ text_mask=prompt_mask[j].unsqueeze(0), # [2, 256]
1258
+ text_states_2=prompt_embeds_2[j].unsqueeze(0), # [2, 768]
1259
+ ref_latents=ref_latents[j].unsqueeze(0),
1260
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
1261
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
1262
+ guidance=guidance_expand,
1263
+ pipeline=self,
1264
+ x_id=j,
1265
+ step_no=i,
1266
+ callback = callback,
1267
+ )
1268
+ if self._interrupt:
1269
+ return [None]
1270
+ if j==0:
1271
+ noise_pred_uncond= ret[0]
1272
+ elif j==1:
1273
+ noise_pred_text= ret[0]
1274
+ else:
1275
+ noise_pred_ip = ret[0]
1276
+ ret = None
1277
+ else:
1278
+ # if self.do_classifier_free_guidance:
1279
+ # noise_pred_uncond = self.transformer(latent_model_input[:1], t_expand[:1], ref_latents=ref_latents[:1], text_states=prompt_embeds[:1], text_mask=prompt_mask[:1], text_states_2=prompt_embeds_2[:1], freqs_cos=freqs_cis[0],freqs_sin=freqs_cis[1], guidance=guidance_expand,return_dict=True)['x']
1280
+ # noise_pred_text = self.transformer(latent_model_input[1:], t_expand[1:], ref_latents=ref_latents[1:], text_states=prompt_embeds[1:], text_mask=prompt_mask[1:], text_states_2=prompt_embeds_2[1:], freqs_cos=freqs_cis[0],freqs_sin=freqs_cis[1], guidance=guidance_expand,return_dict=True)['x']
1281
+ # noise_pred = torch.cat([noise_pred_uncond, noise_pred_text], dim=0)
1282
+ # else:
1283
+ ret = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
1284
+ latent_model_input, # [2, 16, 33, 24, 42]
1285
+ t_expand, # [2]
1286
+ text_states=prompt_embeds, # [2, 256, 4096]
1287
+ text_mask=prompt_mask, # [2, 256]
1288
+ text_states_2=prompt_embeds_2, # [2, 768]
1289
+ ref_latents=ref_latents,
1290
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
1291
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
1292
+ guidance=guidance_expand,
1293
+ pipeline=self,
1294
+ step_no=i,
1295
+ callback = callback,
1296
+ )
1297
+ if self._interrupt:
1298
+ return [None]
1299
+ if self.do_classifier_free_guidance :
1300
+ if ip_cfg_scale > 0:
1301
+ noise_pred_uncond, noise_pred_text, noise_pred_ip = ret
1302
+ else:
1303
+ noise_pred_uncond, noise_pred_text = noise_pred = ret
1304
+ else:
1305
+ noise_pred = ret[0]
1306
+
1307
+ # perform guidance
1308
+ if self.do_classifier_free_guidance:
1309
+ if cfg_star_rescale:
1310
+ batch_size = 1
1311
+ positive_flat = noise_pred_text.view(batch_size, -1)
1312
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
1313
+ dot_product = torch.sum(
1314
+ positive_flat * negative_flat, dim=1, keepdim=True
1315
+ )
1316
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
1317
+ positive_flat, negative_flat = None, None
1318
+ alpha = dot_product / squared_norm
1319
+ noise_pred_uncond *= alpha
1320
+
1321
+ if ip_cfg_scale > 0:
1322
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + start_scale * (noise_pred_ip-noise_pred_text)
1323
+ start_scale -= step_scale
1324
+ if i==0:
1325
+ print(f'i={i}, noise_pred shape={noise_pred.shape}')
1326
+ else:
1327
+ noise_pred = noise_pred_uncond + self.guidance_scale * ( noise_pred_text - noise_pred_uncond)
1328
+
1329
+
1330
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1331
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1332
+ noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale, )
1333
+
1334
+ # compute the previous noisy sample x_t -> x_t-1
1335
+ if i2v_mode and i2v_condition_type == "token_replace":
1336
+ noise_pred = noise_pred.unsqueeze(0)
1337
+ latents = self.scheduler.step(
1338
+ noise_pred[:, :, 1:, :, :], t, latents[:, :, 1:, :, :], **extra_step_kwargs, return_dict=False
1339
+ )[0]
1340
+ latents = torch.concat(
1341
+ [img_latents, latents], dim=2
1342
+ )
1343
+ else:
1344
+ latents = self.scheduler.step(
1345
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1346
+ )[0]
1347
+
1348
+
1349
+ noise_pred_uncond, noise_pred_text, noise_pred, noise_pred_ip, ret = None, None, None, None, None
1350
+
1351
+ if callback is not None:
1352
+ callback(i, latents.squeeze(0), False)
1353
+
1354
+ if self.interrupt:
1355
+ return [None]
1356
+
1357
+ # if load_latent:
1358
+ # latents = torch.load("latent.pt")
1359
+ # else:
1360
+ # torch.save(latents, "latent.pt")
1361
+
1362
+
1363
+ if mask_latents is not None:
1364
+ latents = mask_latents * latents + (1 - mask_latents) * original_latents
1365
+
1366
+ if not output_type == "latent":
1367
+ expand_temporal_dim = False
1368
+ if len(latents.shape) == 4:
1369
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1370
+ latents = latents.unsqueeze(2)
1371
+ expand_temporal_dim = True
1372
+ elif len(latents.shape) == 5:
1373
+ pass
1374
+ else:
1375
+ raise ValueError(
1376
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
1377
+ )
1378
+
1379
+ if (
1380
+ hasattr(self.vae.config, "shift_factor")
1381
+ and self.vae.config.shift_factor
1382
+ ):
1383
+ latents = (
1384
+ latents / self.vae.config.scaling_factor
1385
+ + self.vae.config.shift_factor
1386
+ )
1387
+ else:
1388
+ latents = latents / self.vae.config.scaling_factor
1389
+
1390
+ with torch.autocast(
1391
+ device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
1392
+ ):
1393
+ if enable_tiling:
1394
+ self.vae.enable_tiling()
1395
+ image = self.vae.decode(
1396
+ latents, return_dict=False, generator=generator
1397
+ )[0]
1398
+ else:
1399
+ image = self.vae.decode(
1400
+ latents, return_dict=False, generator=generator
1401
+ )[0]
1402
+
1403
+ if expand_temporal_dim or image.shape[2] == 1:
1404
+ image = image.squeeze(2)
1405
+
1406
+ else:
1407
+ image = latents
1408
+
1409
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1410
+ image = image.cpu().float()
1411
+
1412
+ if i2v_mode and i2v_condition_type == "latent_concat":
1413
+ image = image[:, :, 4:, :, :]
1414
+
1415
+ # Offload all models
1416
+ self.maybe_free_model_hooks()
1417
+
1418
+ if not return_dict:
1419
+ return image
1420
+
1421
+ return HunyuanVideoPipelineOutput(videos=image)
hyvideo/diffusion/pipelines/pipeline_hunyuan_video_audio.py ADDED
@@ -0,0 +1,1359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import numpy as np
22
+ import torch
23
+ from packaging import version
24
+ from diffusers.utils import BaseOutput
25
+ from dataclasses import dataclass
26
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
27
+ from diffusers.configuration_utils import FrozenDict
28
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
29
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
30
+ from diffusers.models import AutoencoderKL, ImageProjection
31
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
32
+ from diffusers.schedulers import KarrasDiffusionSchedulers
33
+ from diffusers.utils import (
34
+ USE_PEFT_BACKEND,
35
+ deprecate,
36
+ logging,
37
+ replace_example_docstring,
38
+ scale_lora_layers,
39
+ unscale_lora_layers,
40
+ )
41
+ from diffusers.utils.torch_utils import randn_tensor
42
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
43
+
44
+ from hyvideo.constants import PRECISION_TO_TYPE
45
+ from hyvideo.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
46
+ from hyvideo.text_encoder import TextEncoder
47
+ from einops import rearrange
48
+ from ...modules import HYVideoDiffusionTransformer
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ EXAMPLE_DOC_STRING = """"""
53
+
54
+
55
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
56
+ """
57
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
58
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
59
+ """
60
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
61
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
62
+ # rescale the results from guidance (fixes overexposure)
63
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
64
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
65
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
66
+ return noise_cfg
67
+
68
+
69
+ def retrieve_timesteps(
70
+ scheduler,
71
+ num_inference_steps: Optional[int] = None,
72
+ device: Optional[Union[str, torch.device]] = None,
73
+ timesteps: Optional[List[int]] = None,
74
+ sigmas: Optional[List[float]] = None,
75
+ **kwargs,
76
+ ):
77
+ """
78
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
79
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
80
+
81
+ Args:
82
+ scheduler (`SchedulerMixin`):
83
+ The scheduler to get timesteps from.
84
+ num_inference_steps (`int`):
85
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
86
+ must be `None`.
87
+ device (`str` or `torch.device`, *optional*):
88
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
89
+ timesteps (`List[int]`, *optional*):
90
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
91
+ `num_inference_steps` and `sigmas` must be `None`.
92
+ sigmas (`List[float]`, *optional*):
93
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
94
+ `num_inference_steps` and `timesteps` must be `None`.
95
+
96
+ Returns:
97
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
98
+ second element is the number of inference steps.
99
+ """
100
+ if timesteps is not None and sigmas is not None:
101
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
102
+ if timesteps is not None:
103
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
104
+ if not accepts_timesteps:
105
+ raise ValueError(
106
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
107
+ f" timestep schedules. Please check whether you are using the correct scheduler."
108
+ )
109
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
110
+ timesteps = scheduler.timesteps
111
+ num_inference_steps = len(timesteps)
112
+ elif sigmas is not None:
113
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
114
+ if not accept_sigmas:
115
+ raise ValueError(
116
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
118
+ )
119
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
120
+ timesteps = scheduler.timesteps
121
+ num_inference_steps = len(timesteps)
122
+ else:
123
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
124
+ timesteps = scheduler.timesteps
125
+ return timesteps, num_inference_steps
126
+
127
+ @dataclass
128
+ class HunyuanVideoPipelineOutput(BaseOutput):
129
+ videos: Union[torch.Tensor, np.ndarray]
130
+
131
+
132
+ class HunyuanVideoAudioPipeline(DiffusionPipeline):
133
+ r"""
134
+ Pipeline for text-to-video generation using HunyuanVideo.
135
+
136
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
137
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
138
+
139
+ Args:
140
+ vae ([`AutoencoderKL`]):
141
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
142
+ text_encoder ([`TextEncoder`]):
143
+ Frozen text-encoder.
144
+ text_encoder_2 ([`TextEncoder`]):
145
+ Frozen text-encoder_2.
146
+ transformer ([`HYVideoDiffusionTransformer`]):
147
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
148
+ scheduler ([`SchedulerMixin`]):
149
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
150
+ """
151
+
152
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
153
+ _optional_components = ["text_encoder_2"]
154
+ _exclude_from_cpu_offload = ["transformer"]
155
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
156
+
157
+ def __init__(
158
+ self,
159
+ vae: AutoencoderKL,
160
+ text_encoder: TextEncoder,
161
+ transformer: HYVideoDiffusionTransformer,
162
+ scheduler: KarrasDiffusionSchedulers,
163
+ text_encoder_2: Optional[TextEncoder] = None,
164
+ progress_bar_config: Dict[str, Any] = None,
165
+ args=None,
166
+ ):
167
+ super().__init__()
168
+
169
+ # ==========================================================================================
170
+ if progress_bar_config is None:
171
+ progress_bar_config = {}
172
+ if not hasattr(self, '_progress_bar_config'):
173
+ self._progress_bar_config = {}
174
+ self._progress_bar_config.update(progress_bar_config)
175
+
176
+ self.args = args
177
+ # ==========================================================================================
178
+
179
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
180
+ deprecation_message = (
181
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
182
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
183
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
184
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
185
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
186
+ " file"
187
+ )
188
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
189
+ new_config = dict(scheduler.config)
190
+ new_config["steps_offset"] = 1
191
+ scheduler._internal_dict = FrozenDict(new_config)
192
+
193
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
194
+ deprecation_message = (
195
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
196
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
197
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
198
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
199
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
200
+ )
201
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
202
+ new_config = dict(scheduler.config)
203
+ new_config["clip_sample"] = False
204
+ scheduler._internal_dict = FrozenDict(new_config)
205
+
206
+ self.register_modules(
207
+ vae=vae,
208
+ text_encoder=text_encoder,
209
+ transformer=transformer,
210
+ scheduler=scheduler,
211
+ text_encoder_2=text_encoder_2
212
+ )
213
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
214
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
215
+
216
+ def encode_prompt(
217
+ self,
218
+ prompt,
219
+ name,
220
+ device,
221
+ num_videos_per_prompt,
222
+ do_classifier_free_guidance,
223
+ negative_prompt=None,
224
+ pixel_value_llava: Optional[torch.Tensor] = None,
225
+ uncond_pixel_value_llava: Optional[torch.Tensor] = None,
226
+ prompt_embeds: Optional[torch.Tensor] = None,
227
+ attention_mask: Optional[torch.Tensor] = None,
228
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
229
+ negative_attention_mask: Optional[torch.Tensor] = None,
230
+ lora_scale: Optional[float] = None,
231
+ clip_skip: Optional[int] = None,
232
+ text_encoder: Optional[TextEncoder] = None,
233
+ data_type: Optional[str] = "image",
234
+ ):
235
+ r"""
236
+ Encodes the prompt into text encoder hidden states.
237
+
238
+ Args:
239
+ prompt (`str` or `List[str]`, *optional*):
240
+ prompt to be encoded
241
+ device: (`torch.device`):
242
+ torch device
243
+ num_videos_per_prompt (`int`):
244
+ number of images that should be generated per prompt
245
+ do_classifier_free_guidance (`bool`):
246
+ whether to use classifier free guidance or not
247
+ negative_prompt (`str` or `List[str]`, *optional*):
248
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
249
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
250
+ less than `1`).
251
+ pixel_value_llava (`torch.Tensor`, *optional*):
252
+ The image tensor for llava.
253
+ uncond_pixel_value_llava (`torch.Tensor`, *optional*):
254
+ The image tensor for llava. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
255
+ less than `1`).
256
+ prompt_embeds (`torch.Tensor`, *optional*):
257
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
258
+ provided, text embeddings will be generated from `prompt` input argument.
259
+ attention_mask (`torch.Tensor`, *optional*):
260
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
261
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
262
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
263
+ argument.
264
+ negative_attention_mask (`torch.Tensor`, *optional*):
265
+ lora_scale (`float`, *optional*):
266
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
267
+ clip_skip (`int`, *optional*):
268
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
269
+ the output of the pre-final layer will be used for computing the prompt embeddings.
270
+ text_encoder (TextEncoder, *optional*):
271
+ """
272
+ if text_encoder is None:
273
+ text_encoder = self.text_encoder
274
+
275
+ # set lora scale so that monkey patched LoRA
276
+ # function of text encoder can correctly access it
277
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
278
+ self._lora_scale = lora_scale
279
+
280
+ # dynamically adjust the LoRA scale
281
+ if not USE_PEFT_BACKEND:
282
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
283
+ else:
284
+ scale_lora_layers(text_encoder.model, lora_scale)
285
+
286
+ if prompt is not None and isinstance(prompt, str):
287
+ batch_size = 1
288
+ elif prompt is not None and isinstance(prompt, list):
289
+ batch_size = len(prompt)
290
+ else:
291
+ batch_size = prompt_embeds.shape[0]
292
+
293
+ if prompt_embeds is None:
294
+ # textual inversion: process multi-vector tokens if necessary
295
+ if isinstance(self, TextualInversionLoaderMixin):
296
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
297
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name=name)
298
+
299
+ if pixel_value_llava is not None:
300
+ text_inputs['pixel_value_llava'] = pixel_value_llava
301
+ text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575 * len(pixel_value_llava))).to(text_inputs['attention_mask'])], dim=1)
302
+
303
+ if clip_skip is None:
304
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
305
+ prompt_embeds = prompt_outputs.hidden_state
306
+ else:
307
+ prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type)
308
+ # Access the `hidden_states` first, that contains a tuple of
309
+ # all the hidden states from the encoder layers. Then index into
310
+ # the tuple to access the hidden states from the desired layer.
311
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
312
+ # We also need to apply the final LayerNorm here to not mess with the
313
+ # representations. The `last_hidden_states` that we typically use for
314
+ # obtaining the final prompt representations passes through the LayerNorm
315
+ # layer.
316
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
317
+
318
+ attention_mask = prompt_outputs.attention_mask
319
+ if attention_mask is not None:
320
+ attention_mask = attention_mask.to(device)
321
+ bs_embed, seq_len = attention_mask.shape
322
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
323
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
324
+
325
+ if text_encoder is not None:
326
+ prompt_embeds_dtype = text_encoder.dtype
327
+ elif self.transformer is not None:
328
+ prompt_embeds_dtype = self.transformer.dtype
329
+ else:
330
+ prompt_embeds_dtype = prompt_embeds.dtype
331
+
332
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
333
+
334
+ if prompt_embeds.ndim == 2:
335
+ bs_embed, _ = prompt_embeds.shape
336
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
337
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
338
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
339
+ else:
340
+ bs_embed, seq_len, _ = prompt_embeds.shape
341
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
342
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
343
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
344
+
345
+ # get unconditional embeddings for classifier free guidance
346
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
347
+ uncond_tokens: List[str]
348
+ if negative_prompt is None:
349
+ uncond_tokens = [""] * batch_size
350
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
351
+ raise TypeError(
352
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
353
+ f" {type(prompt)}."
354
+ )
355
+ elif isinstance(negative_prompt, str):
356
+ uncond_tokens = [negative_prompt]
357
+ elif batch_size != len(negative_prompt):
358
+ raise ValueError(
359
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
360
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
361
+ " the batch size of `prompt`."
362
+ )
363
+ else:
364
+ uncond_tokens = negative_prompt
365
+
366
+ # textual inversion: process multi-vector tokens if necessary
367
+ if isinstance(self, TextualInversionLoaderMixin):
368
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer)
369
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
370
+ if uncond_pixel_value_llava is not None:
371
+ uncond_input['pixel_value_llava'] = uncond_pixel_value_llava
372
+ uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575 * len(uncond_pixel_value_llava))).to(uncond_input['attention_mask'])], dim=1)
373
+
374
+ negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type)
375
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
376
+
377
+ negative_attention_mask = negative_prompt_outputs.attention_mask
378
+ if negative_attention_mask is not None:
379
+ negative_attention_mask = negative_attention_mask.to(device)
380
+ _, seq_len = negative_attention_mask.shape
381
+ negative_attention_mask = negative_attention_mask.repeat(1, num_videos_per_prompt)
382
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_videos_per_prompt, seq_len)
383
+
384
+ if do_classifier_free_guidance:
385
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
386
+ seq_len = negative_prompt_embeds.shape[1]
387
+
388
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
389
+
390
+ if negative_prompt_embeds.ndim == 2:
391
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt)
392
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
393
+ else:
394
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
395
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
396
+
397
+ if text_encoder is not None:
398
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
399
+ # Retrieve the original scale by scaling back the LoRA layers
400
+ unscale_lora_layers(text_encoder.model, lora_scale)
401
+
402
+ return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
403
+
404
+ def encode_prompt_audio_text_base(
405
+ self,
406
+ prompt,
407
+ uncond_prompt,
408
+ pixel_value_llava,
409
+ uncond_pixel_value_llava,
410
+ device,
411
+ num_images_per_prompt,
412
+ do_classifier_free_guidance,
413
+ negative_prompt=None,
414
+ prompt_embeds: Optional[torch.Tensor] = None,
415
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
416
+ lora_scale: Optional[float] = None,
417
+ clip_skip: Optional[int] = None,
418
+ text_encoder: Optional[TextEncoder] = None,
419
+ data_type: Optional[str] = "image",
420
+ name = "person"
421
+ ):
422
+ if text_encoder is None:
423
+ text_encoder = self.text_encoder
424
+
425
+ # set lora scale so that monkey patched LoRA
426
+ # function of text encoder can correctly access it
427
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
428
+ self._lora_scale = lora_scale
429
+
430
+ # dynamically adjust the LoRA scale
431
+ if not USE_PEFT_BACKEND:
432
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
433
+ else:
434
+ scale_lora_layers(text_encoder.model, lora_scale)
435
+
436
+ if prompt is not None and isinstance(prompt, str):
437
+ batch_size = 1
438
+ elif prompt is not None and isinstance(prompt, list):
439
+ batch_size = len(prompt)
440
+ else:
441
+ batch_size = prompt_embeds.shape[0]
442
+
443
+ prompt_embeds = None
444
+
445
+ if prompt_embeds is None:
446
+ # textual inversion: process multi-vector tokens if necessary
447
+ if isinstance(self, TextualInversionLoaderMixin):
448
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
449
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type, name=name) # data_type: video, text_inputs: {'input_ids', 'attention_mask'}
450
+
451
+ text_keys = ['input_ids', 'attention_mask']
452
+
453
+ if pixel_value_llava is not None:
454
+ text_inputs['pixel_value_llava'] = pixel_value_llava
455
+ text_inputs['attention_mask'] = torch.cat([text_inputs['attention_mask'], torch.ones((1, 575)).to(text_inputs['attention_mask'])], dim=1)
456
+
457
+
458
+ if clip_skip is None:
459
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
460
+ prompt_embeds = prompt_outputs.hidden_state
461
+ else:
462
+ prompt_outputs = text_encoder.encode(text_inputs, output_hidden_states=True, data_type=data_type)
463
+ # Access the `hidden_states` first, that contains a tuple of
464
+ # all the hidden states from the encoder layers. Then index into
465
+ # the tuple to access the hidden states from the desired layer.
466
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
467
+ # We also need to apply the final LayerNorm here to not mess with the
468
+ # representations. The `last_hidden_states` that we typically use for
469
+ # obtaining the final prompt representations passes through the LayerNorm
470
+ # layer.
471
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
472
+
473
+ attention_mask = prompt_outputs.attention_mask
474
+ if attention_mask is not None:
475
+ attention_mask = attention_mask.to(device)
476
+ bs_embed, seq_len = attention_mask.shape
477
+ attention_mask = attention_mask.repeat(1, num_images_per_prompt)
478
+ attention_mask = attention_mask.view(bs_embed * num_images_per_prompt, seq_len)
479
+
480
+ if text_encoder is not None:
481
+ prompt_embeds_dtype = text_encoder.dtype
482
+ elif self.unet is not None:
483
+ prompt_embeds_dtype = self.unet.dtype
484
+ else:
485
+ prompt_embeds_dtype = prompt_embeds.dtype
486
+
487
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
488
+
489
+ if prompt_embeds.ndim == 2:
490
+ bs_embed, _ = prompt_embeds.shape
491
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
492
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
493
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, -1)
494
+ else:
495
+ bs_embed, seq_len, _ = prompt_embeds.shape
496
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
497
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
498
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
499
+
500
+ # get unconditional embeddings for classifier free guidance
501
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
502
+ uncond_tokens: List[str]
503
+ if negative_prompt is None:
504
+ uncond_tokens = [""] * batch_size
505
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
506
+ raise TypeError(
507
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
508
+ f" {type(prompt)}."
509
+ )
510
+ elif isinstance(negative_prompt, str):
511
+ uncond_tokens = [negative_prompt]
512
+ elif batch_size != len(negative_prompt):
513
+ raise ValueError(
514
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
515
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
516
+ " the batch size of `prompt`."
517
+ )
518
+ else:
519
+ uncond_tokens = negative_prompt
520
+
521
+ # textual inversion: process multi-vector tokens if necessary
522
+ if isinstance(self, TextualInversionLoaderMixin):
523
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, text_encoder.tokenizer)
524
+ # max_length = prompt_embeds.shape[1]
525
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type, name=name)
526
+
527
+ # if hasattr(text_encoder.model.config, "use_attention_mask") and text_encoder.model.config.use_attention_mask:
528
+ # attention_mask = uncond_input.attention_mask.to(device)
529
+ # else:
530
+ # attention_mask = None
531
+ if uncond_pixel_value_llava is not None:
532
+ uncond_input['pixel_value_llava'] = uncond_pixel_value_llava
533
+ uncond_input['attention_mask'] = torch.cat([uncond_input['attention_mask'], torch.ones((1, 575)).to(uncond_input['attention_mask'])], dim=1)
534
+
535
+ negative_prompt_outputs = text_encoder.encode(uncond_input, data_type=data_type)
536
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
537
+
538
+ negative_attention_mask = negative_prompt_outputs.attention_mask
539
+ if negative_attention_mask is not None:
540
+ negative_attention_mask = negative_attention_mask.to(device)
541
+ _, seq_len = negative_attention_mask.shape
542
+ negative_attention_mask = negative_attention_mask.repeat(1, num_images_per_prompt)
543
+ negative_attention_mask = negative_attention_mask.view(batch_size * num_images_per_prompt, seq_len)
544
+
545
+ if do_classifier_free_guidance:
546
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
547
+ seq_len = negative_prompt_embeds.shape[1]
548
+
549
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
550
+
551
+ if negative_prompt_embeds.ndim == 2:
552
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt)
553
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
554
+ else:
555
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
556
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
557
+
558
+ if text_encoder is not None:
559
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
560
+ # Retrieve the original scale by scaling back the LoRA layers
561
+ unscale_lora_layers(text_encoder.model, lora_scale)
562
+
563
+ return prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask
564
+
565
+ def decode_latents(self, latents, enable_tiling=True):
566
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
567
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
568
+
569
+ latents = 1 / self.vae.config.scaling_factor * latents
570
+ if enable_tiling:
571
+ self.vae.enable_tiling()
572
+ image = self.vae.decode(latents, return_dict=False)[0]
573
+ self.vae.disable_tiling()
574
+ else:
575
+ image = self.vae.decode(latents, return_dict=False)[0]
576
+ image = (image / 2 + 0.5).clamp(0, 1)
577
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
578
+ if image.ndim==4: image = image.cpu().permute(0, 2, 3, 1).float()
579
+ else: image = image.cpu().float()
580
+ return image
581
+
582
+ def prepare_extra_func_kwargs(self, func, kwargs):
583
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
584
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
585
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
586
+ # and should be between [0, 1]
587
+ extra_step_kwargs = {}
588
+
589
+ for k, v in kwargs.items():
590
+ accepts = k in set(inspect.signature(func).parameters.keys())
591
+ if accepts:
592
+ extra_step_kwargs[k] = v
593
+ return extra_step_kwargs
594
+
595
+ def check_inputs(
596
+ self,
597
+ prompt,
598
+ height,
599
+ width,
600
+ frame,
601
+ callback_steps,
602
+ pixel_value_llava=None,
603
+ uncond_pixel_value_llava=None,
604
+ negative_prompt=None,
605
+ prompt_embeds=None,
606
+ negative_prompt_embeds=None,
607
+ callback_on_step_end_tensor_inputs=None,
608
+ vae_ver='88-4c-sd'
609
+ ):
610
+ if height % 8 != 0 or width % 8 != 0:
611
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
612
+
613
+ if frame is not None:
614
+ if '884' in vae_ver:
615
+ if frame!=1 and (frame-1)%4!=0:
616
+ raise ValueError(f'`frame` has to be 1 or a multiple of 4 but is {frame}.')
617
+ elif '888' in vae_ver:
618
+ if frame!=1 and (frame-1)%8!=0:
619
+ raise ValueError(f'`frame` has to be 1 or a multiple of 8 but is {frame}.')
620
+
621
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
622
+ raise ValueError(
623
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
624
+ f" {type(callback_steps)}."
625
+ )
626
+ if callback_on_step_end_tensor_inputs is not None and not all(
627
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
628
+ ):
629
+ raise ValueError(
630
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
631
+ )
632
+
633
+ if prompt is not None and prompt_embeds is not None:
634
+ raise ValueError(
635
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
636
+ " only forward one of the two."
637
+ )
638
+ elif prompt is None and prompt_embeds is None:
639
+ raise ValueError(
640
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
641
+ )
642
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
643
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
644
+
645
+ if negative_prompt is not None and negative_prompt_embeds is not None:
646
+ raise ValueError(
647
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
648
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
649
+ )
650
+
651
+ if pixel_value_llava is not None and uncond_pixel_value_llava is not None:
652
+ if len(pixel_value_llava) != len(uncond_pixel_value_llava):
653
+ raise ValueError(
654
+ "`pixel_value_llava` and `uncond_pixel_value_llava` must have the same length when passed directly, but"
655
+ f" got: `pixel_value_llava` {len(pixel_value_llava)} != `uncond_pixel_value_llava`"
656
+ f" {len(uncond_pixel_value_llava)}."
657
+ )
658
+
659
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
660
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
661
+ raise ValueError(
662
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
663
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
664
+ f" {negative_prompt_embeds.shape}."
665
+ )
666
+
667
+ def get_timesteps(self, num_inference_steps, strength, device):
668
+ # get the original timestep using init_timestep
669
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
670
+
671
+ t_start = max(num_inference_steps - init_timestep, 0)
672
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
673
+ if hasattr(self.scheduler, "set_begin_index"):
674
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
675
+
676
+ return timesteps.to(device), num_inference_steps - t_start
677
+
678
+ def prepare_latents(self, batch_size, num_channels_latents, height, width, frame, dtype, device, generator, latents=None, ref_latents=None, timestep=None):
679
+ shape = (
680
+ batch_size,
681
+ num_channels_latents,
682
+ frame,
683
+ int(height) // self.vae_scale_factor,
684
+ int(width) // self.vae_scale_factor,
685
+ )
686
+ if isinstance(generator, list) and len(generator) != batch_size:
687
+ raise ValueError(
688
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
689
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
690
+ )
691
+
692
+ if latents is None:
693
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
694
+ else:
695
+ latents = latents.to(device)
696
+
697
+
698
+ if timestep is not None:
699
+ init_latents = ref_latents.clone().repeat(1,1,frame,1,1).to(device).to(dtype)
700
+ latents = latents
701
+
702
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
703
+ if hasattr(self.scheduler, "init_noise_sigma"):
704
+ latents = latents * self.scheduler.init_noise_sigma
705
+
706
+ return latents
707
+
708
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
709
+ def get_guidance_scale_embedding(
710
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
711
+ ) -> torch.Tensor:
712
+ """
713
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
714
+
715
+ Args:
716
+ w (`torch.Tensor`):
717
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
718
+ embedding_dim (`int`, *optional*, defaults to 512):
719
+ Dimension of the embeddings to generate.
720
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
721
+ Data type of the generated embeddings.
722
+
723
+ Returns:
724
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
725
+ """
726
+ assert len(w.shape) == 1
727
+ w = w * 1000.0
728
+
729
+ half_dim = embedding_dim // 2
730
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
731
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
732
+ emb = w.to(dtype)[:, None] * emb[None, :]
733
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
734
+ if embedding_dim % 2 == 1: # zero pad
735
+ emb = torch.nn.functional.pad(emb, (0, 1))
736
+ assert emb.shape == (w.shape[0], embedding_dim)
737
+ return emb
738
+
739
+ @property
740
+ def guidance_scale(self):
741
+ return self._guidance_scale
742
+
743
+ @property
744
+ def guidance_rescale(self):
745
+ return self._guidance_rescale
746
+
747
+ @property
748
+ def clip_skip(self):
749
+ return self._clip_skip
750
+
751
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
752
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
753
+ # corresponds to doing no classifier free guidance.
754
+ @property
755
+ def do_classifier_free_guidance(self):
756
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
757
+ return self._guidance_scale > 1
758
+
759
+ @property
760
+ def cross_attention_kwargs(self):
761
+ return self._cross_attention_kwargs
762
+
763
+ @property
764
+ def num_timesteps(self):
765
+ return self._num_timesteps
766
+
767
+ @property
768
+ def interrupt(self):
769
+ return self._interrupt
770
+
771
+ @torch.no_grad()
772
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
773
+ def __call__(
774
+ self,
775
+ prompt: Union[str, List[str]],
776
+
777
+ ref_latents: Union[torch.Tensor], # [1, 16, 1, h//8, w//8]
778
+ # uncond_ref_latents: Union[torch.Tensor],
779
+ pixel_value_llava: Union[torch.Tensor], # [1, 3, 336, 336]
780
+ uncond_pixel_value_llava: Union[torch.Tensor],
781
+ pixel_value_ref: Union[torch.Tensor],
782
+ face_masks: Union[torch.Tensor], # [b f h w]
783
+ audio_prompts: Union[torch.Tensor],
784
+ uncond_audio_prompts: Union[torch.Tensor],
785
+ motion_exp: Union[torch.Tensor],
786
+ motion_pose: Union[torch.Tensor],
787
+ fps: Union[torch.Tensor],
788
+
789
+ height: int,
790
+ width: int,
791
+ video_length: int,
792
+ data_type: str = "video",
793
+ num_inference_steps: int = 50,
794
+ timesteps: List[int] = None,
795
+ sigmas: List[float] = None,
796
+ guidance_scale: float = 7.5,
797
+ negative_prompt: Optional[Union[str, List[str]]] = None,
798
+ num_videos_per_prompt: Optional[int] = 1,
799
+ eta: float = 0.0,
800
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
801
+ latents: Optional[torch.Tensor] = None,
802
+ prompt_embeds: Optional[torch.Tensor] = None,
803
+ attention_mask: Optional[torch.Tensor] = None,
804
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
805
+ negative_attention_mask: Optional[torch.Tensor] = None,
806
+ output_type: Optional[str] = "pil",
807
+ return_dict: bool = True,
808
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
809
+ guidance_rescale: float = 0.0,
810
+ clip_skip: Optional[int] = None,
811
+ callback_on_step_end: Optional[
812
+ Union[
813
+ Callable[[int, int, Dict], None],
814
+ PipelineCallback,
815
+ MultiPipelineCallbacks,
816
+ ]
817
+ ] = None,
818
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
819
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
820
+ vae_ver: str = "88-4c-sd",
821
+ enable_tiling: bool = False,
822
+ n_tokens: Optional[int] = None,
823
+ embedded_guidance_scale: Optional[float] = None,
824
+ joint_pass = False,
825
+ cfg_star_rescale = False,
826
+ name = None,
827
+ **kwargs,
828
+ ):
829
+ r"""
830
+ The call function to the pipeline for generation.
831
+
832
+ Args:
833
+ prompt (`str` or `List[str]`):
834
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
835
+ height (`int`):
836
+ The height in pixels of the generated image.
837
+ width (`int`):
838
+ The width in pixels of the generated image.
839
+ video_length (`int`):
840
+ The number of frames in the generated video.
841
+ num_inference_steps (`int`, *optional*, defaults to 50):
842
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
843
+ expense of slower inference.
844
+ timesteps (`List[int]`, *optional*):
845
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
846
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
847
+ passed will be used. Must be in descending order.
848
+ sigmas (`List[float]`, *optional*):
849
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
850
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
851
+ will be used.
852
+ guidance_scale (`float`, *optional*, defaults to 7.5):
853
+ A higher guidance scale value encourages the model to generate images closely linked to the text
854
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
855
+ negative_prompt (`str` or `List[str]`, *optional*):
856
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
857
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
858
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
859
+ The number of images to generate per prompt.
860
+ eta (`float`, *optional*, defaults to 0.0):
861
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
862
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
863
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
864
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
865
+ generation deterministic.
866
+ latents (`torch.Tensor`, *optional*):
867
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
868
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
869
+ tensor is generated by sampling using the supplied random `generator`.
870
+ prompt_embeds (`torch.Tensor`, *optional*):
871
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
872
+ provided, text embeddings are generated from the `prompt` input argument.
873
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
874
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
875
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
876
+
877
+ output_type (`str`, *optional*, defaults to `"pil"`):
878
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
879
+ return_dict (`bool`, *optional*, defaults to `True`):
880
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
881
+ plain tuple.
882
+ cross_attention_kwargs (`dict`, *optional*):
883
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
884
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
885
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
886
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
887
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
888
+ using zero terminal SNR.
889
+ clip_skip (`int`, *optional*):
890
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
891
+ the output of the pre-final layer will be used for computing the prompt embeddings.
892
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
893
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
894
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
895
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
896
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
897
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
898
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
899
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
900
+ `._callback_tensor_inputs` attribute of your pipeline class.
901
+
902
+ Examples:
903
+
904
+ Returns:
905
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
906
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
907
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
908
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
909
+ "not-safe-for-work" (nsfw) content.
910
+ """
911
+ callback = kwargs.pop("callback", None)
912
+ callback_steps = kwargs.pop("callback_steps", None)
913
+ if callback_steps is not None:
914
+ deprecate(
915
+ "callback_steps",
916
+ "1.0.0",
917
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
918
+ )
919
+
920
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
921
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
922
+
923
+
924
+ # num_inference_steps = 50
925
+
926
+ # 0. Default height and width to transformer
927
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
928
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
929
+ # to deal with lora scaling and other possible forward hooks
930
+
931
+ transformer = self.transformer
932
+
933
+ if transformer.enable_teacache:
934
+ teacache_multiplier = transformer.teacache_multiplier
935
+ transformer.accumulated_rel_l1_distance = 0
936
+ transformer.rel_l1_thresh = 0.1 if teacache_multiplier < 2 else 0.15
937
+
938
+ # 1. Check inputs. Raise error if not correct
939
+ self.check_inputs(
940
+ prompt,
941
+ height,
942
+ width,
943
+ video_length,
944
+ callback_steps,
945
+ pixel_value_llava,
946
+ uncond_pixel_value_llava,
947
+ negative_prompt,
948
+ prompt_embeds,
949
+ negative_prompt_embeds,
950
+ callback_on_step_end_tensor_inputs,
951
+ vae_ver=vae_ver
952
+ )
953
+
954
+ self._guidance_scale = guidance_scale
955
+ self.start_cfg_scale = guidance_scale
956
+ self._guidance_rescale = guidance_rescale
957
+ self._clip_skip = clip_skip
958
+ self._cross_attention_kwargs = cross_attention_kwargs
959
+ self._interrupt = False
960
+
961
+ # 2. Define call parameters
962
+ if prompt is not None and isinstance(prompt, str):
963
+ batch_size = 1
964
+ elif prompt is not None and isinstance(prompt, list):
965
+ batch_size = len(prompt)
966
+ else:
967
+ batch_size = prompt_embeds.shape[0]
968
+
969
+ device = self._execution_device
970
+
971
+ # 3. Encode input prompt
972
+ lora_scale = (
973
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
974
+ )
975
+
976
+
977
+ # ========== Encode text prompt (image prompt) ==========
978
+ prompt_embeds, negative_prompt_embeds, prompt_mask, negative_prompt_mask = \
979
+ self.encode_prompt_audio_text_base(
980
+ prompt=prompt,
981
+ uncond_prompt=negative_prompt,
982
+ pixel_value_llava=pixel_value_llava,
983
+ uncond_pixel_value_llava=uncond_pixel_value_llava,
984
+ device=device,
985
+ num_images_per_prompt=num_videos_per_prompt,
986
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
987
+ negative_prompt=negative_prompt,
988
+ prompt_embeds=prompt_embeds,
989
+ negative_prompt_embeds=negative_prompt_embeds,
990
+ lora_scale=lora_scale,
991
+ clip_skip=self.clip_skip,
992
+ text_encoder=self.text_encoder,
993
+ data_type=data_type,
994
+ name= name,
995
+ # **kwargs
996
+ )
997
+ if self.text_encoder_2 is not None:
998
+ prompt_embeds_2, negative_prompt_embeds_2, prompt_mask_2, negative_prompt_mask_2 = \
999
+ self.encode_prompt_audio_text_base(
1000
+ prompt=prompt,
1001
+ uncond_prompt=negative_prompt,
1002
+ pixel_value_llava=None,
1003
+ uncond_pixel_value_llava=None,
1004
+ device=device,
1005
+ num_images_per_prompt=num_videos_per_prompt,
1006
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1007
+ negative_prompt=negative_prompt,
1008
+ prompt_embeds=None,
1009
+ negative_prompt_embeds=None,
1010
+ lora_scale=lora_scale,
1011
+ clip_skip=self.clip_skip,
1012
+ text_encoder=self.text_encoder_2,
1013
+ # **kwargs
1014
+ )
1015
+ else:
1016
+ prompt_embeds_2 = None
1017
+ negative_prompt_embeds_2 = None
1018
+ prompt_mask_2 = None
1019
+ negative_prompt_mask_2 = None
1020
+
1021
+ if self.transformer.mixed_precision:
1022
+ latent_dtype = torch.float32
1023
+ else:
1024
+ latent_dtype = torch.bfloat16
1025
+ if prompt_embeds != None:
1026
+ prompt_embeds = prompt_embeds.to(torch.bfloat16)
1027
+ if negative_prompt_embeds != None:
1028
+ negative_prompt_embeds = negative_prompt_embeds.to(torch.bfloat16)
1029
+ if prompt_embeds_2 != None:
1030
+ prompt_embeds_2 = prompt_embeds_2.to(torch.bfloat16)
1031
+ if negative_prompt_embeds_2 != None:
1032
+ negative_prompt_embeds_2 = negative_prompt_embeds_2.to(torch.bfloat16)
1033
+ if audio_prompts != None:
1034
+ audio_prompts = audio_prompts.to(torch.bfloat16)
1035
+ if face_masks!= None:
1036
+ face_masks = face_masks.to(torch.bfloat16)
1037
+ if ref_latents != None:
1038
+ ref_latents = ref_latents.to(torch.bfloat16)
1039
+
1040
+ # For classifier free guidance, we need to do two forward passes.
1041
+ # Here we concatenate the unconditional and text embeddings into a single batch
1042
+ # to avoid doing two forward passes
1043
+ if self.do_classifier_free_guidance:
1044
+ prompt_embeds_input = torch.cat([negative_prompt_embeds, prompt_embeds])
1045
+ if prompt_mask is not None:
1046
+ prompt_mask_input = torch.cat([negative_prompt_mask, prompt_mask])
1047
+ if prompt_embeds_2 is not None:
1048
+ prompt_embeds_2_input = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
1049
+ if prompt_mask_2 is not None:
1050
+ prompt_mask_2_input = torch.cat([negative_prompt_mask_2, prompt_mask_2])
1051
+
1052
+ if self.do_classifier_free_guidance and ref_latents != None:
1053
+ ref_latents = torch.cat([ref_latents, ref_latents], dim=0)
1054
+
1055
+
1056
+ # 4. Prepare timesteps
1057
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
1058
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
1059
+ )
1060
+ timesteps, num_inference_steps = retrieve_timesteps(
1061
+ self.scheduler, num_inference_steps, device, timesteps, sigmas, **extra_set_timesteps_kwargs,
1062
+ )
1063
+
1064
+ video_length = audio_prompts.shape[1] // 4 * 4 + 1
1065
+ if "884" in vae_ver:
1066
+ video_length = (video_length - 1) // 4 + 1
1067
+ elif "888" in vae_ver:
1068
+ video_length = (video_length - 1) // 8 + 1
1069
+ else:
1070
+ video_length = video_length
1071
+
1072
+
1073
+ # 5. Prepare latent variables
1074
+ num_channels_latents = self.transformer.config.in_channels
1075
+ infer_length = (audio_prompts.shape[1] // 128 + 1) * 32 + 1
1076
+ latents = self.prepare_latents(
1077
+ batch_size * num_videos_per_prompt,
1078
+ num_channels_latents,
1079
+ height,
1080
+ width,
1081
+ infer_length,
1082
+ latent_dtype, #prompt_embeds.dtype,
1083
+ device,
1084
+ generator,
1085
+ latents,
1086
+ ref_latents[-1:] if ref_latents != None else None,
1087
+ timesteps[:1]
1088
+ )
1089
+
1090
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1091
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
1092
+ self.scheduler.step, {"generator": generator, "eta": eta},
1093
+ )
1094
+
1095
+ vae_precision = "fp16" # torch.float16
1096
+ precision = "bf16" # torch.bfloat16
1097
+ disable_autocast = True
1098
+
1099
+ target_dtype = PRECISION_TO_TYPE[precision]
1100
+ autocast_enabled = (target_dtype != torch.float32) and not disable_autocast
1101
+ vae_dtype = self.vae._model_dtype #PRECISION_TO_TYPE[vae_precision]
1102
+ vae_autocast_enabled = (vae_dtype != torch.float32) and not disable_autocast
1103
+
1104
+ # 7. Denoising loop
1105
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1106
+ self._num_timesteps = len(timesteps)
1107
+
1108
+ latents_all = latents.clone()
1109
+ pad_audio_length = (audio_prompts.shape[1] // 128 + 1) * 128 + 4 - audio_prompts.shape[1]
1110
+ audio_prompts_all = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :pad_audio_length])], dim=1)
1111
+
1112
+
1113
+ shift = 0
1114
+ shift_offset = 10
1115
+ frames_per_batch = 33
1116
+ self.cache_tensor = None
1117
+
1118
+ """ If the total length is shorter than 129, shift is not required """
1119
+ if video_length == 33 or infer_length == 33:
1120
+ infer_length = 33
1121
+ shift_offset = 0
1122
+ latents_all = latents_all[:, :, :33]
1123
+ audio_prompts_all = audio_prompts_all[:, :132]
1124
+ joint_pass = joint_pass or not self.do_classifier_free_guidance
1125
+
1126
+ if callback != None:
1127
+ callback(-1, None, True, override_num_inference_steps = num_inference_steps)
1128
+
1129
+ latent_items = 2 if self.do_classifier_free_guidance else 1
1130
+
1131
+ fps = torch.from_numpy(np.array(fps)).unsqueeze(0).to(dtype=torch.float16)
1132
+
1133
+ if self._interrupt:
1134
+ return [None]
1135
+
1136
+ if transformer.enable_teacache:
1137
+ cache_size = round( infer_length / frames_per_batch )
1138
+ transformer.previous_residual = [None] * latent_items
1139
+ cache_all_previous_residual = [None] * latent_items
1140
+ cache_all_previous_modulated_input = None
1141
+ cache_should_calc = [True] * cache_size
1142
+ cache_accumulated_rel_l1_distance = [0.] * cache_size
1143
+ cache_teacache_skipped_steps = [0] * cache_size
1144
+
1145
+
1146
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1147
+ for i, t in enumerate(timesteps):
1148
+ # init
1149
+ pred_latents = torch.zeros_like(
1150
+ latents_all,
1151
+ dtype=latents_all.dtype,
1152
+ )
1153
+ counter = torch.zeros(
1154
+ (latents_all.shape[0], latents_all.shape[1], infer_length, 1, 1),
1155
+ dtype=latents_all.dtype,
1156
+ ).to(device=latents_all.device)
1157
+
1158
+ cache_slot_no = 0
1159
+ for index_start in range(0, infer_length, frames_per_batch):
1160
+ self.scheduler._step_index = None
1161
+
1162
+ index_start = index_start - shift
1163
+ idx_list = [ii % latents_all.shape[2] for ii in range(index_start, index_start + frames_per_batch)]
1164
+ latents = latents_all[:, :, idx_list].clone()
1165
+
1166
+ idx_list_audio = [ii % audio_prompts_all.shape[1] for ii in range(index_start * 4, (index_start + frames_per_batch) * 4 - 3)]
1167
+ audio_prompts = audio_prompts_all[:, idx_list_audio].clone()
1168
+
1169
+ # expand the latents if we are doing classifier free guidance
1170
+ if self.do_classifier_free_guidance:
1171
+ latent_model_input = torch.cat([latents] * 2)
1172
+ else:
1173
+ latent_model_input = latents
1174
+
1175
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1176
+ embedded_hw = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * 3072
1177
+ img_ref_len = (latent_model_input.shape[-1] // 2) * (latent_model_input.shape[-2] // 2) * ( 1)
1178
+ img_all_len = (latents_all.shape[-1] // 2) * (latents_all.shape[-2] // 2) * latents_all.shape[-3]
1179
+
1180
+ if transformer.enable_teacache and cache_size > 1:
1181
+ for l in range(latent_items):
1182
+ if cache_all_previous_residual[l] != None:
1183
+ bsz = cache_all_previous_residual[l].shape[0]
1184
+ transformer.previous_residual[l][:, img_ref_len:] = cache_all_previous_residual[l].reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
1185
+ if cache_all_previous_modulated_input != None:
1186
+ transformer.previous_modulated_input[:, img_ref_len:] = cache_all_previous_modulated_input.reshape(1, -1, embedded_hw) [:, idx_list].reshape(1, -1, 3072)
1187
+ transformer.should_calc = cache_should_calc[cache_slot_no]
1188
+ transformer.accumulated_rel_l1_distance = cache_accumulated_rel_l1_distance[cache_slot_no]
1189
+ transformer.teacache_skipped_steps = cache_teacache_skipped_steps[cache_slot_no]
1190
+
1191
+
1192
+ if self.do_classifier_free_guidance:
1193
+ if i < num_inference_steps * 0.2 :
1194
+ self._guidance_scale = (1 - i / len(timesteps)) * (self.start_cfg_scale - 2) + 2
1195
+ audio_prompts_input = torch.cat([uncond_audio_prompts, audio_prompts], dim=0)
1196
+ face_masks_input = torch.cat([face_masks * 0.6] * 2, dim=0)
1197
+ else:
1198
+ # define 10-50 step cfg
1199
+ self._guidance_scale = (1 - i / len(timesteps)) * (6.5 - 3.5) + 3.5 # 5-2 +2
1200
+
1201
+ prompt_embeds_input = torch.cat([prompt_embeds, prompt_embeds])
1202
+ if prompt_mask is not None:
1203
+ prompt_mask_input = torch.cat([prompt_mask, prompt_mask])
1204
+ if prompt_embeds_2 is not None:
1205
+ prompt_embeds_2_input = torch.cat([prompt_embeds_2, prompt_embeds_2])
1206
+ if prompt_mask_2 is not None:
1207
+ prompt_mask_2_input = torch.cat([prompt_mask_2, prompt_mask_2])
1208
+ audio_prompts_input = torch.cat([uncond_audio_prompts, audio_prompts], dim=0)
1209
+ face_masks_input = torch.cat([face_masks] * 2, dim=0)
1210
+
1211
+ motion_exp_input = torch.cat([motion_exp] * 2, dim=0)
1212
+ motion_pose_input = torch.cat([motion_pose] * 2, dim=0)
1213
+ fps_input = torch.cat([fps] * 2, dim=0)
1214
+
1215
+ else:
1216
+ audio_prompts_input = audio_prompts
1217
+ face_masks_input = face_masks
1218
+ motion_exp_input = motion_exp
1219
+ motion_pose_input = motion_pose
1220
+ fps_input = fps
1221
+
1222
+ t_expand = t.repeat(latent_model_input.shape[0])
1223
+ guidance_expand = None
1224
+
1225
+ with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
1226
+ additional_kwargs = {
1227
+ "pipeline": self,
1228
+ "step_no": i,
1229
+ }
1230
+ if joint_pass:
1231
+ additional_kwargs.update({
1232
+ "motion_exp": motion_exp_input,
1233
+ "motion_pose": motion_pose_input,
1234
+ "fps": fps_input,
1235
+ "audio_prompts": audio_prompts_input,
1236
+ "face_mask": face_masks_input
1237
+ })
1238
+ noise_pred = self.transformer(latent_model_input, t_expand, ref_latents=ref_latents, text_states=prompt_embeds_input, text_mask=prompt_mask_input, text_states_2=prompt_embeds_2_input, freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, **additional_kwargs,)
1239
+ if self._interrupt:
1240
+ return [None]
1241
+ if self.do_classifier_free_guidance:
1242
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1243
+ else:
1244
+ additional_kwargs.update({
1245
+ "motion_exp": motion_exp_input[:1],
1246
+ "motion_pose": motion_pose_input[:1],
1247
+ "fps": fps_input[:1],
1248
+ "audio_prompts": audio_prompts_input[:1],
1249
+ "face_mask": face_masks_input[:1]
1250
+ })
1251
+ noise_pred_uncond = self.transformer(latent_model_input[:1], t_expand[:1], ref_latents=ref_latents[:1], text_states=prompt_embeds_input[:1], text_mask=prompt_mask_input[:1], text_states_2=prompt_embeds_2_input[:1], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, x_id = 0, **additional_kwargs,)
1252
+ if self._interrupt:
1253
+ return [None]
1254
+ noise_pred_uncond = noise_pred_uncond[0]
1255
+ additional_kwargs.update({
1256
+ "motion_exp": motion_exp_input[1:],
1257
+ "motion_pose": motion_pose_input[1:],
1258
+ "fps": fps_input[1:],
1259
+ "audio_prompts": audio_prompts_input[1:],
1260
+ "face_mask": face_masks_input[1:]
1261
+ })
1262
+ noise_pred_text = self.transformer(latent_model_input[1:], t_expand[1:], ref_latents=ref_latents[1:], text_states=prompt_embeds_input[1:], text_mask=prompt_mask_input[1:], text_states_2=prompt_embeds_2_input[1:], freqs_cos=freqs_cis[0], freqs_sin=freqs_cis[1], guidance=guidance_expand, x_id = 1, **additional_kwargs,)
1263
+ if self._interrupt:
1264
+ return [None]
1265
+ noise_pred_text = noise_pred_text[0]
1266
+
1267
+ # perform guidance
1268
+ if self.do_classifier_free_guidance:
1269
+ if cfg_star_rescale:
1270
+ batch_size = 1
1271
+ positive_flat = noise_pred_text.view(batch_size, -1)
1272
+ negative_flat = noise_pred_uncond.view(batch_size, -1)
1273
+ dot_product = torch.sum(
1274
+ positive_flat * negative_flat, dim=1, keepdim=True
1275
+ )
1276
+ squared_norm = torch.sum(negative_flat**2, dim=1, keepdim=True) + 1e-8
1277
+ positive_flat, negative_flat = None, None
1278
+ alpha = dot_product / squared_norm
1279
+ noise_pred_uncond *= alpha
1280
+
1281
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1282
+
1283
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1284
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1285
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1286
+ noise_pred_text, noise_pred_uncond = None, None
1287
+ # compute the previous noisy sample x_t -> x_t-1
1288
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1289
+ noise_pred = None
1290
+
1291
+ latents = latents.to(torch.bfloat16)
1292
+ for iii in range(frames_per_batch):
1293
+ p = (index_start + iii) % pred_latents.shape[2]
1294
+ pred_latents[:, :, p] += latents[:, :, iii]
1295
+ counter[:, :, p] += 1
1296
+
1297
+ if transformer.enable_teacache and cache_size > 1:
1298
+ for l in range(latent_items):
1299
+ if transformer.previous_residual[l] != None:
1300
+ bsz = transformer.previous_residual[l].shape[0]
1301
+ if cache_all_previous_residual[l] == None:
1302
+ cache_all_previous_residual[l] = torch.zeros((bsz, img_all_len, 3072 ), device=transformer.previous_residual[l].device, dtype=transformer.previous_residual[l].dtype)
1303
+ cache_all_previous_residual[l].reshape(bsz, -1, embedded_hw)[:, idx_list] = transformer.previous_residual[l][:, img_ref_len:].reshape(bsz, -1, embedded_hw)
1304
+
1305
+ if transformer.previous_modulated_input != None:
1306
+ if cache_all_previous_modulated_input == None:
1307
+ cache_all_previous_modulated_input = torch.zeros((1, img_all_len, 3072 ), device=transformer.previous_modulated_input.device, dtype=transformer.previous_modulated_input.dtype)
1308
+ cache_all_previous_modulated_input.reshape(1, -1, embedded_hw)[:, idx_list] = transformer.previous_modulated_input[:, img_ref_len:].reshape(1, -1, embedded_hw)
1309
+ cache_should_calc[cache_slot_no] = transformer.should_calc
1310
+ cache_accumulated_rel_l1_distance[cache_slot_no] = transformer.accumulated_rel_l1_distance
1311
+ cache_teacache_skipped_steps[cache_slot_no] = transformer.teacache_skipped_steps
1312
+
1313
+ cache_slot_no += 1
1314
+
1315
+ shift += shift_offset
1316
+ shift = shift % frames_per_batch
1317
+ pred_latents = pred_latents / counter
1318
+ latents_all = pred_latents
1319
+
1320
+ if callback is not None:
1321
+ callback(i, latents_all.squeeze(0), False)
1322
+
1323
+ latents = latents_all.float()[:, :, :video_length]
1324
+
1325
+ if not output_type == "latent":
1326
+ expand_temporal_dim = False
1327
+ if len(latents.shape) == 4:
1328
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1329
+ latents = latents.unsqueeze(2)
1330
+ expand_temporal_dim = True
1331
+ elif len(latents.shape) == 5:
1332
+ pass
1333
+ else:
1334
+ raise ValueError(
1335
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
1336
+
1337
+ if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
1338
+ latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
1339
+ else:
1340
+ latents = latents / self.vae.config.scaling_factor
1341
+
1342
+ with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled):
1343
+ image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
1344
+ if image is None:
1345
+ return (None, )
1346
+
1347
+ if expand_temporal_dim or image.shape[2] == 1:
1348
+ image = image.squeeze(2)
1349
+
1350
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1351
+ image = image.cpu().float()
1352
+
1353
+ # Offload all models
1354
+ self.maybe_free_model_hooks()
1355
+
1356
+ if not return_dict:
1357
+ return image
1358
+
1359
+ return HunyuanVideoPipelineOutput(videos=image)
hyvideo/diffusion/schedulers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
hyvideo/diffusion/schedulers/scheduling_flow_match_discrete.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from dataclasses import dataclass
21
+ from typing import Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import torch
25
+
26
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
27
+ from diffusers.utils import BaseOutput, logging
28
+ from diffusers.schedulers.scheduling_utils import SchedulerMixin
29
+
30
+
31
+
32
+ @dataclass
33
+ class FlowMatchDiscreteSchedulerOutput(BaseOutput):
34
+ """
35
+ Output class for the scheduler's `step` function output.
36
+
37
+ Args:
38
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
40
+ denoising loop.
41
+ """
42
+
43
+ prev_sample: torch.FloatTensor
44
+
45
+
46
+ class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
47
+ """
48
+ Euler scheduler.
49
+
50
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
51
+ methods the library implements for all schedulers such as loading and saving.
52
+
53
+ Args:
54
+ num_train_timesteps (`int`, defaults to 1000):
55
+ The number of diffusion steps to train the model.
56
+ timestep_spacing (`str`, defaults to `"linspace"`):
57
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
58
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
59
+ shift (`float`, defaults to 1.0):
60
+ The shift value for the timestep schedule.
61
+ reverse (`bool`, defaults to `True`):
62
+ Whether to reverse the timestep schedule.
63
+ """
64
+
65
+ _compatibles = []
66
+ order = 1
67
+
68
+ @register_to_config
69
+ def __init__(
70
+ self,
71
+ num_train_timesteps: int = 1000,
72
+ shift: float = 1.0,
73
+ reverse: bool = True,
74
+ solver: str = "euler",
75
+ n_tokens: Optional[int] = None,
76
+ ):
77
+ sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
78
+
79
+ if not reverse:
80
+ sigmas = sigmas.flip(0)
81
+
82
+ self.sigmas = sigmas
83
+ # the value fed to model
84
+ self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
85
+
86
+ self._step_index = None
87
+ self._begin_index = None
88
+
89
+ self.supported_solver = ["euler"]
90
+ if solver not in self.supported_solver:
91
+ raise ValueError(
92
+ f"Solver {solver} not supported. Supported solvers: {self.supported_solver}"
93
+ )
94
+
95
+ @property
96
+ def step_index(self):
97
+ """
98
+ The index counter for current timestep. It will increase 1 after each scheduler step.
99
+ """
100
+ return self._step_index
101
+
102
+ @property
103
+ def begin_index(self):
104
+ """
105
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
106
+ """
107
+ return self._begin_index
108
+
109
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
110
+ def set_begin_index(self, begin_index: int = 0):
111
+ """
112
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
113
+
114
+ Args:
115
+ begin_index (`int`):
116
+ The begin index for the scheduler.
117
+ """
118
+ self._begin_index = begin_index
119
+
120
+ def _sigma_to_t(self, sigma):
121
+ return sigma * self.config.num_train_timesteps
122
+
123
+ def set_timesteps(
124
+ self,
125
+ num_inference_steps: int,
126
+ device: Union[str, torch.device] = None,
127
+ n_tokens: int = None,
128
+ ):
129
+ """
130
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
131
+
132
+ Args:
133
+ num_inference_steps (`int`):
134
+ The number of diffusion steps used when generating samples with a pre-trained model.
135
+ device (`str` or `torch.device`, *optional*):
136
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
137
+ n_tokens (`int`, *optional*):
138
+ Number of tokens in the input sequence.
139
+ """
140
+ self.num_inference_steps = num_inference_steps
141
+
142
+ sigmas = torch.linspace(1, 0, num_inference_steps + 1)
143
+ sigmas = self.sd3_time_shift(sigmas)
144
+
145
+ if not self.config.reverse:
146
+ sigmas = 1 - sigmas
147
+
148
+ self.sigmas = sigmas
149
+ self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(
150
+ dtype=torch.float32, device=device
151
+ )
152
+
153
+ # Reset step index
154
+ self._step_index = None
155
+
156
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
157
+ if schedule_timesteps is None:
158
+ schedule_timesteps = self.timesteps
159
+
160
+ indices = (schedule_timesteps == timestep).nonzero()
161
+
162
+ # The sigma index that is taken for the **very** first `step`
163
+ # is always the second index (or the last index if there is only 1)
164
+ # This way we can ensure we don't accidentally skip a sigma in
165
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
166
+ pos = 1 if len(indices) > 1 else 0
167
+
168
+ return indices[pos].item()
169
+
170
+ def _init_step_index(self, timestep):
171
+ if self.begin_index is None:
172
+ if isinstance(timestep, torch.Tensor):
173
+ timestep = timestep.to(self.timesteps.device)
174
+ self._step_index = self.index_for_timestep(timestep)
175
+ else:
176
+ self._step_index = self._begin_index
177
+
178
+ def scale_model_input(
179
+ self, sample: torch.Tensor, timestep: Optional[int] = None
180
+ ) -> torch.Tensor:
181
+ return sample
182
+
183
+ def sd3_time_shift(self, t: torch.Tensor):
184
+ return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
185
+
186
+ def step(
187
+ self,
188
+ model_output: torch.FloatTensor,
189
+ timestep: Union[float, torch.FloatTensor],
190
+ sample: torch.FloatTensor,
191
+ return_dict: bool = True,
192
+ ) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
193
+ """
194
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
195
+ process from the learned model outputs (most often the predicted noise).
196
+
197
+ Args:
198
+ model_output (`torch.FloatTensor`):
199
+ The direct output from learned diffusion model.
200
+ timestep (`float`):
201
+ The current discrete timestep in the diffusion chain.
202
+ sample (`torch.FloatTensor`):
203
+ A current instance of a sample created by the diffusion process.
204
+ generator (`torch.Generator`, *optional*):
205
+ A random number generator.
206
+ n_tokens (`int`, *optional*):
207
+ Number of tokens in the input sequence.
208
+ return_dict (`bool`):
209
+ Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
210
+ tuple.
211
+
212
+ Returns:
213
+ [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
214
+ If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
215
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
216
+ """
217
+
218
+ if (
219
+ isinstance(timestep, int)
220
+ or isinstance(timestep, torch.IntTensor)
221
+ or isinstance(timestep, torch.LongTensor)
222
+ ):
223
+ raise ValueError(
224
+ (
225
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
226
+ " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
227
+ " one of the `scheduler.timesteps` as a timestep."
228
+ ),
229
+ )
230
+
231
+ if self.step_index is None:
232
+ self._init_step_index(timestep)
233
+
234
+ # Upcast to avoid precision issues when computing prev_sample
235
+ sample = sample.to(torch.float32)
236
+
237
+ dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
238
+
239
+ if self.config.solver == "euler":
240
+ prev_sample = sample + model_output.to(torch.float32) * dt
241
+ else:
242
+ raise ValueError(
243
+ f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}"
244
+ )
245
+
246
+ # upon completion increase step index by one
247
+ self._step_index += 1
248
+
249
+ if not return_dict:
250
+ return (prev_sample,)
251
+
252
+ return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
253
+
254
+ def __len__(self):
255
+ return self.config.num_train_timesteps
hyvideo/hunyuan.py ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import random
4
+ import functools
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ from pathlib import Path
8
+ from einops import rearrange
9
+ import torch
10
+ import torch.distributed as dist
11
+ from hyvideo.constants import PROMPT_TEMPLATE, NEGATIVE_PROMPT, PRECISION_TO_TYPE, NEGATIVE_PROMPT_I2V
12
+ from hyvideo.vae import load_vae
13
+ from hyvideo.modules import load_model
14
+ from hyvideo.text_encoder import TextEncoder
15
+ from hyvideo.utils.data_utils import align_to, get_closest_ratio, generate_crop_size_list
16
+ from hyvideo.modules.posemb_layers import get_nd_rotary_pos_embed, get_nd_rotary_pos_embed_new
17
+ from hyvideo.diffusion.schedulers import FlowMatchDiscreteScheduler
18
+ from hyvideo.diffusion.pipelines import HunyuanVideoPipeline
19
+ from hyvideo.diffusion.pipelines import HunyuanVideoAudioPipeline
20
+ from PIL import Image
21
+ import numpy as np
22
+ import torchvision.transforms as transforms
23
+ import cv2
24
+ from wan.utils.utils import resize_lanczos, calculate_new_dimensions
25
+ from hyvideo.data_kits.audio_preprocessor import encode_audio, get_facemask
26
+ from transformers import WhisperModel
27
+ from transformers import AutoFeatureExtractor
28
+ from hyvideo.data_kits.face_align import AlignImage
29
+ import librosa
30
+
31
+ def get_audio_feature(feature_extractor, audio_path, duration):
32
+ audio_input, sampling_rate = librosa.load(audio_path, duration=duration, sr=16000)
33
+ assert sampling_rate == 16000
34
+
35
+ audio_features = []
36
+ window = 750*640
37
+ for i in range(0, len(audio_input), window):
38
+ audio_feature = feature_extractor(audio_input[i:i+window],
39
+ sampling_rate=sampling_rate,
40
+ return_tensors="pt",
41
+ device="cuda"
42
+ ).input_features
43
+ audio_features.append(audio_feature)
44
+
45
+ audio_features = torch.cat(audio_features, dim=-1)
46
+ return audio_features, len(audio_input) // 640
47
+
48
+ def pad_image(crop_img, size, color=(255, 255, 255), resize_ratio=1):
49
+ crop_h, crop_w = crop_img.shape[:2]
50
+ target_w, target_h = size
51
+ scale_h, scale_w = target_h / crop_h, target_w / crop_w
52
+ if scale_w > scale_h:
53
+ resize_h = int(target_h*resize_ratio)
54
+ resize_w = int(crop_w / crop_h * resize_h)
55
+ else:
56
+ resize_w = int(target_w*resize_ratio)
57
+ resize_h = int(crop_h / crop_w * resize_w)
58
+ crop_img = cv2.resize(crop_img, (resize_w, resize_h))
59
+ pad_left = (target_w - resize_w) // 2
60
+ pad_top = (target_h - resize_h) // 2
61
+ pad_right = target_w - resize_w - pad_left
62
+ pad_bottom = target_h - resize_h - pad_top
63
+ crop_img = cv2.copyMakeBorder(crop_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=color)
64
+ return crop_img
65
+
66
+
67
+
68
+
69
+ def _merge_input_ids_with_image_features(self, image_features, inputs_embeds, input_ids, attention_mask, labels):
70
+ num_images, num_image_patches, embed_dim = image_features.shape
71
+ batch_size, sequence_length = input_ids.shape
72
+ left_padding = not torch.sum(input_ids[:, -1] == torch.tensor(self.pad_token_id))
73
+ # 1. Create a mask to know where special image tokens are
74
+ special_image_token_mask = input_ids == self.config.image_token_index
75
+ num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1)
76
+ # Compute the maximum embed dimension
77
+ max_embed_dim = (num_special_image_tokens.max() * (num_image_patches - 1)) + sequence_length
78
+ batch_indices, non_image_indices = torch.where(input_ids != self.config.image_token_index)
79
+
80
+ # 2. Compute the positions where text should be written
81
+ # Calculate new positions for text tokens in merged image-text sequence.
82
+ # `special_image_token_mask` identifies image tokens. Each image token will be replaced by `nb_text_tokens_per_images - 1` text tokens.
83
+ # `torch.cumsum` computes how each image token shifts subsequent text token positions.
84
+ # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
85
+ new_token_positions = torch.cumsum((special_image_token_mask * (num_image_patches - 1) + 1), -1) - 1
86
+ nb_image_pad = max_embed_dim - 1 - new_token_positions[:, -1]
87
+ if left_padding:
88
+ new_token_positions += nb_image_pad[:, None] # offset for left padding
89
+ text_to_overwrite = new_token_positions[batch_indices, non_image_indices]
90
+
91
+ # 3. Create the full embedding, already padded to the maximum position
92
+ final_embedding = torch.zeros(
93
+ batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device
94
+ )
95
+ final_attention_mask = torch.zeros(
96
+ batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device
97
+ )
98
+ if labels is not None:
99
+ final_labels = torch.full(
100
+ (batch_size, max_embed_dim), self.config.ignore_index, dtype=input_ids.dtype, device=input_ids.device
101
+ )
102
+ # In case the Vision model or the Language model has been offloaded to CPU, we need to manually
103
+ # set the corresponding tensors into their correct target device.
104
+ target_device = inputs_embeds.device
105
+ batch_indices, non_image_indices, text_to_overwrite = (
106
+ batch_indices.to(target_device),
107
+ non_image_indices.to(target_device),
108
+ text_to_overwrite.to(target_device),
109
+ )
110
+ attention_mask = attention_mask.to(target_device)
111
+
112
+ # 4. Fill the embeddings based on the mask. If we have ["hey" "<image>", "how", "are"]
113
+ # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features
114
+ final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices]
115
+ final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices]
116
+ if labels is not None:
117
+ final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices]
118
+
119
+ # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835)
120
+ image_to_overwrite = torch.full(
121
+ (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device
122
+ )
123
+ image_to_overwrite[batch_indices, text_to_overwrite] = False
124
+ image_to_overwrite &= image_to_overwrite.cumsum(-1) - 1 >= nb_image_pad[:, None].to(target_device)
125
+
126
+ if image_to_overwrite.sum() != image_features.shape[:-1].numel():
127
+ raise ValueError(
128
+ f"The input provided to the model are wrong. The number of image tokens is {torch.sum(special_image_token_mask)} while"
129
+ f" the number of image given to the model is {num_images}. This prevents correct indexing and breaks batch generation."
130
+ )
131
+
132
+ final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device)
133
+ final_attention_mask |= image_to_overwrite
134
+ position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1)
135
+
136
+ # 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
137
+ batch_indices, pad_indices = torch.where(input_ids == self.pad_token_id)
138
+ indices_to_mask = new_token_positions[batch_indices, pad_indices]
139
+
140
+ final_embedding[batch_indices, indices_to_mask] = 0
141
+
142
+ if labels is None:
143
+ final_labels = None
144
+
145
+ return final_embedding, final_attention_mask, final_labels, position_ids
146
+
147
+ def patched_llava_forward(
148
+ self,
149
+ input_ids: torch.LongTensor = None,
150
+ pixel_values: torch.FloatTensor = None,
151
+ attention_mask: Optional[torch.Tensor] = None,
152
+ position_ids: Optional[torch.LongTensor] = None,
153
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
154
+ inputs_embeds: Optional[torch.FloatTensor] = None,
155
+ vision_feature_layer: Optional[int] = None,
156
+ vision_feature_select_strategy: Optional[str] = None,
157
+ labels: Optional[torch.LongTensor] = None,
158
+ use_cache: Optional[bool] = None,
159
+ output_attentions: Optional[bool] = None,
160
+ output_hidden_states: Optional[bool] = None,
161
+ return_dict: Optional[bool] = None,
162
+ cache_position: Optional[torch.LongTensor] = None,
163
+ num_logits_to_keep: int = 0,
164
+ ):
165
+ from transformers.models.llava.modeling_llava import LlavaCausalLMOutputWithPast
166
+
167
+
168
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
169
+ output_hidden_states = (
170
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
171
+ )
172
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
173
+ vision_feature_layer = (
174
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
175
+ )
176
+ vision_feature_select_strategy = (
177
+ vision_feature_select_strategy
178
+ if vision_feature_select_strategy is not None
179
+ else self.config.vision_feature_select_strategy
180
+ )
181
+
182
+ if (input_ids is None) ^ (inputs_embeds is not None):
183
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
184
+
185
+ if pixel_values is not None and inputs_embeds is not None:
186
+ raise ValueError(
187
+ "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
188
+ )
189
+
190
+ if inputs_embeds is None:
191
+ inputs_embeds = self.get_input_embeddings()(input_ids)
192
+
193
+ image_features = None
194
+ if pixel_values is not None:
195
+ image_features = self.get_image_features(
196
+ pixel_values=pixel_values,
197
+ vision_feature_layer=vision_feature_layer,
198
+ vision_feature_select_strategy=vision_feature_select_strategy,
199
+ )
200
+
201
+
202
+ inputs_embeds, attention_mask, labels, position_ids = self._merge_input_ids_with_image_features(
203
+ image_features, inputs_embeds, input_ids, attention_mask, labels
204
+ )
205
+ cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)
206
+
207
+
208
+ outputs = self.language_model(
209
+ attention_mask=attention_mask,
210
+ position_ids=position_ids,
211
+ past_key_values=past_key_values,
212
+ inputs_embeds=inputs_embeds,
213
+ use_cache=use_cache,
214
+ output_attentions=output_attentions,
215
+ output_hidden_states=output_hidden_states,
216
+ return_dict=return_dict,
217
+ cache_position=cache_position,
218
+ num_logits_to_keep=num_logits_to_keep,
219
+ )
220
+
221
+ logits = outputs[0]
222
+
223
+ loss = None
224
+
225
+ if not return_dict:
226
+ output = (logits,) + outputs[1:]
227
+ return (loss,) + output if loss is not None else output
228
+
229
+ return LlavaCausalLMOutputWithPast(
230
+ loss=loss,
231
+ logits=logits,
232
+ past_key_values=outputs.past_key_values,
233
+ hidden_states=outputs.hidden_states,
234
+ attentions=outputs.attentions,
235
+ image_hidden_states=image_features if pixel_values is not None else None,
236
+ )
237
+
238
+ def adapt_avatar_model(model):
239
+ modules_dict= { k: m for k, m in model.named_modules()}
240
+ for model_layer, avatar_layer in model.double_stream_map.items():
241
+ module = modules_dict[f"audio_adapter_blocks.{avatar_layer}"]
242
+ target = modules_dict[f"double_blocks.{model_layer}"]
243
+ setattr(target, "audio_adapter", module )
244
+ delattr(model, "audio_adapter_blocks")
245
+
246
+ class DataPreprocess(object):
247
+ def __init__(self):
248
+ self.llava_size = (336, 336)
249
+ self.llava_transform = transforms.Compose(
250
+ [
251
+ transforms.Resize(self.llava_size, interpolation=transforms.InterpolationMode.BILINEAR),
252
+ transforms.ToTensor(),
253
+ transforms.Normalize((0.48145466, 0.4578275, 0.4082107), (0.26862954, 0.26130258, 0.27577711)),
254
+ ]
255
+ )
256
+
257
+ def get_batch(self, image , size, pad = False):
258
+ image = np.asarray(image)
259
+ if pad:
260
+ llava_item_image = pad_image(image.copy(), self.llava_size)
261
+ else:
262
+ llava_item_image = image.copy()
263
+ uncond_llava_item_image = np.ones_like(llava_item_image) * 255
264
+
265
+ if pad:
266
+ cat_item_image = pad_image(image.copy(), size)
267
+ else:
268
+ cat_item_image = image.copy()
269
+ llava_item_tensor = self.llava_transform(Image.fromarray(llava_item_image.astype(np.uint8)))
270
+ uncond_llava_item_tensor = self.llava_transform(Image.fromarray(uncond_llava_item_image))
271
+ cat_item_tensor = torch.from_numpy(cat_item_image.copy()).permute((2, 0, 1)) / 255.0
272
+ # batch = {
273
+ # "pixel_value_llava": llava_item_tensor.unsqueeze(0),
274
+ # "uncond_pixel_value_llava": uncond_llava_item_tensor.unsqueeze(0),
275
+ # 'pixel_value_ref': cat_item_tensor.unsqueeze(0),
276
+ # }
277
+ return llava_item_tensor.unsqueeze(0), uncond_llava_item_tensor.unsqueeze(0), cat_item_tensor.unsqueeze(0)
278
+
279
+ class Inference(object):
280
+ def __init__(
281
+ self,
282
+ i2v,
283
+ custom,
284
+ avatar,
285
+ enable_cfg,
286
+ vae,
287
+ vae_kwargs,
288
+ text_encoder,
289
+ model,
290
+ text_encoder_2=None,
291
+ pipeline=None,
292
+ feature_extractor=None,
293
+ wav2vec=None,
294
+ align_instance=None,
295
+ device=None,
296
+ ):
297
+ self.i2v = i2v
298
+ self.custom = custom
299
+ self.avatar = avatar
300
+ self.enable_cfg = enable_cfg
301
+ self.vae = vae
302
+ self.vae_kwargs = vae_kwargs
303
+
304
+ self.text_encoder = text_encoder
305
+ self.text_encoder_2 = text_encoder_2
306
+
307
+ self.model = model
308
+ self.pipeline = pipeline
309
+
310
+ self.feature_extractor=feature_extractor
311
+ self.wav2vec=wav2vec
312
+ self.align_instance=align_instance
313
+
314
+ self.device = "cuda"
315
+
316
+
317
+ @classmethod
318
+ def from_pretrained(cls, model_filepath, text_encoder_filepath, dtype = torch.bfloat16, VAE_dtype = torch.float16, mixed_precision_transformer =torch.bfloat16 , **kwargs):
319
+
320
+ device = "cuda"
321
+
322
+ import transformers
323
+ transformers.models.llava.modeling_llava.LlavaForConditionalGeneration.forward = patched_llava_forward # force legacy behaviour to be able to use tansformers v>(4.47)
324
+ transformers.models.llava.modeling_llava.LlavaForConditionalGeneration._merge_input_ids_with_image_features = _merge_input_ids_with_image_features
325
+
326
+ torch.set_grad_enabled(False)
327
+ text_len = 512
328
+ latent_channels = 16
329
+ precision = "bf16"
330
+ vae_precision = "fp32" if VAE_dtype == torch.float32 else "bf16"
331
+ embedded_cfg_scale = 6
332
+ i2v_condition_type = None
333
+ i2v_mode = "i2v" in model_filepath[0]
334
+ custom = False
335
+ avatar = False
336
+ if i2v_mode:
337
+ model_id = "HYVideo-T/2"
338
+ i2v_condition_type = "token_replace"
339
+ elif "custom" in model_filepath[0]:
340
+ model_id = "HYVideo-T/2-custom"
341
+ custom = True
342
+ elif "avatar" in model_filepath[0]:
343
+ model_id = "HYVideo-T/2-avatar"
344
+ text_len = 256
345
+ avatar = True
346
+ else:
347
+ model_id = "HYVideo-T/2-cfgdistill"
348
+
349
+
350
+ if i2v_mode and i2v_condition_type == "latent_concat":
351
+ in_channels = latent_channels * 2 + 1
352
+ image_embed_interleave = 2
353
+ elif i2v_mode and i2v_condition_type == "token_replace":
354
+ in_channels = latent_channels
355
+ image_embed_interleave = 4
356
+ else:
357
+ in_channels = latent_channels
358
+ image_embed_interleave = 1
359
+ out_channels = latent_channels
360
+ pinToMemory = kwargs.pop("pinToMemory", False)
361
+ partialPinning = kwargs.pop("partialPinning", False)
362
+ factor_kwargs = kwargs | {"device": "meta", "dtype": PRECISION_TO_TYPE[precision]}
363
+
364
+ if embedded_cfg_scale and i2v_mode:
365
+ factor_kwargs["guidance_embed"] = True
366
+
367
+ model = load_model(
368
+ model = model_id,
369
+ i2v_condition_type = i2v_condition_type,
370
+ in_channels=in_channels,
371
+ out_channels=out_channels,
372
+ factor_kwargs=factor_kwargs,
373
+ )
374
+
375
+
376
+ from mmgp import offload
377
+ # model = Inference.load_state_dict(args, model, model_filepath)
378
+
379
+ # model_filepath ="c:/temp/avatar/mp_rank_00_model_states.pt"
380
+ offload.load_model_data(model, model_filepath, pinToMemory = pinToMemory, partialPinning = partialPinning)
381
+ pass
382
+ # offload.save_model(model, "hunyuan_video_avatar_720_bf16.safetensors")
383
+ # offload.save_model(model, "hunyuan_video_avatar_720_quanto_bf16_int8.safetensors", do_quantize= True)
384
+
385
+ model.mixed_precision = mixed_precision_transformer
386
+
387
+ if model.mixed_precision :
388
+ model._lock_dtype = torch.float32
389
+ model.lock_layers_dtypes(torch.float32)
390
+ model.eval()
391
+
392
+ # ============================= Build extra models ========================
393
+ # VAE
394
+ if custom or avatar:
395
+ vae_configpath = "ckpts/hunyuan_video_custom_VAE_config.json"
396
+ vae_filepath = "ckpts/hunyuan_video_custom_VAE_fp32.safetensors"
397
+ # elif avatar:
398
+ # vae_configpath = "ckpts/config_vae_avatar.json"
399
+ # vae_filepath = "ckpts/vae_avatar.pt"
400
+ else:
401
+ vae_configpath = "ckpts/hunyuan_video_VAE_config.json"
402
+ vae_filepath = "ckpts/hunyuan_video_VAE_fp32.safetensors"
403
+
404
+ # config = AutoencoderKLCausal3D.load_config("ckpts/hunyuan_video_VAE_config.json")
405
+ # config = AutoencoderKLCausal3D.load_config("c:/temp/hvae/config_vae.json")
406
+
407
+ vae, _, s_ratio, t_ratio = load_vae( "884-16c-hy", vae_path= vae_filepath, vae_config_path= vae_configpath, vae_precision= vae_precision, device= "cpu", )
408
+
409
+ vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else (torch.float16 if avatar else torch.bfloat16)
410
+ vae._model_dtype = torch.float32 if VAE_dtype == torch.float32 else torch.bfloat16
411
+ vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
412
+ enable_cfg = False
413
+ # Text encoder
414
+ if i2v_mode:
415
+ text_encoder = "llm-i2v"
416
+ tokenizer = "llm-i2v"
417
+ prompt_template = "dit-llm-encode-i2v"
418
+ prompt_template_video = "dit-llm-encode-video-i2v"
419
+ elif custom or avatar :
420
+ text_encoder = "llm-i2v"
421
+ tokenizer = "llm-i2v"
422
+ prompt_template = "dit-llm-encode"
423
+ prompt_template_video = "dit-llm-encode-video"
424
+ enable_cfg = True
425
+ else:
426
+ text_encoder = "llm"
427
+ tokenizer = "llm"
428
+ prompt_template = "dit-llm-encode"
429
+ prompt_template_video = "dit-llm-encode-video"
430
+
431
+ if prompt_template_video is not None:
432
+ crop_start = PROMPT_TEMPLATE[prompt_template_video].get( "crop_start", 0 )
433
+ elif prompt_template is not None:
434
+ crop_start = PROMPT_TEMPLATE[prompt_template].get("crop_start", 0)
435
+ else:
436
+ crop_start = 0
437
+ max_length = text_len + crop_start
438
+
439
+ # prompt_template
440
+ prompt_template = PROMPT_TEMPLATE[prompt_template] if prompt_template is not None else None
441
+
442
+ # prompt_template_video
443
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] if prompt_template_video is not None else None
444
+
445
+
446
+ text_encoder = TextEncoder(
447
+ text_encoder_type=text_encoder,
448
+ max_length=max_length,
449
+ text_encoder_precision="fp16",
450
+ tokenizer_type=tokenizer,
451
+ i2v_mode=i2v_mode,
452
+ prompt_template=prompt_template,
453
+ prompt_template_video=prompt_template_video,
454
+ hidden_state_skip_layer=2,
455
+ apply_final_norm=False,
456
+ reproduce=True,
457
+ device="cpu",
458
+ image_embed_interleave=image_embed_interleave,
459
+ text_encoder_path = text_encoder_filepath
460
+ )
461
+
462
+ text_encoder_2 = TextEncoder(
463
+ text_encoder_type="clipL",
464
+ max_length=77,
465
+ text_encoder_precision="fp16",
466
+ tokenizer_type="clipL",
467
+ reproduce=True,
468
+ device="cpu",
469
+ )
470
+
471
+ feature_extractor = None
472
+ wav2vec = None
473
+ align_instance = None
474
+
475
+ if avatar:
476
+ feature_extractor = AutoFeatureExtractor.from_pretrained("ckpts/whisper-tiny/")
477
+ wav2vec = WhisperModel.from_pretrained("ckpts/whisper-tiny/").to(device="cpu", dtype=torch.float32)
478
+ wav2vec._model_dtype = torch.float32
479
+ wav2vec.requires_grad_(False)
480
+ align_instance = AlignImage("cuda", det_path="ckpts/det_align/detface.pt")
481
+ align_instance.facedet.model.to("cpu")
482
+
483
+ adapt_avatar_model(model)
484
+
485
+ return cls(
486
+ i2v=i2v_mode,
487
+ custom=custom,
488
+ avatar=avatar,
489
+ enable_cfg = enable_cfg,
490
+ vae=vae,
491
+ vae_kwargs=vae_kwargs,
492
+ text_encoder=text_encoder,
493
+ text_encoder_2=text_encoder_2,
494
+ model=model,
495
+ feature_extractor=feature_extractor,
496
+ wav2vec=wav2vec,
497
+ align_instance=align_instance,
498
+ device=device,
499
+ )
500
+
501
+
502
+
503
+ class HunyuanVideoSampler(Inference):
504
+ def __init__(
505
+ self,
506
+ i2v,
507
+ custom,
508
+ avatar,
509
+ enable_cfg,
510
+ vae,
511
+ vae_kwargs,
512
+ text_encoder,
513
+ model,
514
+ text_encoder_2=None,
515
+ pipeline=None,
516
+ feature_extractor=None,
517
+ wav2vec=None,
518
+ align_instance=None,
519
+ device=0,
520
+ ):
521
+ super().__init__(
522
+ i2v,
523
+ custom,
524
+ avatar,
525
+ enable_cfg,
526
+ vae,
527
+ vae_kwargs,
528
+ text_encoder,
529
+ model,
530
+ text_encoder_2=text_encoder_2,
531
+ pipeline=pipeline,
532
+ feature_extractor=feature_extractor,
533
+ wav2vec=wav2vec,
534
+ align_instance=align_instance,
535
+ device=device,
536
+ )
537
+
538
+ self.i2v_mode = i2v
539
+ self.enable_cfg = enable_cfg
540
+ self.pipeline = self.load_diffusion_pipeline(
541
+ avatar = self.avatar,
542
+ vae=self.vae,
543
+ text_encoder=self.text_encoder,
544
+ text_encoder_2=self.text_encoder_2,
545
+ model=self.model,
546
+ device=self.device,
547
+ )
548
+
549
+ if self.i2v_mode:
550
+ self.default_negative_prompt = NEGATIVE_PROMPT_I2V
551
+ else:
552
+ self.default_negative_prompt = NEGATIVE_PROMPT
553
+
554
+ @property
555
+ def _interrupt(self):
556
+ return self.pipeline._interrupt
557
+
558
+ @_interrupt.setter
559
+ def _interrupt(self, value):
560
+ self.pipeline._interrupt =value
561
+
562
+ def load_diffusion_pipeline(
563
+ self,
564
+ avatar,
565
+ vae,
566
+ text_encoder,
567
+ text_encoder_2,
568
+ model,
569
+ scheduler=None,
570
+ device=None,
571
+ progress_bar_config=None,
572
+ #data_type="video",
573
+ ):
574
+ """Load the denoising scheduler for inference."""
575
+ if scheduler is None:
576
+ scheduler = FlowMatchDiscreteScheduler(
577
+ shift=6.0,
578
+ reverse=True,
579
+ solver="euler",
580
+ )
581
+
582
+ if avatar:
583
+ pipeline = HunyuanVideoAudioPipeline(
584
+ vae=vae,
585
+ text_encoder=text_encoder,
586
+ text_encoder_2=text_encoder_2,
587
+ transformer=model,
588
+ scheduler=scheduler,
589
+ progress_bar_config=progress_bar_config,
590
+ )
591
+ else:
592
+ pipeline = HunyuanVideoPipeline(
593
+ vae=vae,
594
+ text_encoder=text_encoder,
595
+ text_encoder_2=text_encoder_2,
596
+ transformer=model,
597
+ scheduler=scheduler,
598
+ progress_bar_config=progress_bar_config,
599
+ )
600
+
601
+ return pipeline
602
+
603
+ def get_rotary_pos_embed_new(self, video_length, height, width, concat_dict={}):
604
+ target_ndim = 3
605
+ ndim = 5 - 2
606
+ latents_size = [(video_length-1)//4+1 , height//8, width//8]
607
+
608
+ if isinstance(self.model.patch_size, int):
609
+ assert all(s % self.model.patch_size == 0 for s in latents_size), \
610
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
611
+ f"but got {latents_size}."
612
+ rope_sizes = [s // self.model.patch_size for s in latents_size]
613
+ elif isinstance(self.model.patch_size, list):
614
+ assert all(s % self.model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), \
615
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), " \
616
+ f"but got {latents_size}."
617
+ rope_sizes = [s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)]
618
+
619
+ if len(rope_sizes) != target_ndim:
620
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
621
+ head_dim = self.model.hidden_size // self.model.heads_num
622
+ rope_dim_list = self.model.rope_dim_list
623
+ if rope_dim_list is None:
624
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
625
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
626
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed_new(rope_dim_list,
627
+ rope_sizes,
628
+ theta=256,
629
+ use_real=True,
630
+ theta_rescale_factor=1,
631
+ concat_dict=concat_dict)
632
+ return freqs_cos, freqs_sin
633
+
634
+ def get_rotary_pos_embed(self, video_length, height, width, enable_riflex = False):
635
+ target_ndim = 3
636
+ ndim = 5 - 2
637
+ # 884
638
+ vae = "884-16c-hy"
639
+ if "884" in vae:
640
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
641
+ elif "888" in vae:
642
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
643
+ else:
644
+ latents_size = [video_length, height // 8, width // 8]
645
+
646
+ if isinstance(self.model.patch_size, int):
647
+ assert all(s % self.model.patch_size == 0 for s in latents_size), (
648
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
649
+ f"but got {latents_size}."
650
+ )
651
+ rope_sizes = [s // self.model.patch_size for s in latents_size]
652
+ elif isinstance(self.model.patch_size, list):
653
+ assert all(
654
+ s % self.model.patch_size[idx] == 0
655
+ for idx, s in enumerate(latents_size)
656
+ ), (
657
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.model.patch_size}), "
658
+ f"but got {latents_size}."
659
+ )
660
+ rope_sizes = [
661
+ s // self.model.patch_size[idx] for idx, s in enumerate(latents_size)
662
+ ]
663
+
664
+ if len(rope_sizes) != target_ndim:
665
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
666
+ head_dim = self.model.hidden_size // self.model.heads_num
667
+ rope_dim_list = self.model.rope_dim_list
668
+ if rope_dim_list is None:
669
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
670
+ assert (
671
+ sum(rope_dim_list) == head_dim
672
+ ), "sum(rope_dim_list) should equal to head_dim of attention layer"
673
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
674
+ rope_dim_list,
675
+ rope_sizes,
676
+ theta=256,
677
+ use_real=True,
678
+ theta_rescale_factor=1,
679
+ L_test = (video_length - 1) // 4 + 1,
680
+ enable_riflex = enable_riflex
681
+ )
682
+ return freqs_cos, freqs_sin
683
+
684
+
685
+ def generate(
686
+ self,
687
+ input_prompt,
688
+ input_ref_images = None,
689
+ audio_guide = None,
690
+ fps = 24,
691
+ height=192,
692
+ width=336,
693
+ frame_num=129,
694
+ seed=None,
695
+ n_prompt=None,
696
+ sampling_steps=50,
697
+ guide_scale=1.0,
698
+ shift=5.0,
699
+ embedded_guidance_scale=6.0,
700
+ batch_size=1,
701
+ num_videos_per_prompt=1,
702
+ i2v_resolution="720p",
703
+ image_start=None,
704
+ enable_riflex = False,
705
+ i2v_condition_type: str = "token_replace",
706
+ i2v_stability=True,
707
+ VAE_tile_size = None,
708
+ joint_pass = False,
709
+ cfg_star_switch = False,
710
+ fit_into_canvas = True,
711
+ **kwargs,
712
+ ):
713
+
714
+ if VAE_tile_size != None:
715
+ self.vae.tile_sample_min_tsize = VAE_tile_size["tile_sample_min_tsize"]
716
+ self.vae.tile_latent_min_tsize = VAE_tile_size["tile_latent_min_tsize"]
717
+ self.vae.tile_sample_min_size = VAE_tile_size["tile_sample_min_size"]
718
+ self.vae.tile_latent_min_size = VAE_tile_size["tile_latent_min_size"]
719
+ self.vae.tile_overlap_factor = VAE_tile_size["tile_overlap_factor"]
720
+ self.vae.enable_tiling()
721
+
722
+ i2v_mode= self.i2v_mode
723
+ if not self.enable_cfg:
724
+ guide_scale=1.0
725
+
726
+ # ========================================================================
727
+ # Arguments: seed
728
+ # ========================================================================
729
+ if isinstance(seed, torch.Tensor):
730
+ seed = seed.tolist()
731
+ if seed is None:
732
+ seeds = [
733
+ random.randint(0, 1_000_000)
734
+ for _ in range(batch_size * num_videos_per_prompt)
735
+ ]
736
+ elif isinstance(seed, int):
737
+ seeds = [
738
+ seed + i
739
+ for _ in range(batch_size)
740
+ for i in range(num_videos_per_prompt)
741
+ ]
742
+ elif isinstance(seed, (list, tuple)):
743
+ if len(seed) == batch_size:
744
+ seeds = [
745
+ int(seed[i]) + j
746
+ for i in range(batch_size)
747
+ for j in range(num_videos_per_prompt)
748
+ ]
749
+ elif len(seed) == batch_size * num_videos_per_prompt:
750
+ seeds = [int(s) for s in seed]
751
+ else:
752
+ raise ValueError(
753
+ f"Length of seed must be equal to number of prompt(batch_size) or "
754
+ f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}."
755
+ )
756
+ else:
757
+ raise ValueError(
758
+ f"Seed must be an integer, a list of integers, or None, got {seed}."
759
+ )
760
+ from wan.utils.utils import seed_everything
761
+ seed_everything(seed)
762
+ generator = [torch.Generator("cuda").manual_seed(seed) for seed in seeds]
763
+ # generator = [torch.Generator(self.device).manual_seed(seed) for seed in seeds]
764
+
765
+ # ========================================================================
766
+ # Arguments: target_width, target_height, target_frame_num
767
+ # ========================================================================
768
+ if width <= 0 or height <= 0 or frame_num <= 0:
769
+ raise ValueError(
770
+ f"`height` and `width` and `frame_num` must be positive integers, got height={height}, width={width}, frame_num={frame_num}"
771
+ )
772
+ if (frame_num - 1) % 4 != 0:
773
+ raise ValueError(
774
+ f"`frame_num-1` must be a multiple of 4, got {frame_num}"
775
+ )
776
+
777
+ target_height = align_to(height, 16)
778
+ target_width = align_to(width, 16)
779
+ target_frame_num = frame_num
780
+
781
+ if input_ref_images != None:
782
+ # ip_cfg_scale = 3.0
783
+ ip_cfg_scale = 0
784
+ denoise_strength = 1
785
+ # guide_scale=7.5
786
+ # shift=13
787
+ name = "person"
788
+ input_ref_images = input_ref_images[0]
789
+
790
+ # ========================================================================
791
+ # Arguments: prompt, new_prompt, negative_prompt
792
+ # ========================================================================
793
+ if not isinstance(input_prompt, str):
794
+ raise TypeError(f"`prompt` must be a string, but got {type(input_prompt)}")
795
+ input_prompt = [input_prompt.strip()]
796
+
797
+ # negative prompt
798
+ if n_prompt is None or n_prompt == "":
799
+ n_prompt = self.default_negative_prompt
800
+ if guide_scale == 1.0:
801
+ n_prompt = ""
802
+ if not isinstance(n_prompt, str):
803
+ raise TypeError(
804
+ f"`negative_prompt` must be a string, but got {type(n_prompt)}"
805
+ )
806
+ n_prompt = [n_prompt.strip()]
807
+
808
+ # ========================================================================
809
+ # Scheduler
810
+ # ========================================================================
811
+ scheduler = FlowMatchDiscreteScheduler(
812
+ shift=shift,
813
+ reverse=True,
814
+ solver="euler"
815
+ )
816
+ self.pipeline.scheduler = scheduler
817
+
818
+ # ---------------------------------
819
+ # Reference condition
820
+ # ---------------------------------
821
+ img_latents = None
822
+ semantic_images = None
823
+ denoise_strength = 0
824
+ ip_cfg_scale = 0
825
+ if i2v_mode:
826
+ if i2v_resolution == "720p":
827
+ bucket_hw_base_size = 960
828
+ elif i2v_resolution == "540p":
829
+ bucket_hw_base_size = 720
830
+ elif i2v_resolution == "360p":
831
+ bucket_hw_base_size = 480
832
+ else:
833
+ raise ValueError(f"i2v_resolution: {i2v_resolution} must be in [360p, 540p, 720p]")
834
+
835
+ # semantic_images = [Image.open(i2v_image_path).convert('RGB')]
836
+ semantic_images = [image_start.convert('RGB')] #
837
+ origin_size = semantic_images[0].size
838
+ h, w = origin_size
839
+ h, w = calculate_new_dimensions(height, width, h, w, fit_into_canvas)
840
+ closest_size = (w, h)
841
+ # crop_size_list = generate_crop_size_list(bucket_hw_base_size, 32)
842
+ # aspect_ratios = np.array([round(float(h)/float(w), 5) for h, w in crop_size_list])
843
+ # closest_size, closest_ratio = get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
844
+ ref_image_transform = transforms.Compose([
845
+ transforms.Resize(closest_size),
846
+ transforms.CenterCrop(closest_size),
847
+ transforms.ToTensor(),
848
+ transforms.Normalize([0.5], [0.5])
849
+ ])
850
+
851
+ semantic_image_pixel_values = [ref_image_transform(semantic_image) for semantic_image in semantic_images]
852
+ semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(self.device)
853
+
854
+ with torch.autocast(device_type="cuda", dtype=torch.float16, enabled=True):
855
+ img_latents = self.pipeline.vae.encode(semantic_image_pixel_values).latent_dist.mode() # B, C, F, H, W
856
+ img_latents.mul_(self.pipeline.vae.config.scaling_factor)
857
+
858
+ target_height, target_width = closest_size
859
+
860
+ # ========================================================================
861
+ # Build Rope freqs
862
+ # ========================================================================
863
+
864
+ if input_ref_images == None:
865
+ freqs_cos, freqs_sin = self.get_rotary_pos_embed(target_frame_num, target_height, target_width, enable_riflex)
866
+ else:
867
+ if self.avatar:
868
+ w, h = input_ref_images.size
869
+ target_height, target_width = calculate_new_dimensions(target_height, target_width, h, w, fit_into_canvas)
870
+ concat_dict = {'mode': 'timecat', 'bias': -1}
871
+ freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(129, target_height, target_width, concat_dict)
872
+ else:
873
+ concat_dict = {'mode': 'timecat-w', 'bias': -1}
874
+ freqs_cos, freqs_sin = self.get_rotary_pos_embed_new(target_frame_num, target_height, target_width, concat_dict)
875
+
876
+ n_tokens = freqs_cos.shape[0]
877
+
878
+ callback = kwargs.pop("callback", None)
879
+ callback_steps = kwargs.pop("callback_steps", None)
880
+ # ========================================================================
881
+ # Pipeline inference
882
+ # ========================================================================
883
+
884
+ pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = None, None, None
885
+ if input_ref_images == None:
886
+ name = None
887
+ else:
888
+ pixel_value_llava, uncond_pixel_value_llava, pixel_value_ref = DataPreprocess().get_batch(input_ref_images, (target_width, target_height), pad = self.custom)
889
+
890
+ ref_latents, uncond_audio_prompts, audio_prompts, face_masks, motion_exp, motion_pose = None, None, None, None, None, None
891
+
892
+ if audio_guide != None:
893
+ if n_prompt == None or len(n_prompt) == 0:
894
+ n_prompt = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion, blurring, Lens changes"
895
+
896
+ uncond_pixel_value_llava = pixel_value_llava.clone()
897
+
898
+ pixel_value_ref = pixel_value_ref.unsqueeze(0)
899
+ self.align_instance.facedet.model.to("cuda")
900
+ face_masks = get_facemask(pixel_value_ref.to("cuda")*255, self.align_instance, area=3.0)
901
+ # iii = (face_masks.squeeze(0).squeeze(0).permute(1,2,0).repeat(1,1,3)*255).cpu().numpy().astype(np.uint8)
902
+ # image = Image.fromarray(iii)
903
+ # image.save("mask.png")
904
+ # jjj = (pixel_value_ref.squeeze(0).squeeze(0).permute(1,2,0)*255).cpu().numpy().astype(np.uint8)
905
+
906
+ self.align_instance.facedet.model.to("cpu")
907
+ # pixel_value_ref = pixel_value_ref.clone().repeat(1,129,1,1,1)
908
+
909
+ pixel_value_ref = pixel_value_ref.repeat(1,1+4*2,1,1,1)
910
+ pixel_value_ref = pixel_value_ref * 2 - 1
911
+ pixel_value_ref_for_vae = rearrange(pixel_value_ref, "b f c h w -> b c f h w")
912
+
913
+ vae_dtype = self.vae.dtype
914
+ with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_dtype != torch.float32):
915
+ ref_latents = self.vae.encode(pixel_value_ref_for_vae).latent_dist.sample()
916
+ ref_latents = torch.cat( [ref_latents[:,:, :1], ref_latents[:,:, 1:2].repeat(1,1,31,1,1), ref_latents[:,:, -1:]], dim=2)
917
+ pixel_value_ref, pixel_value_ref_for_vae = None, None
918
+
919
+ if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor:
920
+ ref_latents.sub_(self.vae.config.shift_factor).mul_(self.vae.config.scaling_factor)
921
+ else:
922
+ ref_latents.mul_(self.vae.config.scaling_factor)
923
+
924
+ # out_latents= ref_latents / self.vae.config.scaling_factor
925
+ # image = self.vae.decode(out_latents, return_dict=False, generator=generator)[0]
926
+ # image = image.clamp(-1, 1)
927
+ # from wan.utils.utils import cache_video
928
+ # cache_video( tensor=image, save_file="decode.mp4", fps=25, nrow=1, normalize=True, value_range=(-1, 1))
929
+
930
+
931
+ face_masks = torch.nn.functional.interpolate(face_masks.float().squeeze(2),
932
+ (ref_latents.shape[-2],
933
+ ref_latents.shape[-1]),
934
+ mode="bilinear").unsqueeze(2).to(dtype=ref_latents.dtype)
935
+
936
+ audio_input, audio_len = get_audio_feature(self.feature_extractor, audio_guide, duration = frame_num/fps )
937
+ audio_prompts = audio_input[0]
938
+ weight_dtype = audio_prompts.dtype
939
+
940
+ motion_pose = np.array([25] * 4)
941
+ motion_exp = np.array([30] * 4)
942
+ motion_pose = torch.from_numpy(motion_pose).unsqueeze(0)
943
+ motion_exp = torch.from_numpy(motion_exp).unsqueeze(0)
944
+ audio_prompts = encode_audio(self.wav2vec, audio_prompts.to(dtype=self.wav2vec.dtype), fps, num_frames=audio_len)
945
+ audio_prompts = audio_prompts.to(self.model.dtype)
946
+ if audio_prompts.shape[1] <= 129:
947
+ audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1,129-audio_prompts.shape[1], 1, 1, 1)], dim=1)
948
+ else:
949
+ audio_prompts = torch.cat([audio_prompts, torch.zeros_like(audio_prompts[:, :1]).repeat(1, 5, 1, 1, 1)], dim=1)
950
+ uncond_audio_prompts = torch.zeros_like(audio_prompts[:,:129])
951
+ # target_frame_num = min(target_frame_num, audio_len)
952
+ samples = self.pipeline(
953
+ prompt=input_prompt,
954
+ height=target_height,
955
+ width=target_width,
956
+ video_length=target_frame_num,
957
+ num_inference_steps=sampling_steps,
958
+ guidance_scale=guide_scale,
959
+ negative_prompt=n_prompt,
960
+ num_videos_per_prompt=num_videos_per_prompt,
961
+ generator=generator,
962
+ output_type="pil",
963
+ name = name,
964
+
965
+ pixel_value_ref = pixel_value_ref,
966
+ ref_latents=ref_latents, # [1, 16, 1, h//8, w//8]
967
+ pixel_value_llava=pixel_value_llava, # [1, 3, 336, 336]
968
+ uncond_pixel_value_llava=uncond_pixel_value_llava,
969
+ face_masks=face_masks, # [b f h w]
970
+ audio_prompts=audio_prompts,
971
+ uncond_audio_prompts=uncond_audio_prompts,
972
+ motion_exp=motion_exp,
973
+ motion_pose=motion_pose,
974
+ fps= torch.from_numpy(np.array(fps)),
975
+
976
+ denoise_strength=denoise_strength,
977
+ ip_cfg_scale=ip_cfg_scale,
978
+ freqs_cis=(freqs_cos, freqs_sin),
979
+ n_tokens=n_tokens,
980
+ embedded_guidance_scale=embedded_guidance_scale,
981
+ data_type="video" if target_frame_num > 1 else "image",
982
+ is_progress_bar=True,
983
+ vae_ver="884-16c-hy",
984
+ enable_tiling=True,
985
+ i2v_mode=i2v_mode,
986
+ i2v_condition_type=i2v_condition_type,
987
+ i2v_stability=i2v_stability,
988
+ img_latents=img_latents,
989
+ semantic_images=semantic_images,
990
+ joint_pass = joint_pass,
991
+ cfg_star_rescale = cfg_star_switch,
992
+ callback = callback,
993
+ callback_steps = callback_steps,
994
+ )[0]
995
+
996
+ if samples == None:
997
+ return None
998
+ samples = samples.squeeze(0)
999
+
1000
+ return samples
hyvideo/modules/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import HYVideoDiffusionTransformer, HUNYUAN_VIDEO_CONFIG
2
+
3
+
4
+ def load_model(model, i2v_condition_type, in_channels, out_channels, factor_kwargs):
5
+ """load hunyuan video model
6
+
7
+ Args:
8
+ args (dict): model args
9
+ in_channels (int): input channels number
10
+ out_channels (int): output channels number
11
+ factor_kwargs (dict): factor kwargs
12
+
13
+ Returns:
14
+ model (nn.Module): The hunyuan video model
15
+ """
16
+ if model in HUNYUAN_VIDEO_CONFIG.keys():
17
+ model = HYVideoDiffusionTransformer(
18
+ i2v_condition_type = i2v_condition_type,
19
+ in_channels=in_channels,
20
+ out_channels=out_channels,
21
+ **HUNYUAN_VIDEO_CONFIG[model],
22
+ **factor_kwargs,
23
+ )
24
+ return model
25
+ else:
26
+ raise NotImplementedError()
hyvideo/modules/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hyvideo/modules/attenion.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from importlib.metadata import version
8
+
9
+ def clear_list(l):
10
+ for i in range(len(l)):
11
+ l[i] = None
12
+
13
+ try:
14
+ import flash_attn
15
+ from flash_attn.flash_attn_interface import _flash_attn_forward
16
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
17
+ except ImportError:
18
+ flash_attn = None
19
+ flash_attn_varlen_func = None
20
+ _flash_attn_forward = None
21
+
22
+ try:
23
+ from xformers.ops import memory_efficient_attention
24
+ except ImportError:
25
+ memory_efficient_attention = None
26
+
27
+ try:
28
+ from sageattention import sageattn_varlen
29
+ def sageattn_varlen_wrapper(
30
+ q,
31
+ k,
32
+ v,
33
+ cu_seqlens_q,
34
+ cu_seqlens_kv,
35
+ max_seqlen_q,
36
+ max_seqlen_kv,
37
+ ):
38
+ return sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
39
+ except ImportError:
40
+ sageattn_varlen_wrapper = None
41
+
42
+ try:
43
+ from sageattention import sageattn
44
+ @torch.compiler.disable()
45
+ def sageattn_wrapper(
46
+ qkv_list,
47
+ attention_length
48
+ ):
49
+ q,k, v = qkv_list
50
+ padding_length = q.shape[1] -attention_length
51
+ q = q[:, :attention_length, :, : ]
52
+ k = k[:, :attention_length, :, : ]
53
+ v = v[:, :attention_length, :, : ]
54
+
55
+ o = sageattn(q, k, v, tensor_layout="NHD")
56
+ del q, k ,v
57
+ clear_list(qkv_list)
58
+
59
+ if padding_length > 0:
60
+ o = torch.cat([o, torch.empty( (o.shape[0], padding_length, *o.shape[-2:]), dtype= o.dtype, device=o.device ) ], 1)
61
+
62
+ return o
63
+
64
+ except ImportError:
65
+ sageattn = None
66
+
67
+
68
+ def get_attention_modes():
69
+ ret = ["sdpa", "auto"]
70
+ if flash_attn != None:
71
+ ret.append("flash")
72
+ if memory_efficient_attention != None:
73
+ ret.append("xformers")
74
+ if sageattn_varlen_wrapper != None:
75
+ ret.append("sage")
76
+ if sageattn != None and version("sageattention").startswith("2") :
77
+ ret.append("sage2")
78
+
79
+ return ret
80
+
81
+
82
+
83
+ MEMORY_LAYOUT = {
84
+ "sdpa": (
85
+ lambda x: x.transpose(1, 2),
86
+ lambda x: x.transpose(1, 2),
87
+ ),
88
+ "xformers": (
89
+ lambda x: x,
90
+ lambda x: x,
91
+ ),
92
+ "sage2": (
93
+ lambda x: x,
94
+ lambda x: x,
95
+ ),
96
+ "sage": (
97
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
98
+ lambda x: x,
99
+ ),
100
+ "flash": (
101
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
102
+ lambda x: x,
103
+ ),
104
+ "torch": (
105
+ lambda x: x.transpose(1, 2),
106
+ lambda x: x.transpose(1, 2),
107
+ ),
108
+ "vanilla": (
109
+ lambda x: x.transpose(1, 2),
110
+ lambda x: x.transpose(1, 2),
111
+ ),
112
+ }
113
+
114
+ @torch.compiler.disable()
115
+ def sdpa_wrapper(
116
+ qkv_list,
117
+ attention_length
118
+ ):
119
+ q,k, v = qkv_list
120
+ padding_length = q.shape[2] -attention_length
121
+ q = q[:, :, :attention_length, :]
122
+ k = k[:, :, :attention_length, :]
123
+ v = v[:, :, :attention_length, :]
124
+
125
+ o = F.scaled_dot_product_attention(
126
+ q, k, v, attn_mask=None, is_causal=False
127
+ )
128
+ del q, k ,v
129
+ clear_list(qkv_list)
130
+
131
+ if padding_length > 0:
132
+ o = torch.cat([o, torch.empty( (*o.shape[:2], padding_length, o.shape[-1]), dtype= o.dtype, device=o.device ) ], 2)
133
+
134
+ return o
135
+
136
+ def get_cu_seqlens(text_mask, img_len):
137
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
138
+
139
+ Args:
140
+ text_mask (torch.Tensor): the mask of text
141
+ img_len (int): the length of image
142
+
143
+ Returns:
144
+ torch.Tensor: the calculated cu_seqlens for flash attention
145
+ """
146
+ batch_size = text_mask.shape[0]
147
+ text_len = text_mask.sum(dim=1)
148
+ max_len = text_mask.shape[1] + img_len
149
+
150
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
151
+
152
+ for i in range(batch_size):
153
+ s = text_len[i] + img_len
154
+ s1 = i * max_len + s
155
+ s2 = (i + 1) * max_len
156
+ cu_seqlens[2 * i + 1] = s1
157
+ cu_seqlens[2 * i + 2] = s2
158
+
159
+ return cu_seqlens
160
+
161
+
162
+ def attention(
163
+ qkv_list,
164
+ mode="flash",
165
+ drop_rate=0,
166
+ attn_mask=None,
167
+ causal=False,
168
+ cu_seqlens_q=None,
169
+ cu_seqlens_kv=None,
170
+ max_seqlen_q=None,
171
+ max_seqlen_kv=None,
172
+ batch_size=1,
173
+ ):
174
+ """
175
+ Perform QKV self attention.
176
+
177
+ Args:
178
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
179
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
180
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
181
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
182
+ drop_rate (float): Dropout rate in attention map. (default: 0)
183
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
184
+ (default: None)
185
+ causal (bool): Whether to use causal attention. (default: False)
186
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
187
+ used to index into q.
188
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
189
+ used to index into kv.
190
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
191
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
192
+
193
+ Returns:
194
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
195
+ """
196
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
197
+ q , k , v = qkv_list
198
+ clear_list(qkv_list)
199
+ del qkv_list
200
+ padding_length = 0
201
+ # if attn_mask == None and mode == "sdpa":
202
+ # padding_length = q.shape[1] - cu_seqlens_q
203
+ # q = q[:, :cu_seqlens_q, ... ]
204
+ # k = k[:, :cu_seqlens_kv, ... ]
205
+ # v = v[:, :cu_seqlens_kv, ... ]
206
+
207
+ q = pre_attn_layout(q)
208
+ k = pre_attn_layout(k)
209
+ v = pre_attn_layout(v)
210
+
211
+ if mode == "torch":
212
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
213
+ attn_mask = attn_mask.to(q.dtype)
214
+ x = F.scaled_dot_product_attention(
215
+ q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
216
+ )
217
+
218
+ elif mode == "sdpa":
219
+ # if attn_mask is not None and attn_mask.dtype != torch.bool:
220
+ # attn_mask = attn_mask.to(q.dtype)
221
+ # x = F.scaled_dot_product_attention(
222
+ # q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
223
+ # )
224
+ assert attn_mask==None
225
+ qkv_list = [q, k, v]
226
+ del q, k , v
227
+ x = sdpa_wrapper( qkv_list, cu_seqlens_q )
228
+
229
+ elif mode == "xformers":
230
+ x = memory_efficient_attention(
231
+ q, k, v , attn_bias= attn_mask
232
+ )
233
+
234
+ elif mode == "sage2":
235
+ qkv_list = [q, k, v]
236
+ del q, k , v
237
+ x = sageattn_wrapper(qkv_list, cu_seqlens_q)
238
+
239
+ elif mode == "sage":
240
+ x = sageattn_varlen_wrapper(
241
+ q,
242
+ k,
243
+ v,
244
+ cu_seqlens_q,
245
+ cu_seqlens_kv,
246
+ max_seqlen_q,
247
+ max_seqlen_kv,
248
+ )
249
+ # x with shape [(bxs), a, d]
250
+ x = x.view(
251
+ batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
252
+ ) # reshape x to [b, s, a, d]
253
+
254
+ elif mode == "flash":
255
+ x = flash_attn_varlen_func(
256
+ q,
257
+ k,
258
+ v,
259
+ cu_seqlens_q,
260
+ cu_seqlens_kv,
261
+ max_seqlen_q,
262
+ max_seqlen_kv,
263
+ )
264
+ # x with shape [(bxs), a, d]
265
+ x = x.view(
266
+ batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]
267
+ ) # reshape x to [b, s, a, d]
268
+ elif mode == "vanilla":
269
+ scale_factor = 1 / math.sqrt(q.size(-1))
270
+
271
+ b, a, s, _ = q.shape
272
+ s1 = k.size(2)
273
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
274
+ if causal:
275
+ # Only applied to self attention
276
+ assert (
277
+ attn_mask is None
278
+ ), "Causal mask and attn_mask cannot be used together"
279
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
280
+ diagonal=0
281
+ )
282
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
283
+ attn_bias.to(q.dtype)
284
+
285
+ if attn_mask is not None:
286
+ if attn_mask.dtype == torch.bool:
287
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
288
+ else:
289
+ attn_bias += attn_mask
290
+
291
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
292
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
293
+ attn += attn_bias
294
+ attn = attn.softmax(dim=-1)
295
+ attn = torch.dropout(attn, p=drop_rate, train=True)
296
+ x = attn @ v
297
+ else:
298
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
299
+
300
+ x = post_attn_layout(x)
301
+ b, s, a, d = x.shape
302
+ out = x.reshape(b, s, -1)
303
+ if padding_length > 0 :
304
+ out = torch.cat([out, torch.empty( (out.shape[0], padding_length, out.shape[2]), dtype= out.dtype, device=out.device ) ], 1)
305
+
306
+ return out
307
+
308
+
309
+ def parallel_attention(
310
+ hybrid_seq_parallel_attn,
311
+ q,
312
+ k,
313
+ v,
314
+ img_q_len,
315
+ img_kv_len,
316
+ cu_seqlens_q,
317
+ cu_seqlens_kv
318
+ ):
319
+ attn1 = hybrid_seq_parallel_attn(
320
+ None,
321
+ q[:, :img_q_len, :, :],
322
+ k[:, :img_kv_len, :, :],
323
+ v[:, :img_kv_len, :, :],
324
+ dropout_p=0.0,
325
+ causal=False,
326
+ joint_tensor_query=q[:,img_q_len:cu_seqlens_q[1]],
327
+ joint_tensor_key=k[:,img_kv_len:cu_seqlens_kv[1]],
328
+ joint_tensor_value=v[:,img_kv_len:cu_seqlens_kv[1]],
329
+ joint_strategy="rear",
330
+ )
331
+ if flash_attn.__version__ >= '2.7.0':
332
+ attn2, *_ = _flash_attn_forward(
333
+ q[:,cu_seqlens_q[1]:],
334
+ k[:,cu_seqlens_kv[1]:],
335
+ v[:,cu_seqlens_kv[1]:],
336
+ dropout_p=0.0,
337
+ softmax_scale=q.shape[-1] ** (-0.5),
338
+ causal=False,
339
+ window_size_left=-1,
340
+ window_size_right=-1,
341
+ softcap=0.0,
342
+ alibi_slopes=None,
343
+ return_softmax=False,
344
+ )
345
+ else:
346
+ attn2, *_ = _flash_attn_forward(
347
+ q[:,cu_seqlens_q[1]:],
348
+ k[:,cu_seqlens_kv[1]:],
349
+ v[:,cu_seqlens_kv[1]:],
350
+ dropout_p=0.0,
351
+ softmax_scale=q.shape[-1] ** (-0.5),
352
+ causal=False,
353
+ window_size=(-1, -1),
354
+ softcap=0.0,
355
+ alibi_slopes=None,
356
+ return_softmax=False,
357
+ )
358
+ attn = torch.cat([attn1, attn2], dim=1)
359
+ b, s, a, d = attn.shape
360
+ attn = attn.reshape(b, s, -1)
361
+
362
+ return attn
hyvideo/modules/audio_adapters.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module provides the implementation of an Audio Projection Model, which is designed for
3
+ audio processing tasks. The model takes audio embeddings as input and outputs context tokens
4
+ that can be used for various downstream applications, such as audio analysis or synthesis.
5
+
6
+ The AudioProjModel class is based on the ModelMixin class from the diffusers library, which
7
+ provides a foundation for building custom models. This implementation includes multiple linear
8
+ layers with ReLU activation functions and a LayerNorm for normalization.
9
+
10
+ Key Features:
11
+ - Audio embedding input with flexible sequence length and block structure.
12
+ - Multiple linear layers for feature transformation.
13
+ - ReLU activation for non-linear transformation.
14
+ - LayerNorm for stabilizing and speeding up training.
15
+ - Rearrangement of input embeddings to match the model's expected input shape.
16
+ - Customizable number of blocks, channels, and context tokens for adaptability.
17
+
18
+ The module is structured to be easily integrated into larger systems or used as a standalone
19
+ component for audio feature extraction and processing.
20
+
21
+ Classes:
22
+ - AudioProjModel: A class representing the audio projection model with configurable parameters.
23
+
24
+ Functions:
25
+ - (none)
26
+
27
+ Dependencies:
28
+ - torch: For tensor operations and neural network components.
29
+ - diffusers: For the ModelMixin base class.
30
+ - einops: For tensor rearrangement operations.
31
+
32
+ """
33
+
34
+ import torch
35
+ from diffusers import ModelMixin
36
+ from einops import rearrange
37
+
38
+ import math
39
+ import torch.nn as nn
40
+
41
+ class AudioProjNet2(ModelMixin):
42
+ """Audio Projection Model
43
+
44
+ This class defines an audio projection model that takes audio embeddings as input
45
+ and produces context tokens as output. The model is based on the ModelMixin class
46
+ and consists of multiple linear layers and activation functions. It can be used
47
+ for various audio processing tasks.
48
+
49
+ Attributes:
50
+ seq_len (int): The length of the audio sequence.
51
+ blocks (int): The number of blocks in the audio projection model.
52
+ channels (int): The number of channels in the audio projection model.
53
+ intermediate_dim (int): The intermediate dimension of the model.
54
+ context_tokens (int): The number of context tokens in the output.
55
+ output_dim (int): The output dimension of the context tokens.
56
+
57
+ Methods:
58
+ __init__(self, seq_len=5, blocks=12, channels=768, intermediate_dim=512, context_tokens=32, output_dim=768):
59
+ Initializes the AudioProjModel with the given parameters.
60
+ forward(self, audio_embeds):
61
+ Defines the forward pass for the AudioProjModel.
62
+ Parameters:
63
+ audio_embeds (torch.Tensor): The input audio embeddings with shape (batch_size, video_length, blocks, channels).
64
+ Returns:
65
+ context_tokens (torch.Tensor): The output context tokens with shape (batch_size, video_length, context_tokens, output_dim).
66
+
67
+ """
68
+
69
+ def __init__(
70
+ self,
71
+ seq_len=5,
72
+ blocks=12, # add a new parameter blocks
73
+ channels=768, # add a new parameter channels
74
+ intermediate_dim=512,
75
+ output_dim=768,
76
+ context_tokens=4,
77
+ ):
78
+ super().__init__()
79
+
80
+ self.seq_len = seq_len
81
+ self.blocks = blocks
82
+ self.channels = channels
83
+ self.input_dim = (
84
+ seq_len * blocks * channels
85
+ )
86
+ self.intermediate_dim = intermediate_dim
87
+ self.context_tokens = context_tokens
88
+ self.output_dim = output_dim
89
+
90
+ # define multiple linear layers
91
+ self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
92
+ self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
93
+ self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
94
+
95
+ self.norm = nn.LayerNorm(output_dim)
96
+
97
+
98
+ def forward(self, audio_embeds):
99
+
100
+ video_length = audio_embeds.shape[1]
101
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
102
+ batch_size, window_size, blocks, channels = audio_embeds.shape
103
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
104
+
105
+ audio_embeds = torch.relu(self.proj1(audio_embeds))
106
+ audio_embeds = torch.relu(self.proj2(audio_embeds))
107
+
108
+ context_tokens = self.proj3(audio_embeds).reshape(
109
+ batch_size, self.context_tokens, self.output_dim
110
+ )
111
+ context_tokens = self.norm(context_tokens)
112
+ out_all = rearrange(
113
+ context_tokens, "(bz f) m c -> bz f m c", f=video_length
114
+ )
115
+
116
+ return out_all
117
+
118
+
119
+ def reshape_tensor(x, heads):
120
+ bs, length, width = x.shape
121
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
122
+ x = x.view(bs, length, heads, -1)
123
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
124
+ x = x.transpose(1, 2)
125
+ # (bs, n_heads, length, dim_per_head)
126
+ x = x.reshape(bs, heads, length, -1)
127
+ return x
128
+
129
+
130
+ class PerceiverAttentionCA(nn.Module):
131
+ def __init__(self, *, dim=3072, dim_head=1024, heads=33):
132
+ super().__init__()
133
+ self.scale = dim_head ** -0.5
134
+ self.dim_head = dim_head
135
+ self.heads = heads
136
+ inner_dim = dim_head #* heads
137
+
138
+ self.norm1 = nn.LayerNorm(dim)
139
+ self.norm2 = nn.LayerNorm(dim)
140
+
141
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
142
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
143
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
144
+
145
+ import torch.nn.init as init
146
+ init.zeros_(self.to_out.weight)
147
+ if self.to_out.bias is not None:
148
+ init.zeros_(self.to_out.bias)
149
+
150
+ def forward(self, x, latents):
151
+ """
152
+ Args:
153
+ x (torch.Tensor): image features
154
+ shape (b, t, aa, D)
155
+ latent (torch.Tensor): latent features
156
+ shape (b, t, hw, D)
157
+ """
158
+ x = self.norm1(x)
159
+ latents = self.norm2(latents)
160
+ # print("latents shape: ", latents.shape)
161
+ # print("x shape: ", x.shape)
162
+ q = self.to_q(latents)
163
+ k, v = self.to_kv(x).chunk(2, dim=-1)
164
+
165
+
166
+ # attention
167
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
168
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
169
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
170
+ out = weight @ v
171
+
172
+ # out = out.permute(0, 2, 1, 3)
173
+ return self.to_out(out)
174
+ #def forward(self, x, latents):
175
+ # """
176
+ # Args:
177
+ # x (torch.Tensor): image features
178
+ # shape (b, t, aa, D)
179
+ # latent (torch.Tensor): latent features
180
+ # shape (b, t, hw, D)
181
+ # """
182
+ # if get_sequence_parallel_state():
183
+ # sp_size = nccl_info.sp_size
184
+ # sp_rank = nccl_info.rank_within_group
185
+ # print("rank:", latents.shape, sp_size, sp_rank)
186
+ # latents = torch.chunk(latents, sp_size, dim=1)[sp_rank]
187
+
188
+ # x = self.norm1(x)
189
+ # latents = self.norm2(latents)
190
+ # # print("latents shape: ", latents.shape)
191
+ # # print("x shape: ", x.shape)
192
+ # q = self.to_q(latents)
193
+ # k, v = self.to_kv(x).chunk(2, dim=-1)
194
+
195
+ # # print("q, k, v: ", q.shape, k.shape, v.shape)
196
+
197
+ # # attention
198
+ # #scale = 1 / math.sqrt(math.sqrt(self.dim_head))
199
+ # #weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
200
+ # #weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
201
+ # #out = weight @ v
202
+ # def shrink_head(encoder_state, dim):
203
+ # local_heads = encoder_state.shape[dim] // nccl_info.sp_size
204
+ # return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
205
+
206
+ # if get_sequence_parallel_state():
207
+ # # batch_size, seq_len, attn_heads, head_dim
208
+ # q = all_to_all_4D(q, scatter_dim=2, gather_dim=1) # [2, 32256, 24, 128]
209
+ # k = shrink_head(k ,dim=2)
210
+ # v = shrink_head(v ,dim=2)
211
+ # qkv = torch.stack([query, key, value], dim=2)
212
+ # attn = flash_attn_no_pad(qkv, causal=False, dropout_p=0.0, softmax_scale=None)
213
+ # # out = out.permute(0, 2, 1, 3)
214
+ # #b, s, a, d = attn.shape
215
+ # #attn = attn.reshape(b, s, -1)
216
+ #
217
+ # out = self.to_out(attn)
218
+ # if get_sequence_parallel_state():
219
+ # out = all_gather(out, dim=1)
220
+ # return out
hyvideo/modules/embed_layers.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange, repeat
5
+
6
+ from ..utils.helpers import to_2tuple
7
+
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ dtype=None,
32
+ device=None,
33
+ ):
34
+ factory_kwargs = {"dtype": dtype, "device": device}
35
+ super().__init__()
36
+ patch_size = to_2tuple(patch_size)
37
+ self.patch_size = patch_size
38
+ self.flatten = flatten
39
+
40
+ self.proj = nn.Conv3d(
41
+ in_chans,
42
+ embed_dim,
43
+ kernel_size=patch_size,
44
+ stride=patch_size,
45
+ bias=bias,
46
+ **factory_kwargs
47
+ )
48
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
49
+ if bias:
50
+ nn.init.zeros_(self.proj.bias)
51
+
52
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
53
+
54
+ def forward(self, x):
55
+ x = self.proj(x)
56
+ shape = x.shape
57
+ if self.flatten:
58
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
59
+ x = self.norm(x)
60
+ return x, shape
61
+
62
+
63
+ class TextProjection(nn.Module):
64
+ """
65
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
66
+
67
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
68
+ """
69
+
70
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
71
+ factory_kwargs = {"dtype": dtype, "device": device}
72
+ super().__init__()
73
+ self.linear_1 = nn.Linear(
74
+ in_features=in_channels,
75
+ out_features=hidden_size,
76
+ bias=True,
77
+ **factory_kwargs
78
+ )
79
+ self.act_1 = act_layer()
80
+ self.linear_2 = nn.Linear(
81
+ in_features=hidden_size,
82
+ out_features=hidden_size,
83
+ bias=True,
84
+ **factory_kwargs
85
+ )
86
+
87
+ def forward(self, caption):
88
+ hidden_states = self.linear_1(caption)
89
+ hidden_states = self.act_1(hidden_states)
90
+ hidden_states = self.linear_2(hidden_states)
91
+ return hidden_states
92
+
93
+
94
+ def timestep_embedding(t, dim, max_period=10000):
95
+ """
96
+ Create sinusoidal timestep embeddings.
97
+
98
+ Args:
99
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
100
+ dim (int): the dimension of the output.
101
+ max_period (int): controls the minimum frequency of the embeddings.
102
+
103
+ Returns:
104
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
105
+
106
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
107
+ """
108
+ half = dim // 2
109
+ freqs = torch.exp(
110
+ -math.log(max_period)
111
+ * torch.arange(start=0, end=half, dtype=torch.float32)
112
+ / half
113
+ ).to(device=t.device)
114
+ args = t[:, None].float() * freqs[None]
115
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
116
+ if dim % 2:
117
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
118
+ return embedding
119
+
120
+
121
+ class TimestepEmbedder(nn.Module):
122
+ """
123
+ Embeds scalar timesteps into vector representations.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ hidden_size,
129
+ act_layer,
130
+ frequency_embedding_size=256,
131
+ max_period=10000,
132
+ out_size=None,
133
+ dtype=None,
134
+ device=None,
135
+ ):
136
+ factory_kwargs = {"dtype": dtype, "device": device}
137
+ super().__init__()
138
+ self.frequency_embedding_size = frequency_embedding_size
139
+ self.max_period = max_period
140
+ if out_size is None:
141
+ out_size = hidden_size
142
+
143
+ self.mlp = nn.Sequential(
144
+ nn.Linear(
145
+ frequency_embedding_size, hidden_size, bias=True, **factory_kwargs
146
+ ),
147
+ act_layer(),
148
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
149
+ )
150
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
151
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
152
+
153
+ def forward(self, t):
154
+ t_freq = timestep_embedding(
155
+ t, self.frequency_embedding_size, self.max_period
156
+ ).type(self.mlp[0].weight.dtype)
157
+ t_emb = self.mlp(t_freq)
158
+ return t_emb
hyvideo/modules/mlp_layers.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modulate_layers import modulate_
10
+ from ..utils.helpers import to_2tuple
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(
38
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39
+ )
40
+ self.act = act_layer()
41
+ self.drop1 = nn.Dropout(drop_probs[0])
42
+ self.norm = (
43
+ norm_layer(hidden_channels, **factory_kwargs)
44
+ if norm_layer is not None
45
+ else nn.Identity()
46
+ )
47
+ self.fc2 = linear_layer(
48
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
49
+ )
50
+ self.drop2 = nn.Dropout(drop_probs[1])
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop1(x)
56
+ x = self.norm(x)
57
+ x = self.fc2(x)
58
+ x = self.drop2(x)
59
+ return x
60
+
61
+ def apply_(self, x, divide = 4):
62
+ x_shape = x.shape
63
+ x = x.view(-1, x.shape[-1])
64
+ chunk_size = int(x_shape[1]/divide)
65
+ x_chunks = torch.split(x, chunk_size)
66
+ for i, x_chunk in enumerate(x_chunks):
67
+ mlp_chunk = self.fc1(x_chunk)
68
+ mlp_chunk = self.act(mlp_chunk)
69
+ mlp_chunk = self.drop1(mlp_chunk)
70
+ mlp_chunk = self.norm(mlp_chunk)
71
+ mlp_chunk = self.fc2(mlp_chunk)
72
+ x_chunk[...] = self.drop2(mlp_chunk)
73
+ return x
74
+
75
+ #
76
+ class MLPEmbedder(nn.Module):
77
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
78
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
79
+ factory_kwargs = {"device": device, "dtype": dtype}
80
+ super().__init__()
81
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
82
+ self.silu = nn.SiLU()
83
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ return self.out_layer(self.silu(self.in_layer(x)))
87
+
88
+
89
+ class FinalLayer(nn.Module):
90
+ """The final layer of DiT."""
91
+
92
+ def __init__(
93
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
94
+ ):
95
+ factory_kwargs = {"device": device, "dtype": dtype}
96
+ super().__init__()
97
+
98
+ # Just use LayerNorm for the final layer
99
+ self.norm_final = nn.LayerNorm(
100
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
101
+ )
102
+ if isinstance(patch_size, int):
103
+ self.linear = nn.Linear(
104
+ hidden_size,
105
+ patch_size * patch_size * out_channels,
106
+ bias=True,
107
+ **factory_kwargs
108
+ )
109
+ else:
110
+ self.linear = nn.Linear(
111
+ hidden_size,
112
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
113
+ bias=True,
114
+ )
115
+ nn.init.zeros_(self.linear.weight)
116
+ nn.init.zeros_(self.linear.bias)
117
+
118
+ # Here we don't distinguish between the modulate types. Just use the simple one.
119
+ self.adaLN_modulation = nn.Sequential(
120
+ act_layer(),
121
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
122
+ )
123
+ # Zero-initialize the modulation
124
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
125
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
126
+
127
+ def forward(self, x, c):
128
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
129
+ x = modulate_(self.norm_final(x), shift=shift, scale=scale)
130
+ x = self.linear(x)
131
+ return x
hyvideo/modules/models.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Tuple, Optional, Union, Dict
2
+ from einops import rearrange
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.models import ModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+
11
+ from .activation_layers import get_activation_layer
12
+ from .norm_layers import get_norm_layer
13
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
14
+ from .attenion import attention, parallel_attention, get_cu_seqlens
15
+ from .posemb_layers import apply_rotary_emb
16
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
17
+ from .modulate_layers import ModulateDiT, modulate, modulate_ , apply_gate, apply_gate_and_accumulate_
18
+ from .token_refiner import SingleTokenRefiner
19
+ import numpy as np
20
+ from mmgp import offload
21
+ from wan.modules.attention import pay_attention
22
+ from .audio_adapters import AudioProjNet2, PerceiverAttentionCA
23
+
24
+ def get_linear_split_map():
25
+ hidden_size = 3072
26
+ split_linear_modules_map = {
27
+ "img_attn_qkv" : {"mapped_modules" : ["img_attn_q", "img_attn_k", "img_attn_v"] , "split_sizes": [hidden_size, hidden_size, hidden_size]},
28
+ "linear1" : {"mapped_modules" : ["linear1_attn_q", "linear1_attn_k", "linear1_attn_v", "linear1_mlp"] , "split_sizes": [hidden_size, hidden_size, hidden_size, 7*hidden_size- 3*hidden_size]}
29
+ }
30
+ return split_linear_modules_map
31
+ try:
32
+ from xformers.ops.fmha.attn_bias import BlockDiagonalPaddedKeysMask
33
+ except ImportError:
34
+ BlockDiagonalPaddedKeysMask = None
35
+
36
+
37
+ class MMDoubleStreamBlock(nn.Module):
38
+ """
39
+ A multimodal dit block with seperate modulation for
40
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
41
+ (Flux.1): https://github.com/black-forest-labs/flux
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ hidden_size: int,
47
+ heads_num: int,
48
+ mlp_width_ratio: float,
49
+ mlp_act_type: str = "gelu_tanh",
50
+ qk_norm: bool = True,
51
+ qk_norm_type: str = "rms",
52
+ qkv_bias: bool = False,
53
+ dtype: Optional[torch.dtype] = None,
54
+ device: Optional[torch.device] = None,
55
+ attention_mode: str = "sdpa",
56
+ ):
57
+ factory_kwargs = {"device": device, "dtype": dtype}
58
+ super().__init__()
59
+
60
+ self.attention_mode = attention_mode
61
+ self.deterministic = False
62
+ self.heads_num = heads_num
63
+ head_dim = hidden_size // heads_num
64
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
65
+
66
+ self.img_mod = ModulateDiT(
67
+ hidden_size,
68
+ factor=6,
69
+ act_layer=get_activation_layer("silu"),
70
+ **factory_kwargs,
71
+ )
72
+ self.img_norm1 = nn.LayerNorm(
73
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
74
+ )
75
+
76
+ self.img_attn_qkv = nn.Linear(
77
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
78
+ )
79
+ qk_norm_layer = get_norm_layer(qk_norm_type)
80
+ self.img_attn_q_norm = (
81
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
82
+ if qk_norm
83
+ else nn.Identity()
84
+ )
85
+ self.img_attn_k_norm = (
86
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
87
+ if qk_norm
88
+ else nn.Identity()
89
+ )
90
+ self.img_attn_proj = nn.Linear(
91
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
92
+ )
93
+
94
+ self.img_norm2 = nn.LayerNorm(
95
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
96
+ )
97
+ self.img_mlp = MLP(
98
+ hidden_size,
99
+ mlp_hidden_dim,
100
+ act_layer=get_activation_layer(mlp_act_type),
101
+ bias=True,
102
+ **factory_kwargs,
103
+ )
104
+
105
+ self.txt_mod = ModulateDiT(
106
+ hidden_size,
107
+ factor=6,
108
+ act_layer=get_activation_layer("silu"),
109
+ **factory_kwargs,
110
+ )
111
+ self.txt_norm1 = nn.LayerNorm(
112
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
113
+ )
114
+
115
+ self.txt_attn_qkv = nn.Linear(
116
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
117
+ )
118
+ self.txt_attn_q_norm = (
119
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
120
+ if qk_norm
121
+ else nn.Identity()
122
+ )
123
+ self.txt_attn_k_norm = (
124
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
125
+ if qk_norm
126
+ else nn.Identity()
127
+ )
128
+ self.txt_attn_proj = nn.Linear(
129
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
130
+ )
131
+
132
+ self.txt_norm2 = nn.LayerNorm(
133
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
134
+ )
135
+ self.txt_mlp = MLP(
136
+ hidden_size,
137
+ mlp_hidden_dim,
138
+ act_layer=get_activation_layer(mlp_act_type),
139
+ bias=True,
140
+ **factory_kwargs,
141
+ )
142
+ self.hybrid_seq_parallel_attn = None
143
+
144
+ def enable_deterministic(self):
145
+ self.deterministic = True
146
+
147
+ def disable_deterministic(self):
148
+ self.deterministic = False
149
+
150
+ def forward(
151
+ self,
152
+ img: torch.Tensor,
153
+ txt: torch.Tensor,
154
+ vec: torch.Tensor,
155
+ attn_mask = None,
156
+ seqlens_q: Optional[torch.Tensor] = None,
157
+ seqlens_kv: Optional[torch.Tensor] = None,
158
+ freqs_cis: tuple = None,
159
+ condition_type: str = None,
160
+ token_replace_vec: torch.Tensor = None,
161
+ frist_frame_token_num: int = None,
162
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
163
+
164
+ if condition_type == "token_replace":
165
+ img_mod1, token_replace_img_mod1 = self.img_mod(vec, condition_type=condition_type, \
166
+ token_replace_vec=token_replace_vec)
167
+ (img_mod1_shift,
168
+ img_mod1_scale,
169
+ img_mod1_gate,
170
+ img_mod2_shift,
171
+ img_mod2_scale,
172
+ img_mod2_gate) = img_mod1.chunk(6, dim=-1)
173
+ (tr_img_mod1_shift,
174
+ tr_img_mod1_scale,
175
+ tr_img_mod1_gate,
176
+ tr_img_mod2_shift,
177
+ tr_img_mod2_scale,
178
+ tr_img_mod2_gate) = token_replace_img_mod1.chunk(6, dim=-1)
179
+ else:
180
+ (
181
+ img_mod1_shift,
182
+ img_mod1_scale,
183
+ img_mod1_gate,
184
+ img_mod2_shift,
185
+ img_mod2_scale,
186
+ img_mod2_gate,
187
+ ) = self.img_mod(vec).chunk(6, dim=-1)
188
+ (
189
+ txt_mod1_shift,
190
+ txt_mod1_scale,
191
+ txt_mod1_gate,
192
+ txt_mod2_shift,
193
+ txt_mod2_scale,
194
+ txt_mod2_gate,
195
+ ) = self.txt_mod(vec).chunk(6, dim=-1)
196
+
197
+ ##### Enjoy this spagheti VRAM optimizations done by DeepBeepMeep !
198
+ # I am sure you are a nice person and as you copy this code, you will give me officially proper credits:
199
+ # Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter
200
+
201
+ # Prepare image for attention.
202
+ img_modulated = self.img_norm1(img)
203
+ img_modulated = img_modulated.to(torch.bfloat16)
204
+
205
+ if condition_type == "token_replace":
206
+ modulate_(img_modulated[:, :frist_frame_token_num], shift=tr_img_mod1_shift, scale=tr_img_mod1_scale)
207
+ modulate_(img_modulated[:, frist_frame_token_num:], shift=img_mod1_shift, scale=img_mod1_scale)
208
+ else:
209
+ modulate_( img_modulated, shift=img_mod1_shift, scale=img_mod1_scale )
210
+
211
+ shape = (*img_modulated.shape[:2], self.heads_num, int(img_modulated.shape[-1] / self.heads_num) )
212
+ img_q = self.img_attn_q(img_modulated).view(*shape)
213
+ img_k = self.img_attn_k(img_modulated).view(*shape)
214
+ img_v = self.img_attn_v(img_modulated).view(*shape)
215
+ del img_modulated
216
+
217
+ # Apply QK-Norm if needed
218
+ self.img_attn_q_norm.apply_(img_q).to(img_v)
219
+ img_q_len = img_q.shape[1]
220
+ self.img_attn_k_norm.apply_(img_k).to(img_v)
221
+ img_kv_len= img_k.shape[1]
222
+ batch_size = img_k.shape[0]
223
+ # Apply RoPE if needed.
224
+ qklist = [img_q, img_k]
225
+ del img_q, img_k
226
+ img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False)
227
+ # Prepare txt for attention.
228
+ txt_modulated = self.txt_norm1(txt)
229
+ modulate_(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale )
230
+
231
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
232
+ del txt_modulated
233
+ txt_q, txt_k, txt_v = rearrange(
234
+ txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
235
+ )
236
+ del txt_qkv
237
+ # Apply QK-Norm if needed.
238
+ self.txt_attn_q_norm.apply_(txt_q).to(txt_v)
239
+ self.txt_attn_k_norm.apply_(txt_k).to(txt_v)
240
+
241
+ # Run actual attention.
242
+ q = torch.cat((img_q, txt_q), dim=1)
243
+ del img_q, txt_q
244
+ k = torch.cat((img_k, txt_k), dim=1)
245
+ del img_k, txt_k
246
+ v = torch.cat((img_v, txt_v), dim=1)
247
+ del img_v, txt_v
248
+
249
+ # attention computation start
250
+ qkv_list = [q,k,v]
251
+ del q, k, v
252
+
253
+ attn = pay_attention(
254
+ qkv_list,
255
+ attention_mask=attn_mask,
256
+ q_lens=seqlens_q,
257
+ k_lens=seqlens_kv,
258
+ )
259
+ b, s, a, d = attn.shape
260
+ attn = attn.reshape(b, s, -1)
261
+ del qkv_list
262
+
263
+ # attention computation end
264
+
265
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
266
+ del attn
267
+ # Calculate the img bloks.
268
+
269
+ if condition_type == "token_replace":
270
+ img_attn = self.img_attn_proj(img_attn)
271
+ apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_attn[:, :frist_frame_token_num], gate=tr_img_mod1_gate)
272
+ apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_attn[:, frist_frame_token_num:], gate=img_mod1_gate)
273
+ del img_attn
274
+ img_modulated = self.img_norm2(img)
275
+ img_modulated = img_modulated.to(torch.bfloat16)
276
+ modulate_( img_modulated[:, :frist_frame_token_num], shift=tr_img_mod2_shift, scale=tr_img_mod2_scale)
277
+ modulate_( img_modulated[:, frist_frame_token_num:], shift=img_mod2_shift, scale=img_mod2_scale)
278
+ self.img_mlp.apply_(img_modulated)
279
+ apply_gate_and_accumulate_(img[:, :frist_frame_token_num], img_modulated[:, :frist_frame_token_num], gate=tr_img_mod2_gate)
280
+ apply_gate_and_accumulate_(img[:, frist_frame_token_num:], img_modulated[:, frist_frame_token_num:], gate=img_mod2_gate)
281
+ del img_modulated
282
+ else:
283
+ img_attn = self.img_attn_proj(img_attn)
284
+ apply_gate_and_accumulate_(img, img_attn, gate=img_mod1_gate)
285
+ del img_attn
286
+ img_modulated = self.img_norm2(img)
287
+ img_modulated = img_modulated.to(torch.bfloat16)
288
+ modulate_( img_modulated , shift=img_mod2_shift, scale=img_mod2_scale)
289
+ self.img_mlp.apply_(img_modulated)
290
+ apply_gate_and_accumulate_(img, img_modulated, gate=img_mod2_gate)
291
+ del img_modulated
292
+
293
+ # Calculate the txt bloks.
294
+ txt_attn = self.txt_attn_proj(txt_attn)
295
+ apply_gate_and_accumulate_(txt, txt_attn, gate=txt_mod1_gate)
296
+ del txt_attn
297
+ txt_modulated = self.txt_norm2(txt)
298
+ txt_modulated = txt_modulated.to(torch.bfloat16)
299
+ modulate_(txt_modulated, shift=txt_mod2_shift, scale=txt_mod2_scale)
300
+ txt_mlp = self.txt_mlp(txt_modulated)
301
+ del txt_modulated
302
+ apply_gate_and_accumulate_(txt, txt_mlp, gate=txt_mod2_gate)
303
+ return img, txt
304
+
305
+
306
+ class MMSingleStreamBlock(nn.Module):
307
+ """
308
+ A DiT block with parallel linear layers as described in
309
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
310
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
311
+ (Flux.1): https://github.com/black-forest-labs/flux
312
+ """
313
+
314
+ def __init__(
315
+ self,
316
+ hidden_size: int,
317
+ heads_num: int,
318
+ mlp_width_ratio: float = 4.0,
319
+ mlp_act_type: str = "gelu_tanh",
320
+ qk_norm: bool = True,
321
+ qk_norm_type: str = "rms",
322
+ qk_scale: float = None,
323
+ dtype: Optional[torch.dtype] = None,
324
+ device: Optional[torch.device] = None,
325
+ attention_mode: str = "sdpa",
326
+ ):
327
+ factory_kwargs = {"device": device, "dtype": dtype}
328
+ super().__init__()
329
+ self.attention_mode = attention_mode
330
+ self.deterministic = False
331
+ self.hidden_size = hidden_size
332
+ self.heads_num = heads_num
333
+ head_dim = hidden_size // heads_num
334
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
335
+ self.mlp_hidden_dim = mlp_hidden_dim
336
+ self.scale = qk_scale or head_dim ** -0.5
337
+
338
+ # qkv and mlp_in
339
+ self.linear1 = nn.Linear(
340
+ hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
341
+ )
342
+ # proj and mlp_out
343
+ self.linear2 = nn.Linear(
344
+ hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
345
+ )
346
+
347
+ qk_norm_layer = get_norm_layer(qk_norm_type)
348
+ self.q_norm = (
349
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
350
+ if qk_norm
351
+ else nn.Identity()
352
+ )
353
+ self.k_norm = (
354
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
355
+ if qk_norm
356
+ else nn.Identity()
357
+ )
358
+
359
+ self.pre_norm = nn.LayerNorm(
360
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
361
+ )
362
+
363
+ self.mlp_act = get_activation_layer(mlp_act_type)()
364
+ self.modulation = ModulateDiT(
365
+ hidden_size,
366
+ factor=3,
367
+ act_layer=get_activation_layer("silu"),
368
+ **factory_kwargs,
369
+ )
370
+ self.hybrid_seq_parallel_attn = None
371
+
372
+ def enable_deterministic(self):
373
+ self.deterministic = True
374
+
375
+ def disable_deterministic(self):
376
+ self.deterministic = False
377
+
378
+ def forward(
379
+ self,
380
+ # x: torch.Tensor,
381
+ img: torch.Tensor,
382
+ txt: torch.Tensor,
383
+ vec: torch.Tensor,
384
+ txt_len: int,
385
+ attn_mask= None,
386
+ seqlens_q: Optional[torch.Tensor] = None,
387
+ seqlens_kv: Optional[torch.Tensor] = None,
388
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
389
+ condition_type: str = None,
390
+ token_replace_vec: torch.Tensor = None,
391
+ frist_frame_token_num: int = None,
392
+ ) -> torch.Tensor:
393
+
394
+ ##### More spagheti VRAM optimizations done by DeepBeepMeep !
395
+ # I am sure you are a nice person and as you copy this code, you will give me proper credits:
396
+ # Please link to https://github.com/deepbeepmeep/HunyuanVideoGP and @deepbeepmeep on twitter
397
+
398
+ if condition_type == "token_replace":
399
+ mod, tr_mod = self.modulation(vec,
400
+ condition_type=condition_type,
401
+ token_replace_vec=token_replace_vec)
402
+ (mod_shift,
403
+ mod_scale,
404
+ mod_gate) = mod.chunk(3, dim=-1)
405
+ (tr_mod_shift,
406
+ tr_mod_scale,
407
+ tr_mod_gate) = tr_mod.chunk(3, dim=-1)
408
+ else:
409
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
410
+
411
+ img_mod = self.pre_norm(img)
412
+ img_mod = img_mod.to(torch.bfloat16)
413
+ if condition_type == "token_replace":
414
+ modulate_(img_mod[:, :frist_frame_token_num], shift=tr_mod_shift, scale=tr_mod_scale)
415
+ modulate_(img_mod[:, frist_frame_token_num:], shift=mod_shift, scale=mod_scale)
416
+ else:
417
+ modulate_(img_mod, shift=mod_shift, scale=mod_scale)
418
+ txt_mod = self.pre_norm(txt)
419
+ txt_mod = txt_mod.to(torch.bfloat16)
420
+ modulate_(txt_mod, shift=mod_shift, scale=mod_scale)
421
+
422
+ shape = (*img_mod.shape[:2], self.heads_num, int(img_mod.shape[-1] / self.heads_num) )
423
+ img_q = self.linear1_attn_q(img_mod).view(*shape)
424
+ img_k = self.linear1_attn_k(img_mod).view(*shape)
425
+ img_v = self.linear1_attn_v(img_mod).view(*shape)
426
+
427
+ shape = (*txt_mod.shape[:2], self.heads_num, int(txt_mod.shape[-1] / self.heads_num) )
428
+ txt_q = self.linear1_attn_q(txt_mod).view(*shape)
429
+ txt_k = self.linear1_attn_k(txt_mod).view(*shape)
430
+ txt_v = self.linear1_attn_v(txt_mod).view(*shape)
431
+
432
+ batch_size = img_mod.shape[0]
433
+
434
+ # Apply QK-Norm if needed.
435
+ # q = self.q_norm(q).to(v)
436
+ self.q_norm.apply_(img_q)
437
+ self.k_norm.apply_(img_k)
438
+ self.q_norm.apply_(txt_q)
439
+ self.k_norm.apply_(txt_k)
440
+
441
+ qklist = [img_q, img_k]
442
+ del img_q, img_k
443
+ img_q, img_k = apply_rotary_emb(qklist, freqs_cis, head_first=False)
444
+ img_q_len=img_q.shape[1]
445
+ q = torch.cat((img_q, txt_q), dim=1)
446
+ del img_q, txt_q
447
+ k = torch.cat((img_k, txt_k), dim=1)
448
+ img_kv_len=img_k.shape[1]
449
+ del img_k, txt_k
450
+
451
+ v = torch.cat((img_v, txt_v), dim=1)
452
+ del img_v, txt_v
453
+
454
+ # attention computation start
455
+ qkv_list = [q,k,v]
456
+ del q, k, v
457
+ attn = pay_attention(
458
+ qkv_list,
459
+ attention_mask=attn_mask,
460
+ q_lens = seqlens_q,
461
+ k_lens = seqlens_kv,
462
+ )
463
+ b, s, a, d = attn.shape
464
+ attn = attn.reshape(b, s, -1)
465
+ del qkv_list
466
+ # attention computation end
467
+
468
+ x_mod = torch.cat((img_mod, txt_mod), 1)
469
+ del img_mod, txt_mod
470
+ x_mod_shape = x_mod.shape
471
+ x_mod = x_mod.view(-1, x_mod.shape[-1])
472
+ chunk_size = int(x_mod_shape[1]/6)
473
+ x_chunks = torch.split(x_mod, chunk_size)
474
+ attn = attn.view(-1, attn.shape[-1])
475
+ attn_chunks =torch.split(attn, chunk_size)
476
+ for x_chunk, attn_chunk in zip(x_chunks, attn_chunks):
477
+ mlp_chunk = self.linear1_mlp(x_chunk)
478
+ mlp_chunk = self.mlp_act(mlp_chunk)
479
+ attn_mlp_chunk = torch.cat((attn_chunk, mlp_chunk), -1)
480
+ del attn_chunk, mlp_chunk
481
+ x_chunk[...] = self.linear2(attn_mlp_chunk)
482
+ del attn_mlp_chunk
483
+ x_mod = x_mod.view(x_mod_shape)
484
+
485
+ if condition_type == "token_replace":
486
+ apply_gate_and_accumulate_(img[:, :frist_frame_token_num, :], x_mod[:, :frist_frame_token_num, :], gate=tr_mod_gate)
487
+ apply_gate_and_accumulate_(img[:, frist_frame_token_num:, :], x_mod[:, frist_frame_token_num:-txt_len, :], gate=mod_gate)
488
+ else:
489
+ apply_gate_and_accumulate_(img, x_mod[:, :-txt_len, :], gate=mod_gate)
490
+
491
+ apply_gate_and_accumulate_(txt, x_mod[:, -txt_len:, :], gate=mod_gate)
492
+
493
+ return img, txt
494
+
495
+ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
496
+ def preprocess_loras(self, model_filename, sd):
497
+ if not "i2v" in model_filename:
498
+ return sd
499
+ new_sd = {}
500
+ for k,v in sd.items():
501
+ repl_list = ["double_blocks", "single_blocks", "final_layer", "img_mlp", "img_attn_qkv", "img_attn_proj","img_mod", "txt_mlp", "txt_attn_qkv","txt_attn_proj", "txt_mod", "linear1",
502
+ "linear2", "modulation", "mlp_fc1"]
503
+ src_list = [k +"_" for k in repl_list] + ["_" + k for k in repl_list]
504
+ tgt_list = [k +"." for k in repl_list] + ["." + k for k in repl_list]
505
+ if k.startswith("Hunyuan_video_I2V_lora_"):
506
+ # crappy conversion script for non reversible lora naming
507
+ k = k.replace("Hunyuan_video_I2V_lora_","diffusion_model.")
508
+ k = k.replace("lora_up","lora_B")
509
+ k = k.replace("lora_down","lora_A")
510
+ if "txt_in_individual" in k:
511
+ pass
512
+ for s,t in zip(src_list, tgt_list):
513
+ k = k.replace(s,t)
514
+ if "individual_token_refiner" in k:
515
+ k = k.replace("txt_in_individual_token_refiner_blocks_", "txt_in.individual_token_refiner.blocks.")
516
+ k = k.replace("_mlp_fc", ".mlp.fc",)
517
+ k = k.replace(".mlp_fc", ".mlp.fc",)
518
+ new_sd[k] = v
519
+ return new_sd
520
+ """
521
+ HunyuanVideo Transformer backbone
522
+
523
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
524
+
525
+ Reference:
526
+ [1] Flux.1: https://github.com/black-forest-labs/flux
527
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
528
+
529
+ Parameters
530
+ ----------
531
+ args: argparse.Namespace
532
+ The arguments parsed by argparse.
533
+ patch_size: list
534
+ The size of the patch.
535
+ in_channels: int
536
+ The number of input channels.
537
+ out_channels: int
538
+ The number of output channels.
539
+ hidden_size: int
540
+ The hidden size of the transformer backbone.
541
+ heads_num: int
542
+ The number of attention heads.
543
+ mlp_width_ratio: float
544
+ The ratio of the hidden size of the MLP in the transformer block.
545
+ mlp_act_type: str
546
+ The activation function of the MLP in the transformer block.
547
+ depth_double_blocks: int
548
+ The number of transformer blocks in the double blocks.
549
+ depth_single_blocks: int
550
+ The number of transformer blocks in the single blocks.
551
+ rope_dim_list: list
552
+ The dimension of the rotary embedding for t, h, w.
553
+ qkv_bias: bool
554
+ Whether to use bias in the qkv linear layer.
555
+ qk_norm: bool
556
+ Whether to use qk norm.
557
+ qk_norm_type: str
558
+ The type of qk norm.
559
+ guidance_embed: bool
560
+ Whether to use guidance embedding for distillation.
561
+ text_projection: str
562
+ The type of the text projection, default is single_refiner.
563
+ use_attention_mask: bool
564
+ Whether to use attention mask for text encoder.
565
+ dtype: torch.dtype
566
+ The dtype of the model.
567
+ device: torch.device
568
+ The device of the model.
569
+ """
570
+
571
+ @register_to_config
572
+ def __init__(
573
+ self,
574
+ i2v_condition_type,
575
+ patch_size: list = [1, 2, 2],
576
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
577
+ out_channels: int = None,
578
+ hidden_size: int = 3072,
579
+ heads_num: int = 24,
580
+ mlp_width_ratio: float = 4.0,
581
+ mlp_act_type: str = "gelu_tanh",
582
+ mm_double_blocks_depth: int = 20,
583
+ mm_single_blocks_depth: int = 40,
584
+ rope_dim_list: List[int] = [16, 56, 56],
585
+ qkv_bias: bool = True,
586
+ qk_norm: bool = True,
587
+ qk_norm_type: str = "rms",
588
+ guidance_embed: bool = False, # For modulation.
589
+ text_projection: str = "single_refiner",
590
+ use_attention_mask: bool = True,
591
+ dtype: Optional[torch.dtype] = None,
592
+ device: Optional[torch.device] = None,
593
+ attention_mode: Optional[str] = "sdpa",
594
+ avatar = False,
595
+ ):
596
+ factory_kwargs = {"device": device, "dtype": dtype}
597
+ super().__init__()
598
+
599
+ # mm_double_blocks_depth , mm_single_blocks_depth = 5, 5
600
+
601
+ self.patch_size = patch_size
602
+ self.in_channels = in_channels
603
+ self.out_channels = in_channels if out_channels is None else out_channels
604
+ self.unpatchify_channels = self.out_channels
605
+ self.guidance_embed = guidance_embed
606
+ self.rope_dim_list = rope_dim_list
607
+ self.i2v_condition_type = i2v_condition_type
608
+ self.attention_mode = attention_mode
609
+
610
+ # Text projection. Default to linear projection.
611
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
612
+ self.use_attention_mask = use_attention_mask
613
+ self.text_projection = text_projection
614
+
615
+ self.text_states_dim = 4096
616
+ self.text_states_dim_2 = 768
617
+
618
+ if hidden_size % heads_num != 0:
619
+ raise ValueError(
620
+ f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
621
+ )
622
+ pe_dim = hidden_size // heads_num
623
+ if sum(rope_dim_list) != pe_dim:
624
+ raise ValueError(
625
+ f"Got {rope_dim_list} but expected positional dim {pe_dim}"
626
+ )
627
+ self.hidden_size = hidden_size
628
+ self.heads_num = heads_num
629
+
630
+ # image projection
631
+ self.img_in = PatchEmbed(
632
+ self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
633
+ )
634
+
635
+ # text projection
636
+ if self.text_projection == "linear":
637
+ self.txt_in = TextProjection(
638
+ self.text_states_dim,
639
+ self.hidden_size,
640
+ get_activation_layer("silu"),
641
+ **factory_kwargs,
642
+ )
643
+ elif self.text_projection == "single_refiner":
644
+ self.txt_in = SingleTokenRefiner(
645
+ self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
646
+ )
647
+ else:
648
+ raise NotImplementedError(
649
+ f"Unsupported text_projection: {self.text_projection}"
650
+ )
651
+
652
+ # time modulation
653
+ self.time_in = TimestepEmbedder(
654
+ self.hidden_size, get_activation_layer("silu"), **factory_kwargs
655
+ )
656
+
657
+ # text modulation
658
+ self.vector_in = MLPEmbedder(
659
+ self.text_states_dim_2, self.hidden_size, **factory_kwargs
660
+ )
661
+
662
+ # guidance modulation
663
+ self.guidance_in = (
664
+ TimestepEmbedder(
665
+ self.hidden_size, get_activation_layer("silu"), **factory_kwargs
666
+ )
667
+ if guidance_embed
668
+ else None
669
+ )
670
+
671
+ # double blocks
672
+ self.double_blocks = nn.ModuleList(
673
+ [
674
+ MMDoubleStreamBlock(
675
+ self.hidden_size,
676
+ self.heads_num,
677
+ mlp_width_ratio=mlp_width_ratio,
678
+ mlp_act_type=mlp_act_type,
679
+ qk_norm=qk_norm,
680
+ qk_norm_type=qk_norm_type,
681
+ qkv_bias=qkv_bias,
682
+ attention_mode = attention_mode,
683
+ **factory_kwargs,
684
+ )
685
+ for _ in range(mm_double_blocks_depth)
686
+ ]
687
+ )
688
+
689
+ # single blocks
690
+ self.single_blocks = nn.ModuleList(
691
+ [
692
+ MMSingleStreamBlock(
693
+ self.hidden_size,
694
+ self.heads_num,
695
+ mlp_width_ratio=mlp_width_ratio,
696
+ mlp_act_type=mlp_act_type,
697
+ qk_norm=qk_norm,
698
+ qk_norm_type=qk_norm_type,
699
+ attention_mode = attention_mode,
700
+ **factory_kwargs,
701
+ )
702
+ for _ in range(mm_single_blocks_depth)
703
+ ]
704
+ )
705
+
706
+ self.final_layer = FinalLayer(
707
+ self.hidden_size,
708
+ self.patch_size,
709
+ self.out_channels,
710
+ get_activation_layer("silu"),
711
+ **factory_kwargs,
712
+ )
713
+ avatar_audio = avatar
714
+ if avatar_audio:
715
+ self.ref_in = PatchEmbed(
716
+ self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
717
+ )
718
+
719
+ # -------------------- audio_proj_model --------------------
720
+ self.audio_proj = AudioProjNet2(seq_len=10, blocks=5, channels=384, intermediate_dim=1024, output_dim=3072, context_tokens=4)
721
+
722
+ # -------------------- motion-embeder --------------------
723
+ self.motion_exp = TimestepEmbedder(
724
+ self.hidden_size // 4,
725
+ get_activation_layer("silu"),
726
+ **factory_kwargs
727
+ )
728
+ self.motion_pose = TimestepEmbedder(
729
+ self.hidden_size // 4,
730
+ get_activation_layer("silu"),
731
+ **factory_kwargs
732
+ )
733
+
734
+ self.fps_proj = TimestepEmbedder(
735
+ self.hidden_size,
736
+ get_activation_layer("silu"),
737
+ **factory_kwargs
738
+ )
739
+
740
+ self.before_proj = nn.Linear(self.hidden_size, self.hidden_size)
741
+
742
+ # -------------------- audio_insert_model --------------------
743
+ self.double_stream_list = [1, 3, 5, 7, 9, 11, 13, 15, 17, 19]
744
+ self.single_stream_list = []
745
+ self.double_stream_map = {str(i): j for j, i in enumerate(self.double_stream_list)}
746
+ self.single_stream_map = {str(i): j+len(self.double_stream_list) for j, i in enumerate(self.single_stream_list)}
747
+
748
+ self.audio_adapter_blocks = nn.ModuleList([
749
+ PerceiverAttentionCA(dim=3072, dim_head=1024, heads=33) for _ in range(len(self.double_stream_list) + len(self.single_stream_list))
750
+ ])
751
+
752
+
753
+
754
+ def lock_layers_dtypes(self, dtype = torch.float32):
755
+ layer_list = [self.final_layer, self.final_layer.linear, self.final_layer.adaLN_modulation[1]]
756
+ target_dype= dtype
757
+
758
+ for current_layer_list, current_dtype in zip([layer_list], [target_dype]):
759
+ for layer in current_layer_list:
760
+ layer._lock_dtype = dtype
761
+
762
+ if hasattr(layer, "weight") and layer.weight.dtype != current_dtype :
763
+ layer.weight.data = layer.weight.data.to(current_dtype)
764
+ if hasattr(layer, "bias"):
765
+ layer.bias.data = layer.bias.data.to(current_dtype)
766
+
767
+ self._lock_dtype = dtype
768
+
769
+ def enable_deterministic(self):
770
+ for block in self.double_blocks:
771
+ block.enable_deterministic()
772
+ for block in self.single_blocks:
773
+ block.enable_deterministic()
774
+
775
+ def disable_deterministic(self):
776
+ for block in self.double_blocks:
777
+ block.disable_deterministic()
778
+ for block in self.single_blocks:
779
+ block.disable_deterministic()
780
+
781
+ def forward(
782
+ self,
783
+ x: torch.Tensor,
784
+ t: torch.Tensor, # Should be in range(0, 1000).
785
+ ref_latents: torch.Tensor=None,
786
+ text_states: torch.Tensor = None,
787
+ text_mask: torch.Tensor = None, # Now we don't use it.
788
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
789
+ freqs_cos: Optional[torch.Tensor] = None,
790
+ freqs_sin: Optional[torch.Tensor] = None,
791
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
792
+ pipeline=None,
793
+ x_id = 0,
794
+ step_no = 0,
795
+ callback = None,
796
+ audio_prompts = None,
797
+ motion_exp = None,
798
+ motion_pose = None,
799
+ fps = None,
800
+ face_mask = None,
801
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
802
+
803
+ img = x
804
+ bsz, _, ot, oh, ow = x.shape
805
+ del x
806
+ txt = text_states
807
+ tt, th, tw = (
808
+ ot // self.patch_size[0],
809
+ oh // self.patch_size[1],
810
+ ow // self.patch_size[2],
811
+ )
812
+
813
+ # Prepare modulation vectors.
814
+ vec = self.time_in(t)
815
+ if motion_exp != None:
816
+ vec += self.motion_exp(motion_exp.view(-1)).view(bsz, -1) # (b, 3072)
817
+ if motion_pose != None:
818
+ vec += self.motion_pose(motion_pose.view(-1)).view(bsz, -1) # (b, 3072)
819
+ if fps != None:
820
+ vec += self.fps_proj(fps) # (b, 3072)
821
+ if audio_prompts != None:
822
+ audio_feature_all = self.audio_proj(audio_prompts)
823
+ audio_feature_pad = audio_feature_all[:,:1].repeat(1,3,1,1)
824
+ audio_feature_all_insert = torch.cat([audio_feature_pad, audio_feature_all], dim=1).view(bsz, ot, 16, 3072)
825
+ audio_feature_all = None
826
+
827
+ if self.i2v_condition_type == "token_replace":
828
+ token_replace_t = torch.zeros_like(t)
829
+ token_replace_vec = self.time_in(token_replace_t)
830
+ frist_frame_token_num = th * tw
831
+ else:
832
+ token_replace_vec = None
833
+ frist_frame_token_num = None
834
+ # token_replace_mask_img = None
835
+ # token_replace_mask_txt = None
836
+
837
+ # text modulation
838
+ vec_2 = self.vector_in(text_states_2)
839
+ del text_states_2
840
+ vec += vec_2
841
+ if self.i2v_condition_type == "token_replace":
842
+ token_replace_vec += vec_2
843
+ del vec_2
844
+
845
+ # guidance modulation
846
+ if self.guidance_embed:
847
+ if guidance is None:
848
+ raise ValueError(
849
+ "Didn't get guidance strength for guidance distilled model."
850
+ )
851
+
852
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
853
+ vec += self.guidance_in(guidance)
854
+
855
+ # Embed image and text.
856
+ img, shape_mask = self.img_in(img)
857
+ if audio_prompts != None:
858
+ ref_latents_first = ref_latents[:, :, :1].clone()
859
+ ref_latents,_ = self.ref_in(ref_latents)
860
+ ref_latents_first,_ = self.img_in(ref_latents_first)
861
+ elif ref_latents != None:
862
+ ref_latents, _ = self.img_in(ref_latents)
863
+
864
+ if self.text_projection == "linear":
865
+ txt = self.txt_in(txt)
866
+ elif self.text_projection == "single_refiner":
867
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
868
+ else:
869
+ raise NotImplementedError(
870
+ f"Unsupported text_projection: {self.text_projection}"
871
+ )
872
+
873
+ if audio_prompts != None:
874
+ img += self.before_proj(ref_latents)
875
+ ref_length = ref_latents_first.shape[-2] # [b s c]
876
+ img = torch.cat([ref_latents_first, img], dim=-2) # t c
877
+ img_len = img.shape[1]
878
+ mask_len = img_len - ref_length
879
+ if face_mask.shape[2] == 1:
880
+ face_mask = face_mask.repeat(1,1,ot,1,1) # repeat if number of mask frame is 1
881
+ face_mask = torch.nn.functional.interpolate(face_mask, size=[ot, shape_mask[-2], shape_mask[-1]], mode="nearest")
882
+ # face_mask = face_mask.view(-1,mask_len,1).repeat(1,1,img.shape[-1]).type_as(img)
883
+ face_mask = face_mask.view(-1,mask_len,1).type_as(img)
884
+ elif ref_latents == None:
885
+ ref_length = None
886
+ else:
887
+ ref_length = ref_latents.shape[-2]
888
+ img = torch.cat([ref_latents, img], dim=-2) # t c
889
+ txt_seq_len = txt.shape[1]
890
+ img_seq_len = img.shape[1]
891
+
892
+ text_len = text_mask.sum(1)
893
+ total_len = text_len + img_seq_len
894
+ seqlens_q = seqlens_kv = total_len
895
+ attn_mask = None
896
+
897
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
898
+
899
+
900
+ if self.enable_teacache:
901
+ if x_id == 0:
902
+ self.should_calc = True
903
+ inp = img[0:1]
904
+ vec_ = vec[0:1]
905
+ ( img_mod1_shift, img_mod1_scale, _ , _ , _ , _ , ) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1)
906
+ normed_inp = self.double_blocks[0].img_norm1(inp)
907
+ normed_inp = normed_inp.to(torch.bfloat16)
908
+ modulated_inp = modulate( normed_inp, shift=img_mod1_shift, scale=img_mod1_scale )
909
+ del normed_inp, img_mod1_shift, img_mod1_scale
910
+ if step_no <= self.teacache_start_step or step_no == self.num_steps-1:
911
+ self.accumulated_rel_l1_distance = 0
912
+ else:
913
+ coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02]
914
+ rescale_func = np.poly1d(coefficients)
915
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
916
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
917
+ self.should_calc = False
918
+ self.teacache_skipped_steps += 1
919
+ else:
920
+ self.accumulated_rel_l1_distance = 0
921
+ self.previous_modulated_input = modulated_inp
922
+ else:
923
+ self.should_calc = True
924
+
925
+ if not self.should_calc:
926
+ img += self.previous_residual[x_id]
927
+ else:
928
+ if self.enable_teacache:
929
+ self.previous_residual[x_id] = None
930
+ ori_img = img[0:1].clone()
931
+ # --------------------- Pass through DiT blocks ------------------------
932
+ for layer_num, block in enumerate(self.double_blocks):
933
+ for i in range(len(img)):
934
+ if callback != None:
935
+ callback(-1, None, False, True)
936
+ if pipeline._interrupt:
937
+ return None
938
+ double_block_args = [
939
+ img[i:i+1],
940
+ txt[i:i+1],
941
+ vec[i:i+1],
942
+ attn_mask,
943
+ seqlens_q[i:i+1],
944
+ seqlens_kv[i:i+1],
945
+ freqs_cis,
946
+ self.i2v_condition_type,
947
+ token_replace_vec,
948
+ frist_frame_token_num,
949
+ ]
950
+
951
+ img[i], txt[i] = block(*double_block_args)
952
+ double_block_args = None
953
+ # insert audio feature to img
954
+ if audio_prompts != None:
955
+ audio_adapter = getattr(self.double_blocks[layer_num], "audio_adapter", None)
956
+ if audio_adapter != None:
957
+ real_img = img[i:i+1,ref_length:].view(1, ot, -1, 3072)
958
+ real_img = audio_adapter(audio_feature_all_insert[i:i+1], real_img).view(1, -1, 3072)
959
+ real_img *= face_mask[i:i+1]
960
+ img[i:i+1, ref_length:] += real_img
961
+ real_img = None
962
+
963
+
964
+ for _, block in enumerate(self.single_blocks):
965
+ for i in range(len(img)):
966
+ if callback != None:
967
+ callback(-1, None, False, True)
968
+ if pipeline._interrupt:
969
+ return None
970
+ single_block_args = [
971
+ # x,
972
+ img[i:i+1],
973
+ txt[i:i+1],
974
+ vec[i:i+1],
975
+ txt_seq_len,
976
+ attn_mask,
977
+ seqlens_q[i:i+1],
978
+ seqlens_kv[i:i+1],
979
+ (freqs_cos, freqs_sin),
980
+ self.i2v_condition_type,
981
+ token_replace_vec,
982
+ frist_frame_token_num,
983
+ ]
984
+
985
+ img[i], txt[i] = block(*single_block_args)
986
+ single_block_args = None
987
+
988
+ # img = x[:, :img_seq_len, ...]
989
+ if self.enable_teacache:
990
+ if len(img) > 1:
991
+ self.previous_residual[0] = torch.empty_like(img)
992
+ for i, (x, residual) in enumerate(zip(img, self.previous_residual[0])):
993
+ if i < len(img) - 1:
994
+ residual[...] = torch.sub(x, ori_img)
995
+ else:
996
+ residual[...] = ori_img
997
+ torch.sub(x, ori_img, out=residual)
998
+ x = None
999
+ else:
1000
+ self.previous_residual[x_id] = ori_img
1001
+ torch.sub(img, ori_img, out=self.previous_residual[x_id])
1002
+
1003
+
1004
+ if ref_length != None:
1005
+ img = img[:, ref_length:]
1006
+ # ---------------------------- Final layer ------------------------------
1007
+ out_dtype = self.final_layer.linear.weight.dtype
1008
+ vec = vec.to(out_dtype)
1009
+ img_list = []
1010
+ for img_chunk, vec_chunk in zip(img,vec):
1011
+ img_list.append( self.final_layer(img_chunk.to(out_dtype).unsqueeze(0), vec_chunk.unsqueeze(0))) # (N, T, patch_size ** 2 * out_channels)
1012
+ img = torch.cat(img_list)
1013
+ img_list = None
1014
+
1015
+ # img = self.unpatchify(img, tt, th, tw)
1016
+ img = self.unpatchify(img, tt, th, tw)
1017
+
1018
+ return img
1019
+
1020
+ def unpatchify(self, x, t, h, w):
1021
+ """
1022
+ x: (N, T, patch_size**2 * C)
1023
+ imgs: (N, H, W, C)
1024
+ """
1025
+ c = self.unpatchify_channels
1026
+ pt, ph, pw = self.patch_size
1027
+ assert t * h * w == x.shape[1]
1028
+
1029
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
1030
+ x = torch.einsum("nthwcopq->nctohpwq", x)
1031
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
1032
+
1033
+ return imgs
1034
+
1035
+ def params_count(self):
1036
+ counts = {
1037
+ "double": sum(
1038
+ [
1039
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
1040
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
1041
+ + sum(p.numel() for p in block.img_mlp.parameters())
1042
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
1043
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
1044
+ + sum(p.numel() for p in block.txt_mlp.parameters())
1045
+ for block in self.double_blocks
1046
+ ]
1047
+ ),
1048
+ "single": sum(
1049
+ [
1050
+ sum(p.numel() for p in block.linear1.parameters())
1051
+ + sum(p.numel() for p in block.linear2.parameters())
1052
+ for block in self.single_blocks
1053
+ ]
1054
+ ),
1055
+ "total": sum(p.numel() for p in self.parameters()),
1056
+ }
1057
+ counts["attn+mlp"] = counts["double"] + counts["single"]
1058
+ return counts
1059
+
1060
+
1061
+ #################################################################################
1062
+ # HunyuanVideo Configs #
1063
+ #################################################################################
1064
+
1065
+ HUNYUAN_VIDEO_CONFIG = {
1066
+ "HYVideo-T/2": {
1067
+ "mm_double_blocks_depth": 20,
1068
+ "mm_single_blocks_depth": 40,
1069
+ "rope_dim_list": [16, 56, 56],
1070
+ "hidden_size": 3072,
1071
+ "heads_num": 24,
1072
+ "mlp_width_ratio": 4,
1073
+ },
1074
+ "HYVideo-T/2-cfgdistill": {
1075
+ "mm_double_blocks_depth": 20,
1076
+ "mm_single_blocks_depth": 40,
1077
+ "rope_dim_list": [16, 56, 56],
1078
+ "hidden_size": 3072,
1079
+ "heads_num": 24,
1080
+ "mlp_width_ratio": 4,
1081
+ "guidance_embed": True,
1082
+ },
1083
+ "HYVideo-S/2": {
1084
+ "mm_double_blocks_depth": 6,
1085
+ "mm_single_blocks_depth": 12,
1086
+ "rope_dim_list": [12, 42, 42],
1087
+ "hidden_size": 480,
1088
+ "heads_num": 5,
1089
+ "mlp_width_ratio": 4,
1090
+ },
1091
+ 'HYVideo-T/2-custom': { # 9.0B / 12.5B
1092
+ "mm_double_blocks_depth": 20,
1093
+ "mm_single_blocks_depth": 40,
1094
+ "rope_dim_list": [16, 56, 56],
1095
+ "hidden_size": 3072,
1096
+ "heads_num": 24,
1097
+ "mlp_width_ratio": 4,
1098
+ },
1099
+ 'HYVideo-T/2-avatar': { # 9.0B / 12.5B
1100
+ 'mm_double_blocks_depth': 20,
1101
+ 'mm_single_blocks_depth': 40,
1102
+ 'rope_dim_list': [16, 56, 56],
1103
+ 'hidden_size': 3072,
1104
+ 'heads_num': 24,
1105
+ 'mlp_width_ratio': 4,
1106
+ 'avatar': True,
1107
+ },
1108
+
1109
+ }
hyvideo/modules/modulate_layers.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import math
6
+
7
+ class ModulateDiT(nn.Module):
8
+ """Modulation layer for DiT."""
9
+ def __init__(
10
+ self,
11
+ hidden_size: int,
12
+ factor: int,
13
+ act_layer: Callable,
14
+ dtype=None,
15
+ device=None,
16
+ ):
17
+ factory_kwargs = {"dtype": dtype, "device": device}
18
+ super().__init__()
19
+ self.act = act_layer()
20
+ self.linear = nn.Linear(
21
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
+ )
23
+ # Zero-initialize the modulation
24
+ nn.init.zeros_(self.linear.weight)
25
+ nn.init.zeros_(self.linear.bias)
26
+
27
+ def forward(self, x: torch.Tensor, condition_type=None, token_replace_vec=None) -> torch.Tensor:
28
+ x_out = self.linear(self.act(x))
29
+
30
+ if condition_type == "token_replace":
31
+ x_token_replace_out = self.linear(self.act(token_replace_vec))
32
+ return x_out, x_token_replace_out
33
+ else:
34
+ return x_out
35
+
36
+ def modulate(x, shift=None, scale=None):
37
+ """modulate by shift and scale
38
+
39
+ Args:
40
+ x (torch.Tensor): input tensor.
41
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
42
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
43
+
44
+ Returns:
45
+ torch.Tensor: the output tensor after modulate.
46
+ """
47
+ if scale is None and shift is None:
48
+ return x
49
+ elif shift is None:
50
+ return x * (1 + scale.unsqueeze(1))
51
+ elif scale is None:
52
+ return x + shift.unsqueeze(1)
53
+ else:
54
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
55
+
56
+ def modulate_(x, shift=None, scale=None):
57
+
58
+ if scale is None and shift is None:
59
+ return x
60
+ elif shift is None:
61
+ scale = scale + 1
62
+ scale = scale.unsqueeze(1)
63
+ return x.mul_(scale)
64
+ elif scale is None:
65
+ return x + shift.unsqueeze(1)
66
+ else:
67
+ scale = scale + 1
68
+ scale = scale.unsqueeze(1)
69
+ # return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
70
+ torch.addcmul(shift.unsqueeze(1), x, scale, out =x )
71
+ return x
72
+
73
+ def modulate(x, shift=None, scale=None, condition_type=None,
74
+ tr_shift=None, tr_scale=None,
75
+ frist_frame_token_num=None):
76
+ if condition_type == "token_replace":
77
+ x_zero = x[:, :frist_frame_token_num] * (1 + tr_scale.unsqueeze(1)) + tr_shift.unsqueeze(1)
78
+ x_orig = x[:, frist_frame_token_num:] * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
79
+ x = torch.concat((x_zero, x_orig), dim=1)
80
+ return x
81
+ else:
82
+ if scale is None and shift is None:
83
+ return x
84
+ elif shift is None:
85
+ return x * (1 + scale.unsqueeze(1))
86
+ elif scale is None:
87
+ return x + shift.unsqueeze(1)
88
+ else:
89
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
90
+
91
+ def apply_gate(x, gate=None, tanh=False, condition_type=None, tr_gate=None, frist_frame_token_num=None):
92
+ """AI is creating summary for apply_gate
93
+
94
+ Args:
95
+ x (torch.Tensor): input tensor.
96
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
97
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
98
+
99
+ Returns:
100
+ torch.Tensor: the output tensor after apply gate.
101
+ """
102
+ if condition_type == "token_replace":
103
+ if gate is None:
104
+ return x
105
+ if tanh:
106
+ x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1).tanh()
107
+ x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1).tanh()
108
+ x = torch.concat((x_zero, x_orig), dim=1)
109
+ return x
110
+ else:
111
+ x_zero = x[:, :frist_frame_token_num] * tr_gate.unsqueeze(1)
112
+ x_orig = x[:, frist_frame_token_num:] * gate.unsqueeze(1)
113
+ x = torch.concat((x_zero, x_orig), dim=1)
114
+ return x
115
+ else:
116
+ if gate is None:
117
+ return x
118
+ if tanh:
119
+ return x * gate.unsqueeze(1).tanh()
120
+ else:
121
+ return x * gate.unsqueeze(1)
122
+
123
+ def apply_gate_and_accumulate_(accumulator, x, gate=None, tanh=False):
124
+ if gate is None:
125
+ return accumulator
126
+ if tanh:
127
+ return accumulator.addcmul_(x, gate.unsqueeze(1).tanh())
128
+ else:
129
+ return accumulator.addcmul_(x, gate.unsqueeze(1))
130
+
131
+ def ckpt_wrapper(module):
132
+ def ckpt_forward(*inputs):
133
+ outputs = module(*inputs)
134
+ return outputs
135
+
136
+ return ckpt_forward
hyvideo/modules/norm_layers.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ elementwise_affine=True,
10
+ eps: float = 1e-6,
11
+ device=None,
12
+ dtype=None,
13
+ ):
14
+ """
15
+ Initialize the RMSNorm normalization layer.
16
+
17
+ Args:
18
+ dim (int): The dimension of the input tensor.
19
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20
+
21
+ Attributes:
22
+ eps (float): A small value added to the denominator for numerical stability.
23
+ weight (nn.Parameter): Learnable scaling parameter.
24
+
25
+ """
26
+ factory_kwargs = {"device": device, "dtype": dtype}
27
+ super().__init__()
28
+ self.eps = eps
29
+ if elementwise_affine:
30
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31
+
32
+ def _norm(self, x):
33
+ """
34
+ Apply the RMSNorm normalization to the input tensor.
35
+
36
+ Args:
37
+ x (torch.Tensor): The input tensor.
38
+
39
+ Returns:
40
+ torch.Tensor: The normalized tensor.
41
+
42
+ """
43
+
44
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
45
+
46
+ def forward(self, x):
47
+ """
48
+ Forward pass through the RMSNorm layer.
49
+
50
+ Args:
51
+ x (torch.Tensor): The input tensor.
52
+
53
+ Returns:
54
+ torch.Tensor: The output tensor after applying RMSNorm.
55
+
56
+ """
57
+ output = self._norm(x.float()).type_as(x)
58
+ if hasattr(self, "weight"):
59
+ output = output * self.weight
60
+ return output
61
+
62
+ def apply_(self, x):
63
+ y = x.pow(2).mean(-1, keepdim=True)
64
+ y.add_(self.eps)
65
+ y.rsqrt_()
66
+ x.mul_(y)
67
+ del y
68
+ if hasattr(self, "weight"):
69
+ x.mul_(self.weight)
70
+ return x
71
+
72
+
73
+ def get_norm_layer(norm_layer):
74
+ """
75
+ Get the normalization layer.
76
+
77
+ Args:
78
+ norm_layer (str): The type of normalization layer.
79
+
80
+ Returns:
81
+ norm_layer (nn.Module): The normalization layer.
82
+ """
83
+ if norm_layer == "layer":
84
+ return nn.LayerNorm
85
+ elif norm_layer == "rms":
86
+ return RMSNorm
87
+ else:
88
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
hyvideo/modules/original models.py ADDED
@@ -0,0 +1,760 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Tuple, Optional, Union, Dict
2
+ from einops import rearrange
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from diffusers.models import ModelMixin
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+
11
+ from .activation_layers import get_activation_layer
12
+ from .norm_layers import get_norm_layer
13
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
14
+ from .attenion import attention, parallel_attention, get_cu_seqlens
15
+ from .posemb_layers import apply_rotary_emb
16
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
17
+ from .modulate_layers import ModulateDiT, modulate, apply_gate
18
+ from .token_refiner import SingleTokenRefiner
19
+
20
+
21
+ class MMDoubleStreamBlock(nn.Module):
22
+ """
23
+ A multimodal dit block with seperate modulation for
24
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
25
+ (Flux.1): https://github.com/black-forest-labs/flux
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ hidden_size: int,
31
+ heads_num: int,
32
+ mlp_width_ratio: float,
33
+ mlp_act_type: str = "gelu_tanh",
34
+ qk_norm: bool = True,
35
+ qk_norm_type: str = "rms",
36
+ qkv_bias: bool = False,
37
+ dtype: Optional[torch.dtype] = None,
38
+ device: Optional[torch.device] = None,
39
+ ):
40
+ factory_kwargs = {"device": device, "dtype": dtype}
41
+ super().__init__()
42
+
43
+ self.deterministic = False
44
+ self.heads_num = heads_num
45
+ head_dim = hidden_size // heads_num
46
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
47
+
48
+ self.img_mod = ModulateDiT(
49
+ hidden_size,
50
+ factor=6,
51
+ act_layer=get_activation_layer("silu"),
52
+ **factory_kwargs,
53
+ )
54
+ self.img_norm1 = nn.LayerNorm(
55
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
56
+ )
57
+
58
+ self.img_attn_qkv = nn.Linear(
59
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
60
+ )
61
+ qk_norm_layer = get_norm_layer(qk_norm_type)
62
+ self.img_attn_q_norm = (
63
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
64
+ if qk_norm
65
+ else nn.Identity()
66
+ )
67
+ self.img_attn_k_norm = (
68
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
69
+ if qk_norm
70
+ else nn.Identity()
71
+ )
72
+ self.img_attn_proj = nn.Linear(
73
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
74
+ )
75
+
76
+ self.img_norm2 = nn.LayerNorm(
77
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
78
+ )
79
+ self.img_mlp = MLP(
80
+ hidden_size,
81
+ mlp_hidden_dim,
82
+ act_layer=get_activation_layer(mlp_act_type),
83
+ bias=True,
84
+ **factory_kwargs,
85
+ )
86
+
87
+ self.txt_mod = ModulateDiT(
88
+ hidden_size,
89
+ factor=6,
90
+ act_layer=get_activation_layer("silu"),
91
+ **factory_kwargs,
92
+ )
93
+ self.txt_norm1 = nn.LayerNorm(
94
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
95
+ )
96
+
97
+ self.txt_attn_qkv = nn.Linear(
98
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
99
+ )
100
+ self.txt_attn_q_norm = (
101
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
102
+ if qk_norm
103
+ else nn.Identity()
104
+ )
105
+ self.txt_attn_k_norm = (
106
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
107
+ if qk_norm
108
+ else nn.Identity()
109
+ )
110
+ self.txt_attn_proj = nn.Linear(
111
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
112
+ )
113
+
114
+ self.txt_norm2 = nn.LayerNorm(
115
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
116
+ )
117
+ self.txt_mlp = MLP(
118
+ hidden_size,
119
+ mlp_hidden_dim,
120
+ act_layer=get_activation_layer(mlp_act_type),
121
+ bias=True,
122
+ **factory_kwargs,
123
+ )
124
+ self.hybrid_seq_parallel_attn = None
125
+
126
+ def enable_deterministic(self):
127
+ self.deterministic = True
128
+
129
+ def disable_deterministic(self):
130
+ self.deterministic = False
131
+
132
+ def forward(
133
+ self,
134
+ img: torch.Tensor,
135
+ txt: torch.Tensor,
136
+ vec: torch.Tensor,
137
+ cu_seqlens_q: Optional[torch.Tensor] = None,
138
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
139
+ max_seqlen_q: Optional[int] = None,
140
+ max_seqlen_kv: Optional[int] = None,
141
+ freqs_cis: tuple = None,
142
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
143
+ (
144
+ img_mod1_shift,
145
+ img_mod1_scale,
146
+ img_mod1_gate,
147
+ img_mod2_shift,
148
+ img_mod2_scale,
149
+ img_mod2_gate,
150
+ ) = self.img_mod(vec).chunk(6, dim=-1)
151
+ (
152
+ txt_mod1_shift,
153
+ txt_mod1_scale,
154
+ txt_mod1_gate,
155
+ txt_mod2_shift,
156
+ txt_mod2_scale,
157
+ txt_mod2_gate,
158
+ ) = self.txt_mod(vec).chunk(6, dim=-1)
159
+
160
+ # Prepare image for attention.
161
+ img_modulated = self.img_norm1(img)
162
+ img_modulated = modulate(
163
+ img_modulated, shift=img_mod1_shift, scale=img_mod1_scale
164
+ )
165
+ img_qkv = self.img_attn_qkv(img_modulated)
166
+ img_q, img_k, img_v = rearrange(
167
+ img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
168
+ )
169
+ # Apply QK-Norm if needed
170
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
171
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
172
+
173
+ # Apply RoPE if needed.
174
+ if freqs_cis is not None:
175
+ img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
176
+ assert (
177
+ img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
178
+ ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
179
+ img_q, img_k = img_qq, img_kk
180
+
181
+ # Prepare txt for attention.
182
+ txt_modulated = self.txt_norm1(txt)
183
+ txt_modulated = modulate(
184
+ txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale
185
+ )
186
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
187
+ txt_q, txt_k, txt_v = rearrange(
188
+ txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
189
+ )
190
+ # Apply QK-Norm if needed.
191
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
192
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
193
+
194
+ # Run actual attention.
195
+ q = torch.cat((img_q, txt_q), dim=1)
196
+ k = torch.cat((img_k, txt_k), dim=1)
197
+ v = torch.cat((img_v, txt_v), dim=1)
198
+ assert (
199
+ cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
200
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
201
+
202
+ # attention computation start
203
+ if not self.hybrid_seq_parallel_attn:
204
+ attn = attention(
205
+ q,
206
+ k,
207
+ v,
208
+ cu_seqlens_q=cu_seqlens_q,
209
+ cu_seqlens_kv=cu_seqlens_kv,
210
+ max_seqlen_q=max_seqlen_q,
211
+ max_seqlen_kv=max_seqlen_kv,
212
+ batch_size=img_k.shape[0],
213
+ )
214
+ else:
215
+ attn = parallel_attention(
216
+ self.hybrid_seq_parallel_attn,
217
+ q,
218
+ k,
219
+ v,
220
+ img_q_len=img_q.shape[1],
221
+ img_kv_len=img_k.shape[1],
222
+ cu_seqlens_q=cu_seqlens_q,
223
+ cu_seqlens_kv=cu_seqlens_kv
224
+ )
225
+
226
+ # attention computation end
227
+
228
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
229
+
230
+ # Calculate the img bloks.
231
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
232
+ img = img + apply_gate(
233
+ self.img_mlp(
234
+ modulate(
235
+ self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale
236
+ )
237
+ ),
238
+ gate=img_mod2_gate,
239
+ )
240
+
241
+ # Calculate the txt bloks.
242
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
243
+ txt = txt + apply_gate(
244
+ self.txt_mlp(
245
+ modulate(
246
+ self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale
247
+ )
248
+ ),
249
+ gate=txt_mod2_gate,
250
+ )
251
+
252
+ return img, txt
253
+
254
+
255
+ class MMSingleStreamBlock(nn.Module):
256
+ """
257
+ A DiT block with parallel linear layers as described in
258
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
259
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
260
+ (Flux.1): https://github.com/black-forest-labs/flux
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ hidden_size: int,
266
+ heads_num: int,
267
+ mlp_width_ratio: float = 4.0,
268
+ mlp_act_type: str = "gelu_tanh",
269
+ qk_norm: bool = True,
270
+ qk_norm_type: str = "rms",
271
+ qk_scale: float = None,
272
+ dtype: Optional[torch.dtype] = None,
273
+ device: Optional[torch.device] = None,
274
+ ):
275
+ factory_kwargs = {"device": device, "dtype": dtype}
276
+ super().__init__()
277
+
278
+ self.deterministic = False
279
+ self.hidden_size = hidden_size
280
+ self.heads_num = heads_num
281
+ head_dim = hidden_size // heads_num
282
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
283
+ self.mlp_hidden_dim = mlp_hidden_dim
284
+ self.scale = qk_scale or head_dim ** -0.5
285
+
286
+ # qkv and mlp_in
287
+ self.linear1 = nn.Linear(
288
+ hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs
289
+ )
290
+ # proj and mlp_out
291
+ self.linear2 = nn.Linear(
292
+ hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs
293
+ )
294
+
295
+ qk_norm_layer = get_norm_layer(qk_norm_type)
296
+ self.q_norm = (
297
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
298
+ if qk_norm
299
+ else nn.Identity()
300
+ )
301
+ self.k_norm = (
302
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
303
+ if qk_norm
304
+ else nn.Identity()
305
+ )
306
+
307
+ self.pre_norm = nn.LayerNorm(
308
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
309
+ )
310
+
311
+ self.mlp_act = get_activation_layer(mlp_act_type)()
312
+ self.modulation = ModulateDiT(
313
+ hidden_size,
314
+ factor=3,
315
+ act_layer=get_activation_layer("silu"),
316
+ **factory_kwargs,
317
+ )
318
+ self.hybrid_seq_parallel_attn = None
319
+
320
+ def enable_deterministic(self):
321
+ self.deterministic = True
322
+
323
+ def disable_deterministic(self):
324
+ self.deterministic = False
325
+
326
+ def forward(
327
+ self,
328
+ x: torch.Tensor,
329
+ vec: torch.Tensor,
330
+ txt_len: int,
331
+ cu_seqlens_q: Optional[torch.Tensor] = None,
332
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
333
+ max_seqlen_q: Optional[int] = None,
334
+ max_seqlen_kv: Optional[int] = None,
335
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
336
+ ) -> torch.Tensor:
337
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
338
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
339
+ qkv, mlp = torch.split(
340
+ self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1
341
+ )
342
+
343
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
344
+
345
+ # Apply QK-Norm if needed.
346
+ q = self.q_norm(q).to(v)
347
+ k = self.k_norm(k).to(v)
348
+
349
+ # Apply RoPE if needed.
350
+ if freqs_cis is not None:
351
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
352
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
353
+ img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
354
+ assert (
355
+ img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
356
+ ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
357
+ img_q, img_k = img_qq, img_kk
358
+ q = torch.cat((img_q, txt_q), dim=1)
359
+ k = torch.cat((img_k, txt_k), dim=1)
360
+
361
+ # Compute attention.
362
+ assert (
363
+ cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
364
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
365
+
366
+ # attention computation start
367
+ if not self.hybrid_seq_parallel_attn:
368
+ attn = attention(
369
+ q,
370
+ k,
371
+ v,
372
+ cu_seqlens_q=cu_seqlens_q,
373
+ cu_seqlens_kv=cu_seqlens_kv,
374
+ max_seqlen_q=max_seqlen_q,
375
+ max_seqlen_kv=max_seqlen_kv,
376
+ batch_size=x.shape[0],
377
+ )
378
+ else:
379
+ attn = parallel_attention(
380
+ self.hybrid_seq_parallel_attn,
381
+ q,
382
+ k,
383
+ v,
384
+ img_q_len=img_q.shape[1],
385
+ img_kv_len=img_k.shape[1],
386
+ cu_seqlens_q=cu_seqlens_q,
387
+ cu_seqlens_kv=cu_seqlens_kv
388
+ )
389
+ # attention computation end
390
+
391
+ # Compute activation in mlp stream, cat again and run second linear layer.
392
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
393
+ return x + apply_gate(output, gate=mod_gate)
394
+
395
+
396
+ class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
397
+ """
398
+ HunyuanVideo Transformer backbone
399
+
400
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
401
+
402
+ Reference:
403
+ [1] Flux.1: https://github.com/black-forest-labs/flux
404
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
405
+
406
+ Parameters
407
+ ----------
408
+ args: argparse.Namespace
409
+ The arguments parsed by argparse.
410
+ patch_size: list
411
+ The size of the patch.
412
+ in_channels: int
413
+ The number of input channels.
414
+ out_channels: int
415
+ The number of output channels.
416
+ hidden_size: int
417
+ The hidden size of the transformer backbone.
418
+ heads_num: int
419
+ The number of attention heads.
420
+ mlp_width_ratio: float
421
+ The ratio of the hidden size of the MLP in the transformer block.
422
+ mlp_act_type: str
423
+ The activation function of the MLP in the transformer block.
424
+ depth_double_blocks: int
425
+ The number of transformer blocks in the double blocks.
426
+ depth_single_blocks: int
427
+ The number of transformer blocks in the single blocks.
428
+ rope_dim_list: list
429
+ The dimension of the rotary embedding for t, h, w.
430
+ qkv_bias: bool
431
+ Whether to use bias in the qkv linear layer.
432
+ qk_norm: bool
433
+ Whether to use qk norm.
434
+ qk_norm_type: str
435
+ The type of qk norm.
436
+ guidance_embed: bool
437
+ Whether to use guidance embedding for distillation.
438
+ text_projection: str
439
+ The type of the text projection, default is single_refiner.
440
+ use_attention_mask: bool
441
+ Whether to use attention mask for text encoder.
442
+ dtype: torch.dtype
443
+ The dtype of the model.
444
+ device: torch.device
445
+ The device of the model.
446
+ """
447
+
448
+ @register_to_config
449
+ def __init__(
450
+ self,
451
+ args: Any,
452
+ patch_size: list = [1, 2, 2],
453
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
454
+ out_channels: int = None,
455
+ hidden_size: int = 3072,
456
+ heads_num: int = 24,
457
+ mlp_width_ratio: float = 4.0,
458
+ mlp_act_type: str = "gelu_tanh",
459
+ mm_double_blocks_depth: int = 20,
460
+ mm_single_blocks_depth: int = 40,
461
+ rope_dim_list: List[int] = [16, 56, 56],
462
+ qkv_bias: bool = True,
463
+ qk_norm: bool = True,
464
+ qk_norm_type: str = "rms",
465
+ guidance_embed: bool = False, # For modulation.
466
+ text_projection: str = "single_refiner",
467
+ use_attention_mask: bool = True,
468
+ dtype: Optional[torch.dtype] = None,
469
+ device: Optional[torch.device] = None,
470
+ ):
471
+ factory_kwargs = {"device": device, "dtype": dtype}
472
+ super().__init__()
473
+
474
+ self.patch_size = patch_size
475
+ self.in_channels = in_channels
476
+ self.out_channels = in_channels if out_channels is None else out_channels
477
+ self.unpatchify_channels = self.out_channels
478
+ self.guidance_embed = guidance_embed
479
+ self.rope_dim_list = rope_dim_list
480
+
481
+ # Text projection. Default to linear projection.
482
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
483
+ self.use_attention_mask = use_attention_mask
484
+ self.text_projection = text_projection
485
+
486
+ self.text_states_dim = args.text_states_dim
487
+ self.text_states_dim_2 = args.text_states_dim_2
488
+
489
+ if hidden_size % heads_num != 0:
490
+ raise ValueError(
491
+ f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}"
492
+ )
493
+ pe_dim = hidden_size // heads_num
494
+ if sum(rope_dim_list) != pe_dim:
495
+ raise ValueError(
496
+ f"Got {rope_dim_list} but expected positional dim {pe_dim}"
497
+ )
498
+ self.hidden_size = hidden_size
499
+ self.heads_num = heads_num
500
+
501
+ # image projection
502
+ self.img_in = PatchEmbed(
503
+ self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs
504
+ )
505
+
506
+ # text projection
507
+ if self.text_projection == "linear":
508
+ self.txt_in = TextProjection(
509
+ self.text_states_dim,
510
+ self.hidden_size,
511
+ get_activation_layer("silu"),
512
+ **factory_kwargs,
513
+ )
514
+ elif self.text_projection == "single_refiner":
515
+ self.txt_in = SingleTokenRefiner(
516
+ self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs
517
+ )
518
+ else:
519
+ raise NotImplementedError(
520
+ f"Unsupported text_projection: {self.text_projection}"
521
+ )
522
+
523
+ # time modulation
524
+ self.time_in = TimestepEmbedder(
525
+ self.hidden_size, get_activation_layer("silu"), **factory_kwargs
526
+ )
527
+
528
+ # text modulation
529
+ self.vector_in = MLPEmbedder(
530
+ self.text_states_dim_2, self.hidden_size, **factory_kwargs
531
+ )
532
+
533
+ # guidance modulation
534
+ self.guidance_in = (
535
+ TimestepEmbedder(
536
+ self.hidden_size, get_activation_layer("silu"), **factory_kwargs
537
+ )
538
+ if guidance_embed
539
+ else None
540
+ )
541
+
542
+ # double blocks
543
+ self.double_blocks = nn.ModuleList(
544
+ [
545
+ MMDoubleStreamBlock(
546
+ self.hidden_size,
547
+ self.heads_num,
548
+ mlp_width_ratio=mlp_width_ratio,
549
+ mlp_act_type=mlp_act_type,
550
+ qk_norm=qk_norm,
551
+ qk_norm_type=qk_norm_type,
552
+ qkv_bias=qkv_bias,
553
+ **factory_kwargs,
554
+ )
555
+ for _ in range(mm_double_blocks_depth)
556
+ ]
557
+ )
558
+
559
+ # single blocks
560
+ self.single_blocks = nn.ModuleList(
561
+ [
562
+ MMSingleStreamBlock(
563
+ self.hidden_size,
564
+ self.heads_num,
565
+ mlp_width_ratio=mlp_width_ratio,
566
+ mlp_act_type=mlp_act_type,
567
+ qk_norm=qk_norm,
568
+ qk_norm_type=qk_norm_type,
569
+ **factory_kwargs,
570
+ )
571
+ for _ in range(mm_single_blocks_depth)
572
+ ]
573
+ )
574
+
575
+ self.final_layer = FinalLayer(
576
+ self.hidden_size,
577
+ self.patch_size,
578
+ self.out_channels,
579
+ get_activation_layer("silu"),
580
+ **factory_kwargs,
581
+ )
582
+
583
+ def enable_deterministic(self):
584
+ for block in self.double_blocks:
585
+ block.enable_deterministic()
586
+ for block in self.single_blocks:
587
+ block.enable_deterministic()
588
+
589
+ def disable_deterministic(self):
590
+ for block in self.double_blocks:
591
+ block.disable_deterministic()
592
+ for block in self.single_blocks:
593
+ block.disable_deterministic()
594
+
595
+ def forward(
596
+ self,
597
+ x: torch.Tensor,
598
+ t: torch.Tensor, # Should be in range(0, 1000).
599
+ text_states: torch.Tensor = None,
600
+ text_mask: torch.Tensor = None, # Now we don't use it.
601
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
602
+ freqs_cos: Optional[torch.Tensor] = None,
603
+ freqs_sin: Optional[torch.Tensor] = None,
604
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
605
+ return_dict: bool = True,
606
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
607
+ out = {}
608
+ img = x
609
+ txt = text_states
610
+ _, _, ot, oh, ow = x.shape
611
+ tt, th, tw = (
612
+ ot // self.patch_size[0],
613
+ oh // self.patch_size[1],
614
+ ow // self.patch_size[2],
615
+ )
616
+
617
+ # Prepare modulation vectors.
618
+ vec = self.time_in(t)
619
+
620
+ # text modulation
621
+ vec = vec + self.vector_in(text_states_2)
622
+
623
+ # guidance modulation
624
+ if self.guidance_embed:
625
+ if guidance is None:
626
+ raise ValueError(
627
+ "Didn't get guidance strength for guidance distilled model."
628
+ )
629
+
630
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
631
+ vec = vec + self.guidance_in(guidance)
632
+
633
+ # Embed image and text.
634
+ img = self.img_in(img)
635
+ if self.text_projection == "linear":
636
+ txt = self.txt_in(txt)
637
+ elif self.text_projection == "single_refiner":
638
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
639
+ else:
640
+ raise NotImplementedError(
641
+ f"Unsupported text_projection: {self.text_projection}"
642
+ )
643
+
644
+ txt_seq_len = txt.shape[1]
645
+ img_seq_len = img.shape[1]
646
+
647
+ # Compute cu_squlens and max_seqlen for flash attention
648
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
649
+ cu_seqlens_kv = cu_seqlens_q
650
+ max_seqlen_q = img_seq_len + txt_seq_len
651
+ max_seqlen_kv = max_seqlen_q
652
+
653
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
654
+ # --------------------- Pass through DiT blocks ------------------------
655
+ for _, block in enumerate(self.double_blocks):
656
+ double_block_args = [
657
+ img,
658
+ txt,
659
+ vec,
660
+ cu_seqlens_q,
661
+ cu_seqlens_kv,
662
+ max_seqlen_q,
663
+ max_seqlen_kv,
664
+ freqs_cis,
665
+ ]
666
+
667
+ img, txt = block(*double_block_args)
668
+
669
+ # Merge txt and img to pass through single stream blocks.
670
+ x = torch.cat((img, txt), 1)
671
+ if len(self.single_blocks) > 0:
672
+ for _, block in enumerate(self.single_blocks):
673
+ single_block_args = [
674
+ x,
675
+ vec,
676
+ txt_seq_len,
677
+ cu_seqlens_q,
678
+ cu_seqlens_kv,
679
+ max_seqlen_q,
680
+ max_seqlen_kv,
681
+ (freqs_cos, freqs_sin),
682
+ ]
683
+
684
+ x = block(*single_block_args)
685
+
686
+ img = x[:, :img_seq_len, ...]
687
+
688
+ # ---------------------------- Final layer ------------------------------
689
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
690
+
691
+ img = self.unpatchify(img, tt, th, tw)
692
+ if return_dict:
693
+ out["x"] = img
694
+ return out
695
+ return img
696
+
697
+ def unpatchify(self, x, t, h, w):
698
+ """
699
+ x: (N, T, patch_size**2 * C)
700
+ imgs: (N, H, W, C)
701
+ """
702
+ c = self.unpatchify_channels
703
+ pt, ph, pw = self.patch_size
704
+ assert t * h * w == x.shape[1]
705
+
706
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
707
+ x = torch.einsum("nthwcopq->nctohpwq", x)
708
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
709
+
710
+ return imgs
711
+
712
+ def params_count(self):
713
+ counts = {
714
+ "double": sum(
715
+ [
716
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
717
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
718
+ + sum(p.numel() for p in block.img_mlp.parameters())
719
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
720
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
721
+ + sum(p.numel() for p in block.txt_mlp.parameters())
722
+ for block in self.double_blocks
723
+ ]
724
+ ),
725
+ "single": sum(
726
+ [
727
+ sum(p.numel() for p in block.linear1.parameters())
728
+ + sum(p.numel() for p in block.linear2.parameters())
729
+ for block in self.single_blocks
730
+ ]
731
+ ),
732
+ "total": sum(p.numel() for p in self.parameters()),
733
+ }
734
+ counts["attn+mlp"] = counts["double"] + counts["single"]
735
+ return counts
736
+
737
+
738
+ #################################################################################
739
+ # HunyuanVideo Configs #
740
+ #################################################################################
741
+
742
+ HUNYUAN_VIDEO_CONFIG = {
743
+ "HYVideo-T/2": {
744
+ "mm_double_blocks_depth": 20,
745
+ "mm_single_blocks_depth": 40,
746
+ "rope_dim_list": [16, 56, 56],
747
+ "hidden_size": 3072,
748
+ "heads_num": 24,
749
+ "mlp_width_ratio": 4,
750
+ },
751
+ "HYVideo-T/2-cfgdistill": {
752
+ "mm_double_blocks_depth": 20,
753
+ "mm_single_blocks_depth": 40,
754
+ "rope_dim_list": [16, 56, 56],
755
+ "hidden_size": 3072,
756
+ "heads_num": 24,
757
+ "mlp_width_ratio": 4,
758
+ "guidance_embed": True,
759
+ },
760
+ }
hyvideo/modules/placement.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+
5
+ def hunyuan_token_reorder_to_token_major(tensor, fix_len, reorder_len, reorder_num_frame, frame_size):
6
+ """Reorder it from frame major to token major!"""
7
+ assert reorder_len == reorder_num_frame * frame_size
8
+ assert tensor.shape[2] == fix_len + reorder_len
9
+
10
+ tensor[:, :, :-fix_len, :] = tensor[:, :, :-fix_len:, :].reshape(tensor.shape[0], tensor.shape[1], reorder_num_frame, frame_size, tensor.shape[3]) \
11
+ .transpose(2, 3).reshape(tensor.shape[0], tensor.shape[1], reorder_len, tensor.shape[3])
12
+ return tensor
13
+
14
+ def hunyuan_token_reorder_to_frame_major(tensor, fix_len, reorder_len, reorder_num_frame, frame_size):
15
+ """Reorder it from token major to frame major!"""
16
+ assert reorder_len == reorder_num_frame * frame_size
17
+ assert tensor.shape[2] == fix_len + reorder_len
18
+
19
+ tensor[:, :, :-fix_len:, :] = tensor[:, :, :-fix_len:, :].reshape(tensor.shape[0], tensor.shape[1], frame_size, reorder_num_frame, tensor.shape[3]) \
20
+ .transpose(2, 3).reshape(tensor.shape[0], tensor.shape[1], reorder_len, tensor.shape[3])
21
+ return tensor
22
+
23
+
24
+ @triton.jit
25
+ def hunyuan_sparse_head_placement_kernel(
26
+ query_ptr, key_ptr, value_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
27
+ query_out_ptr, key_out_ptr, value_out_ptr, # [cfg, num_heads, seq_len, head_dim]
28
+ best_mask_idx_ptr, # [cfg, num_heads]
29
+ query_stride_b, query_stride_h, query_stride_s, query_stride_d,
30
+ mask_idx_stride_b, mask_idx_stride_h,
31
+ seq_len: tl.constexpr,
32
+ head_dim: tl.constexpr,
33
+ context_length: tl.constexpr,
34
+ num_frame: tl.constexpr,
35
+ frame_size: tl.constexpr,
36
+ BLOCK_SIZE: tl.constexpr
37
+ ):
38
+ # Copy query, key, value to output
39
+ # range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
40
+ cfg = tl.program_id(0)
41
+ head = tl.program_id(1)
42
+ block_id = tl.program_id(2)
43
+
44
+ start_id = block_id * BLOCK_SIZE
45
+ end_id = start_id + BLOCK_SIZE
46
+ end_id = tl.where(end_id > seq_len, seq_len, end_id)
47
+
48
+ # Load best mask idx (0 is spatial, 1 is temporal)
49
+ is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
50
+
51
+ offset_token = tl.arange(0, BLOCK_SIZE) + start_id
52
+ offset_mask = offset_token < seq_len
53
+ offset_d = tl.arange(0, head_dim)
54
+
55
+ if is_temporal:
56
+ frame_id = offset_token // frame_size
57
+ patch_id = offset_token - frame_id * frame_size
58
+ offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, patch_id * num_frame + frame_id)
59
+
60
+ offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:,None] * query_stride_s) + offset_d[None,:] * query_stride_d
61
+ offset_query = query_ptr + offset_load
62
+ offset_key = key_ptr + offset_load
63
+ offset_value = value_ptr + offset_load
64
+
65
+ offset_store = (cfg * query_stride_b + head * query_stride_h + offset_store_token[:,None] * query_stride_s) + offset_d[None,:] * query_stride_d
66
+ offset_query_out = query_out_ptr + offset_store
67
+ offset_key_out = key_out_ptr + offset_store
68
+ offset_value_out = value_out_ptr + offset_store
69
+
70
+ # Maybe tune the pipeline here
71
+ query = tl.load(offset_query, mask=offset_mask[:,None])
72
+ tl.store(offset_query_out, query, mask=offset_mask[:,None])
73
+ key = tl.load(offset_key, mask=offset_mask[:,None])
74
+ tl.store(offset_key_out, key, mask=offset_mask[:,None])
75
+ value = tl.load(offset_value, mask=offset_mask[:,None])
76
+ tl.store(offset_value_out, value, mask=offset_mask[:,None])
77
+
78
+
79
+ else:
80
+ offset_load = (cfg * query_stride_b + head * query_stride_h + offset_token[:,None] * query_stride_s) + offset_d[None,:] * query_stride_d
81
+ offset_query = query_ptr + offset_load
82
+ offset_key = key_ptr + offset_load
83
+ offset_value = value_ptr + offset_load
84
+
85
+ offset_store = offset_load
86
+ offset_query_out = query_out_ptr + offset_store
87
+ offset_key_out = key_out_ptr + offset_store
88
+ offset_value_out = value_out_ptr + offset_store
89
+
90
+ # Maybe tune the pipeline here
91
+ query = tl.load(offset_query, mask=offset_mask[:,None])
92
+ tl.store(offset_query_out, query, mask=offset_mask[:,None])
93
+ key = tl.load(offset_key, mask=offset_mask[:,None])
94
+ tl.store(offset_key_out, key, mask=offset_mask[:,None])
95
+ value = tl.load(offset_value, mask=offset_mask[:,None])
96
+ tl.store(offset_value_out, value, mask=offset_mask[:,None])
97
+
98
+
99
+ def hunyuan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size):
100
+ cfg, num_heads, seq_len, head_dim = query.shape
101
+ BLOCK_SIZE = 128
102
+ assert seq_len == context_length + num_frame * frame_size
103
+
104
+ grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
105
+
106
+ hunyuan_sparse_head_placement_kernel[grid](
107
+ query, key, value,
108
+ query_out, key_out, value_out,
109
+ best_mask_idx,
110
+ query.stride(0), query.stride(1), query.stride(2), query.stride(3),
111
+ best_mask_idx.stride(0), best_mask_idx.stride(1),
112
+ seq_len, head_dim, context_length, num_frame, frame_size,
113
+ BLOCK_SIZE
114
+ )
115
+
116
+
117
+ def ref_hunyuan_sparse_head_placement(query, key, value, best_mask_idx, context_length, num_frame, frame_size):
118
+ cfg, num_heads, seq_len, head_dim = query.shape
119
+ assert seq_len == context_length + num_frame * frame_size
120
+
121
+ query_out = query.clone()
122
+ key_out = key.clone()
123
+ value_out = value.clone()
124
+
125
+ # Spatial
126
+ query_out[best_mask_idx == 0], key_out[best_mask_idx == 0], value_out[best_mask_idx == 0] = \
127
+ query[best_mask_idx == 0], key[best_mask_idx == 0], value[best_mask_idx == 0]
128
+
129
+ # Temporal
130
+ query_out[best_mask_idx == 1], key_out[best_mask_idx == 1], value_out[best_mask_idx == 1] = \
131
+ hunyuan_token_reorder_to_token_major(query[best_mask_idx == 1].unsqueeze(0), context_length, num_frame * frame_size, num_frame, frame_size).squeeze(0), \
132
+ hunyuan_token_reorder_to_token_major(key[best_mask_idx == 1].unsqueeze(0), context_length, num_frame * frame_size, num_frame, frame_size).squeeze(0), \
133
+ hunyuan_token_reorder_to_token_major(value[best_mask_idx == 1].unsqueeze(0), context_length, num_frame * frame_size, num_frame, frame_size).squeeze(0)
134
+
135
+ return query_out, key_out, value_out
136
+
137
+
138
+ def test_hunyuan_sparse_head_placement():
139
+
140
+ context_length = 226
141
+ num_frame = 11
142
+ frame_size = 4080
143
+
144
+ cfg = 2
145
+ num_heads = 48
146
+
147
+ seq_len = context_length + num_frame * frame_size
148
+ head_dim = 64
149
+
150
+ dtype = torch.bfloat16
151
+ device = torch.device("cuda")
152
+
153
+ query = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
154
+ key = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
155
+ value = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
156
+
157
+ best_mask_idx = torch.randint(0, 2, (cfg, num_heads), device=device)
158
+
159
+ query_out = torch.empty_like(query)
160
+ key_out = torch.empty_like(key)
161
+ value_out = torch.empty_like(value)
162
+
163
+ hunyuan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
164
+ ref_query_out, ref_key_out, ref_value_out = ref_hunyuan_sparse_head_placement(query, key, value, best_mask_idx, context_length, num_frame, frame_size)
165
+
166
+ torch.testing.assert_close(query_out, ref_query_out)
167
+ torch.testing.assert_close(key_out, ref_key_out)
168
+ torch.testing.assert_close(value_out, ref_value_out)
169
+
170
+
171
+ def benchmark_hunyuan_sparse_head_placement():
172
+ import time
173
+
174
+ context_length = 226
175
+ num_frame = 11
176
+ frame_size = 4080
177
+
178
+ cfg = 2
179
+ num_heads = 48
180
+
181
+ seq_len = context_length + num_frame * frame_size
182
+ head_dim = 64
183
+
184
+ dtype = torch.bfloat16
185
+ device = torch.device("cuda")
186
+
187
+ query = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
188
+ key = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
189
+ value = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
190
+ best_mask_idx = torch.randint(0, 2, (cfg, num_heads), device=device)
191
+
192
+ query_out = torch.empty_like(query)
193
+ key_out = torch.empty_like(key)
194
+ value_out = torch.empty_like(value)
195
+
196
+ warmup = 10
197
+ all_iter = 1000
198
+
199
+ # warmup
200
+ for _ in range(warmup):
201
+ hunyuan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
202
+
203
+ torch.cuda.synchronize()
204
+ start = time.time()
205
+ for _ in range(all_iter):
206
+ hunyuan_sparse_head_placement(query, key, value, query_out, key_out, value_out, best_mask_idx, context_length, num_frame, frame_size)
207
+ torch.cuda.synchronize()
208
+ end = time.time()
209
+
210
+ print(f"Triton Elapsed Time: {(end - start) / all_iter * 1e3:.2f} ms")
211
+ print(f"Triton Total Bandwidth: {query.nelement() * query.element_size() * 3 * 2 * all_iter / (end - start) / 1e9:.2f} GB/s")
212
+
213
+ torch.cuda.synchronize()
214
+ start = time.time()
215
+ for _ in range(all_iter):
216
+ ref_hunyuan_sparse_head_placement(query, key, value, best_mask_idx, context_length, num_frame, frame_size)
217
+ torch.cuda.synchronize()
218
+ end = time.time()
219
+
220
+ print(f"Reference Elapsed Time: {(end - start) / all_iter * 1e3:.2f} ms")
221
+ print(f"Reference Total Bandwidth: {query.nelement() * query.element_size() * 3 * 2 * all_iter / (end - start) / 1e9:.2f} GB/s")
222
+
223
+
224
+ @triton.jit
225
+ def hunyuan_hidden_states_placement_kernel(
226
+ hidden_states_ptr, # [cfg, num_heads, seq_len, head_dim] seq_len = context_length + num_frame * frame_size
227
+ hidden_states_out_ptr, # [cfg, num_heads, seq_len, head_dim]
228
+ best_mask_idx_ptr, # [cfg, num_heads]
229
+ hidden_states_stride_b, hidden_states_stride_h, hidden_states_stride_s, hidden_states_stride_d,
230
+ mask_idx_stride_b, mask_idx_stride_h,
231
+ seq_len: tl.constexpr,
232
+ head_dim: tl.constexpr,
233
+ context_length: tl.constexpr,
234
+ num_frame: tl.constexpr,
235
+ frame_size: tl.constexpr,
236
+ BLOCK_SIZE: tl.constexpr
237
+ ):
238
+ # Copy hidden_states to output
239
+ # range: [b, h, block_id * block_size: block_id * block_size + block_size, :]
240
+ cfg = tl.program_id(0)
241
+ head = tl.program_id(1)
242
+ block_id = tl.program_id(2)
243
+
244
+ start_id = block_id * BLOCK_SIZE
245
+ end_id = start_id + BLOCK_SIZE
246
+ end_id = tl.where(end_id > seq_len, seq_len, end_id)
247
+
248
+ # Load best mask idx (0 is spatial, 1 is temporal)
249
+ is_temporal = tl.load(best_mask_idx_ptr + cfg * mask_idx_stride_b + head * mask_idx_stride_h)
250
+
251
+ offset_token = tl.arange(0, BLOCK_SIZE) + start_id
252
+ offset_mask = offset_token < seq_len
253
+ offset_d = tl.arange(0, head_dim)
254
+
255
+ if is_temporal:
256
+ patch_id = offset_token // num_frame
257
+ frame_id = offset_token - patch_id * num_frame
258
+ offset_store_token = tl.where(offset_token >= seq_len - context_length, offset_token, frame_id * frame_size + patch_id)
259
+
260
+ offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:,None] * hidden_states_stride_s) + offset_d[None,:] * hidden_states_stride_d
261
+ offset_hidden_states = hidden_states_ptr + offset_load
262
+
263
+ offset_store = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_store_token[:,None] * hidden_states_stride_s) + offset_d[None,:] * hidden_states_stride_d
264
+ offset_hidden_states_out = hidden_states_out_ptr + offset_store
265
+
266
+ # Maybe tune the pipeline here
267
+ hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:,None])
268
+ tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:,None])
269
+ else:
270
+ offset_load = (cfg * hidden_states_stride_b + head * hidden_states_stride_h + offset_token[:,None] * hidden_states_stride_s) + offset_d[None,:] * hidden_states_stride_d
271
+ offset_hidden_states = hidden_states_ptr + offset_load
272
+
273
+ offset_store = offset_load
274
+ offset_hidden_states_out = hidden_states_out_ptr + offset_store
275
+
276
+ # Maybe tune the pipeline here
277
+ hidden_states = tl.load(offset_hidden_states, mask=offset_mask[:,None])
278
+ tl.store(offset_hidden_states_out, hidden_states, mask=offset_mask[:,None])
279
+
280
+
281
+ def hunyuan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size):
282
+ cfg, num_heads, seq_len, head_dim = hidden_states.shape
283
+ BLOCK_SIZE = 128
284
+ assert seq_len == context_length + num_frame * frame_size
285
+
286
+ grid = (cfg, num_heads, (seq_len + BLOCK_SIZE - 1) // BLOCK_SIZE)
287
+
288
+
289
+ hunyuan_hidden_states_placement_kernel[grid](
290
+ hidden_states,
291
+ hidden_states_out,
292
+ best_mask_idx,
293
+ hidden_states.stride(0), hidden_states.stride(1), hidden_states.stride(2), hidden_states.stride(3),
294
+ best_mask_idx.stride(0), best_mask_idx.stride(1),
295
+ seq_len, head_dim, context_length, num_frame, frame_size,
296
+ BLOCK_SIZE
297
+ )
298
+
299
+ return hidden_states_out
300
+
301
+ def ref_hunyuan_hidden_states_placement(hidden_states, output_hidden_states, best_mask_idx, context_length, num_frame, frame_size):
302
+ cfg, num_heads, seq_len, head_dim = hidden_states.shape
303
+ assert seq_len == context_length + num_frame * frame_size
304
+
305
+ # Spatial
306
+ output_hidden_states[best_mask_idx == 0] = hidden_states[best_mask_idx == 0]
307
+ # Temporal
308
+ output_hidden_states[best_mask_idx == 1] = hunyuan_token_reorder_to_frame_major(hidden_states[best_mask_idx == 1].unsqueeze(0), context_length, num_frame * frame_size, num_frame, frame_size).squeeze(0)
309
+
310
+ def test_hunyuan_hidden_states_placement():
311
+
312
+ context_length = 226
313
+ num_frame = 11
314
+ frame_size = 4080
315
+
316
+ cfg = 2
317
+ num_heads = 48
318
+
319
+ seq_len = context_length + num_frame * frame_size
320
+ head_dim = 64
321
+
322
+ dtype = torch.bfloat16
323
+ device = torch.device("cuda")
324
+
325
+ hidden_states = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
326
+ best_mask_idx = torch.randint(0, 2, (cfg, num_heads), device=device)
327
+
328
+ hidden_states_out1 = torch.empty_like(hidden_states)
329
+ hidden_states_out2 = torch.empty_like(hidden_states)
330
+
331
+ hunyuan_hidden_states_placement(hidden_states, hidden_states_out1, best_mask_idx, context_length, num_frame, frame_size)
332
+ ref_hunyuan_hidden_states_placement(hidden_states, hidden_states_out2, best_mask_idx, context_length, num_frame, frame_size)
333
+
334
+ torch.testing.assert_close(hidden_states_out1, hidden_states_out2)
335
+
336
+ def benchmark_hunyuan_hidden_states_placement():
337
+ import time
338
+
339
+ context_length = 226
340
+ num_frame = 11
341
+ frame_size = 4080
342
+
343
+ cfg = 2
344
+ num_heads = 48
345
+
346
+ seq_len = context_length + num_frame * frame_size
347
+ head_dim = 64
348
+
349
+ dtype = torch.bfloat16
350
+ device = torch.device("cuda")
351
+
352
+ hidden_states = torch.randn(cfg, num_heads, seq_len, head_dim, dtype=dtype, device=device)
353
+ best_mask_idx = torch.randint(0, 2, (cfg, num_heads), device=device)
354
+
355
+ hidden_states_out = torch.empty_like(hidden_states)
356
+
357
+ warmup = 10
358
+ all_iter = 1000
359
+
360
+ # warmup
361
+ for _ in range(warmup):
362
+ hunyuan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size)
363
+
364
+ torch.cuda.synchronize()
365
+ start = time.time()
366
+ for _ in range(all_iter):
367
+ hunyuan_hidden_states_placement(hidden_states, hidden_states_out, best_mask_idx, context_length, num_frame, frame_size)
368
+ torch.cuda.synchronize()
369
+ end = time.time()
370
+
371
+ print(f"Triton Elapsed Time: {(end - start) / all_iter * 1e3:.2f} ms")
372
+ print(f"Triton Total Bandwidth: {hidden_states.nelement() * hidden_states.element_size() * 2 * all_iter / (end - start) / 1e9:.2f} GB/s")
373
+
374
+ torch.cuda.synchronize()
375
+ start = time.time()
376
+ for _ in range(all_iter):
377
+ ref_hunyuan_hidden_states_placement(hidden_states, hidden_states.clone(), best_mask_idx, context_length, num_frame, frame_size)
378
+ torch.cuda.synchronize()
379
+ end = time.time()
380
+
381
+ print(f"Reference Elapsed Time: {(end - start) / all_iter * 1e3:.2f} ms")
382
+ print(f"Reference Total Bandwidth: {hidden_states.nelement() * hidden_states.element_size() * 2 * all_iter / (end - start) / 1e9:.2f} GB/s")
383
+
384
+
385
+ if __name__ == "__main__":
386
+ test_hunyuan_sparse_head_placement()
387
+ benchmark_hunyuan_sparse_head_placement()
388
+ test_hunyuan_hidden_states_placement()
389
+ benchmark_hunyuan_hidden_states_placement()
hyvideo/modules/posemb_layers.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List, Optional
3
+ import numpy as np
4
+
5
+
6
+ ###### Thanks to the RifleX project (https://github.com/thu-ml/RIFLEx/) for this alternative pos embed for long videos
7
+ #
8
+ def get_1d_rotary_pos_embed_riflex(
9
+ dim: int,
10
+ pos: Union[np.ndarray, int],
11
+ theta: float = 10000.0,
12
+ use_real=False,
13
+ k: Optional[int] = None,
14
+ L_test: Optional[int] = None,
15
+ ):
16
+ """
17
+ RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
18
+
19
+ This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
20
+ index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
21
+ data type.
22
+
23
+ Args:
24
+ dim (`int`): Dimension of the frequency tensor.
25
+ pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
26
+ theta (`float`, *optional*, defaults to 10000.0):
27
+ Scaling factor for frequency computation. Defaults to 10000.0.
28
+ use_real (`bool`, *optional*):
29
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
30
+ k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
31
+ L_test (`int`, *optional*, defaults to None): the number of frames for inference
32
+ Returns:
33
+ `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
34
+ """
35
+ assert dim % 2 == 0
36
+
37
+ if isinstance(pos, int):
38
+ pos = torch.arange(pos)
39
+ if isinstance(pos, np.ndarray):
40
+ pos = torch.from_numpy(pos) # type: ignore # [S]
41
+
42
+ freqs = 1.0 / (
43
+ theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)
44
+ ) # [D/2]
45
+
46
+ # === Riflex modification start ===
47
+ # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
48
+ # Empirical observations show that a few videos may exhibit repetition in the tail frames.
49
+ # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
50
+ if k is not None:
51
+ freqs[k-1] = 0.9 * 2 * torch.pi / L_test
52
+ # === Riflex modification end ===
53
+
54
+ freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
55
+ if use_real:
56
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
57
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
58
+ return freqs_cos, freqs_sin
59
+ else:
60
+ # lumina
61
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
62
+ return freqs_cis
63
+
64
+ def identify_k( b: float, d: int, N: int):
65
+ """
66
+ This function identifies the index of the intrinsic frequency component in a RoPE-based pre-trained diffusion transformer.
67
+
68
+ Args:
69
+ b (`float`): The base frequency for RoPE.
70
+ d (`int`): Dimension of the frequency tensor
71
+ N (`int`): the first observed repetition frame in latent space
72
+ Returns:
73
+ k (`int`): the index of intrinsic frequency component
74
+ N_k (`int`): the period of intrinsic frequency component in latent space
75
+ Example:
76
+ In HunyuanVideo, b=256 and d=16, the repetition occurs approximately 8s (N=48 in latent space).
77
+ k, N_k = identify_k(b=256, d=16, N=48)
78
+ In this case, the intrinsic frequency index k is 4, and the period N_k is 50.
79
+ """
80
+
81
+ # Compute the period of each frequency in RoPE according to Eq.(4)
82
+ periods = []
83
+ for j in range(1, d // 2 + 1):
84
+ theta_j = 1.0 / (b ** (2 * (j - 1) / d))
85
+ N_j = round(2 * torch.pi / theta_j)
86
+ periods.append(N_j)
87
+
88
+ # Identify the intrinsic frequency whose period is closed to N(see Eq.(7))
89
+ diffs = [abs(N_j - N) for N_j in periods]
90
+ k = diffs.index(min(diffs)) + 1
91
+ N_k = periods[k-1]
92
+ return k, N_k
93
+
94
+ def _to_tuple(x, dim=2):
95
+ if isinstance(x, int):
96
+ return (x,) * dim
97
+ elif len(x) == dim:
98
+ return x
99
+ else:
100
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
101
+
102
+
103
+ def get_meshgrid_nd(start, *args, dim=2):
104
+ """
105
+ Get n-D meshgrid with start, stop and num.
106
+
107
+ Args:
108
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
109
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
110
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
111
+ n-tuples.
112
+ *args: See above.
113
+ dim (int): Dimension of the meshgrid. Defaults to 2.
114
+
115
+ Returns:
116
+ grid (np.ndarray): [dim, ...]
117
+ """
118
+ if len(args) == 0:
119
+ # start is grid_size
120
+ num = _to_tuple(start, dim=dim)
121
+ start = (0,) * dim
122
+ stop = num
123
+ elif len(args) == 1:
124
+ # start is start, args[0] is stop, step is 1
125
+ start = _to_tuple(start, dim=dim)
126
+ stop = _to_tuple(args[0], dim=dim)
127
+ num = [stop[i] - start[i] for i in range(dim)]
128
+ elif len(args) == 2:
129
+ # start is start, args[0] is stop, args[1] is num
130
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
131
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
132
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
133
+ else:
134
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
135
+
136
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
137
+ axis_grid = []
138
+ for i in range(dim):
139
+ a, b, n = start[i], stop[i], num[i]
140
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
141
+ axis_grid.append(g)
142
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
143
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
144
+
145
+ return grid
146
+
147
+
148
+ #################################################################################
149
+ # Rotary Positional Embedding Functions #
150
+ #################################################################################
151
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
152
+
153
+
154
+ def reshape_for_broadcast(
155
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
156
+ x: torch.Tensor,
157
+ head_first=False,
158
+ ):
159
+ """
160
+ Reshape frequency tensor for broadcasting it with another tensor.
161
+
162
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
163
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
164
+
165
+ Notes:
166
+ When using FlashMHAModified, head_first should be False.
167
+ When using Attention, head_first should be True.
168
+
169
+ Args:
170
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
171
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
172
+ head_first (bool): head dimension first (except batch dim) or not.
173
+
174
+ Returns:
175
+ torch.Tensor: Reshaped frequency tensor.
176
+
177
+ Raises:
178
+ AssertionError: If the frequency tensor doesn't match the expected shape.
179
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
180
+ """
181
+ ndim = x.ndim
182
+ assert 0 <= 1 < ndim
183
+
184
+ if isinstance(freqs_cis, tuple):
185
+ # freqs_cis: (cos, sin) in real space
186
+ if head_first:
187
+ assert freqs_cis[0].shape == (
188
+ x.shape[-2],
189
+ x.shape[-1],
190
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
191
+ shape = [
192
+ d if i == ndim - 2 or i == ndim - 1 else 1
193
+ for i, d in enumerate(x.shape)
194
+ ]
195
+ else:
196
+ assert freqs_cis[0].shape == (
197
+ x.shape[1],
198
+ x.shape[-1],
199
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
200
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
201
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
202
+ else:
203
+ # freqs_cis: values in complex space
204
+ if head_first:
205
+ assert freqs_cis.shape == (
206
+ x.shape[-2],
207
+ x.shape[-1],
208
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
209
+ shape = [
210
+ d if i == ndim - 2 or i == ndim - 1 else 1
211
+ for i, d in enumerate(x.shape)
212
+ ]
213
+ else:
214
+ assert freqs_cis.shape == (
215
+ x.shape[1],
216
+ x.shape[-1],
217
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
218
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
219
+ return freqs_cis.view(*shape)
220
+
221
+
222
+ def rotate_half(x):
223
+ x_real, x_imag = (
224
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
225
+ ) # [B, S, H, D//2]
226
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
227
+
228
+
229
+ def apply_rotary_emb( qklist,
230
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
231
+ head_first: bool = False,
232
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
233
+ """
234
+ Apply rotary embeddings to input tensors using the given frequency tensor.
235
+
236
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
237
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
238
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
239
+ returned as real tensors.
240
+
241
+ Args:
242
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
243
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
244
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
245
+ head_first (bool): head dimension first (except batch dim) or not.
246
+
247
+ Returns:
248
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
249
+
250
+ """
251
+ xq, xk = qklist
252
+ qklist.clear()
253
+ xk_out = None
254
+ if isinstance(freqs_cis, tuple):
255
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
256
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
257
+ # real * cos - imag * sin
258
+ # imag * cos + real * sin
259
+ xq_dtype = xq.dtype
260
+ xq_out = xq.to(torch.float)
261
+ xq = None
262
+ xq_rot = rotate_half(xq_out)
263
+ xq_out *= cos
264
+ xq_rot *= sin
265
+ xq_out += xq_rot
266
+ del xq_rot
267
+ xq_out = xq_out.to(xq_dtype)
268
+
269
+ xk_out = xk.to(torch.float)
270
+ xk = None
271
+ xk_rot = rotate_half(xk_out)
272
+ xk_out *= cos
273
+ xk_rot *= sin
274
+ xk_out += xk_rot
275
+ del xk_rot
276
+ xk_out = xk_out.to(xq_dtype)
277
+ else:
278
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
279
+ xq_ = torch.view_as_complex(
280
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
281
+ ) # [B, S, H, D//2]
282
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
283
+ xq.device
284
+ ) # [S, D//2] --> [1, S, 1, D//2]
285
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
286
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
287
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
288
+ xk_ = torch.view_as_complex(
289
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
290
+ ) # [B, S, H, D//2]
291
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
292
+
293
+ return xq_out, xk_out
294
+
295
+ def get_nd_rotary_pos_embed_new(rope_dim_list, start, *args, theta=10000., use_real=False,
296
+ theta_rescale_factor: Union[float, List[float]]=1.0,
297
+ interpolation_factor: Union[float, List[float]]=1.0,
298
+ concat_dict={}
299
+ ):
300
+
301
+ grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
302
+ if len(concat_dict)<1:
303
+ pass
304
+ else:
305
+ if concat_dict['mode']=='timecat':
306
+ bias = grid[:,:1].clone()
307
+ bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
308
+ grid = torch.cat([bias, grid], dim=1)
309
+
310
+ elif concat_dict['mode']=='timecat-w':
311
+ bias = grid[:,:1].clone()
312
+ bias[0] = concat_dict['bias']*torch.ones_like(bias[0])
313
+ bias[2] += start[-1] ## ref https://github.com/Yuanshi9815/OminiControl/blob/main/src/generate.py#L178
314
+ grid = torch.cat([bias, grid], dim=1)
315
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
316
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
317
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
318
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
319
+ assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
320
+
321
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
322
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
323
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
324
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
325
+ assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
326
+
327
+ # use 1/ndim of dimensions to encode grid_axis
328
+ embs = []
329
+ for i in range(len(rope_dim_list)):
330
+ emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=use_real,
331
+ theta_rescale_factor=theta_rescale_factor[i],
332
+ interpolation_factor=interpolation_factor[i]) # 2 x [WHD, rope_dim_list[i]]
333
+
334
+ embs.append(emb)
335
+
336
+ if use_real:
337
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
338
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
339
+ return cos, sin
340
+ else:
341
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
342
+ return emb
343
+
344
+ def get_nd_rotary_pos_embed(
345
+ rope_dim_list,
346
+ start,
347
+ *args,
348
+ theta=10000.0,
349
+ use_real=False,
350
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
351
+ interpolation_factor: Union[float, List[float]] = 1.0,
352
+ k = 4,
353
+ L_test = 66,
354
+ enable_riflex = True
355
+ ):
356
+ """
357
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
358
+
359
+ Args:
360
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
361
+ sum(rope_dim_list) should equal to head_dim of attention layer.
362
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
363
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
364
+ *args: See above.
365
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
366
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
367
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
368
+ part and an imaginary part separately.
369
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
370
+
371
+ Returns:
372
+ pos_embed (torch.Tensor): [HW, D/2]
373
+ """
374
+
375
+ grid = get_meshgrid_nd(
376
+ start, *args, dim=len(rope_dim_list)
377
+ ) # [3, W, H, D] / [2, W, H]
378
+
379
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
380
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
381
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
382
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
383
+ assert len(theta_rescale_factor) == len(
384
+ rope_dim_list
385
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
386
+
387
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
388
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
389
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
390
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
391
+ assert len(interpolation_factor) == len(
392
+ rope_dim_list
393
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
394
+
395
+ # use 1/ndim of dimensions to encode grid_axis
396
+ embs = []
397
+ for i in range(len(rope_dim_list)):
398
+ # emb = get_1d_rotary_pos_embed(
399
+ # rope_dim_list[i],
400
+ # grid[i].reshape(-1),
401
+ # theta,
402
+ # use_real=use_real,
403
+ # theta_rescale_factor=theta_rescale_factor[i],
404
+ # interpolation_factor=interpolation_factor[i],
405
+ # ) # 2 x [WHD, rope_dim_list[i]]
406
+
407
+
408
+ # === RIFLEx modification start ===
409
+ # apply RIFLEx for time dimension
410
+ if i == 0 and enable_riflex:
411
+ emb = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, k=k, L_test=L_test)
412
+ # === RIFLEx modification end ===
413
+ else:
414
+ emb = get_1d_rotary_pos_embed(rope_dim_list[i], grid[i].reshape(-1), theta, use_real=True, theta_rescale_factor=theta_rescale_factor[i],interpolation_factor=interpolation_factor[i],)
415
+ embs.append(emb)
416
+
417
+ if use_real:
418
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
419
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
420
+ return cos, sin
421
+ else:
422
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
423
+ return emb
424
+
425
+
426
+ def get_1d_rotary_pos_embed(
427
+ dim: int,
428
+ pos: Union[torch.FloatTensor, int],
429
+ theta: float = 10000.0,
430
+ use_real: bool = False,
431
+ theta_rescale_factor: float = 1.0,
432
+ interpolation_factor: float = 1.0,
433
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
434
+ """
435
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
436
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
437
+
438
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
439
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
440
+ The returned tensor contains complex values in complex64 data type.
441
+
442
+ Args:
443
+ dim (int): Dimension of the frequency tensor.
444
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
445
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
446
+ use_real (bool, optional): If True, return real part and imaginary part separately.
447
+ Otherwise, return complex numbers.
448
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
449
+
450
+ Returns:
451
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
452
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
453
+ """
454
+ if isinstance(pos, int):
455
+ pos = torch.arange(pos).float()
456
+
457
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
458
+ # has some connection to NTK literature
459
+ if theta_rescale_factor != 1.0:
460
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
461
+
462
+ freqs = 1.0 / (
463
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
464
+ ) # [D/2]
465
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
466
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
467
+ if use_real:
468
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
469
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
470
+ return freqs_cos, freqs_sin
471
+ else:
472
+ freqs_cis = torch.polar(
473
+ torch.ones_like(freqs), freqs
474
+ ) # complex64 # [S, D/2]
475
+ return freqs_cis
hyvideo/modules/token_refiner.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from .activation_layers import get_activation_layer
8
+ from .attenion import attention
9
+ from .norm_layers import get_norm_layer
10
+ from .embed_layers import TimestepEmbedder, TextProjection
11
+ from .attenion import attention
12
+ from .mlp_layers import MLP
13
+ from .modulate_layers import modulate, apply_gate
14
+
15
+
16
+ class IndividualTokenRefinerBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ hidden_size,
20
+ heads_num,
21
+ mlp_width_ratio: str = 4.0,
22
+ mlp_drop_rate: float = 0.0,
23
+ act_type: str = "silu",
24
+ qk_norm: bool = False,
25
+ qk_norm_type: str = "layer",
26
+ qkv_bias: bool = True,
27
+ dtype: Optional[torch.dtype] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ factory_kwargs = {"device": device, "dtype": dtype}
31
+ super().__init__()
32
+ self.heads_num = heads_num
33
+ head_dim = hidden_size // heads_num
34
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35
+
36
+ self.norm1 = nn.LayerNorm(
37
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
38
+ )
39
+ self.self_attn_qkv = nn.Linear(
40
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
41
+ )
42
+ qk_norm_layer = get_norm_layer(qk_norm_type)
43
+ self.self_attn_q_norm = (
44
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
45
+ if qk_norm
46
+ else nn.Identity()
47
+ )
48
+ self.self_attn_k_norm = (
49
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
50
+ if qk_norm
51
+ else nn.Identity()
52
+ )
53
+ self.self_attn_proj = nn.Linear(
54
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
55
+ )
56
+
57
+ self.norm2 = nn.LayerNorm(
58
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
59
+ )
60
+ act_layer = get_activation_layer(act_type)
61
+ self.mlp = MLP(
62
+ in_channels=hidden_size,
63
+ hidden_channels=mlp_hidden_dim,
64
+ act_layer=act_layer,
65
+ drop=mlp_drop_rate,
66
+ **factory_kwargs,
67
+ )
68
+
69
+ self.adaLN_modulation = nn.Sequential(
70
+ act_layer(),
71
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
72
+ )
73
+ # Zero-initialize the modulation
74
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
75
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
76
+
77
+ def forward(
78
+ self,
79
+ x: torch.Tensor,
80
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
81
+ attn_mask: torch.Tensor = None,
82
+ ):
83
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
84
+
85
+ norm_x = self.norm1(x)
86
+ qkv = self.self_attn_qkv(norm_x)
87
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
88
+ # Apply QK-Norm if needed
89
+ q = self.self_attn_q_norm(q).to(v)
90
+ k = self.self_attn_k_norm(k).to(v)
91
+ qkv_list = [q, k, v]
92
+ del q,k
93
+ # Self-Attention
94
+ attn = attention( qkv_list, mode="torch", attn_mask=attn_mask)
95
+
96
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
97
+
98
+ # FFN Layer
99
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
100
+
101
+ return x
102
+
103
+
104
+ class IndividualTokenRefiner(nn.Module):
105
+ def __init__(
106
+ self,
107
+ hidden_size,
108
+ heads_num,
109
+ depth,
110
+ mlp_width_ratio: float = 4.0,
111
+ mlp_drop_rate: float = 0.0,
112
+ act_type: str = "silu",
113
+ qk_norm: bool = False,
114
+ qk_norm_type: str = "layer",
115
+ qkv_bias: bool = True,
116
+ dtype: Optional[torch.dtype] = None,
117
+ device: Optional[torch.device] = None,
118
+ ):
119
+ factory_kwargs = {"device": device, "dtype": dtype}
120
+ super().__init__()
121
+ self.blocks = nn.ModuleList(
122
+ [
123
+ IndividualTokenRefinerBlock(
124
+ hidden_size=hidden_size,
125
+ heads_num=heads_num,
126
+ mlp_width_ratio=mlp_width_ratio,
127
+ mlp_drop_rate=mlp_drop_rate,
128
+ act_type=act_type,
129
+ qk_norm=qk_norm,
130
+ qk_norm_type=qk_norm_type,
131
+ qkv_bias=qkv_bias,
132
+ **factory_kwargs,
133
+ )
134
+ for _ in range(depth)
135
+ ]
136
+ )
137
+
138
+ def forward(
139
+ self,
140
+ x: torch.Tensor,
141
+ c: torch.LongTensor,
142
+ mask: Optional[torch.Tensor] = None,
143
+ ):
144
+ self_attn_mask = None
145
+ if mask is not None:
146
+ batch_size = mask.shape[0]
147
+ seq_len = mask.shape[1]
148
+ mask = mask.to(x.device)
149
+ # batch_size x 1 x seq_len x seq_len
150
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
151
+ 1, 1, seq_len, 1
152
+ )
153
+ # batch_size x 1 x seq_len x seq_len
154
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
155
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
156
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
157
+ # avoids self-attention weight being NaN for padding tokens
158
+ self_attn_mask[:, :, :, 0] = True
159
+
160
+ for block in self.blocks:
161
+ x = block(x, c, self_attn_mask)
162
+ return x
163
+
164
+
165
+ class SingleTokenRefiner(nn.Module):
166
+ """
167
+ A single token refiner block for llm text embedding refine.
168
+ """
169
+ def __init__(
170
+ self,
171
+ in_channels,
172
+ hidden_size,
173
+ heads_num,
174
+ depth,
175
+ mlp_width_ratio: float = 4.0,
176
+ mlp_drop_rate: float = 0.0,
177
+ act_type: str = "silu",
178
+ qk_norm: bool = False,
179
+ qk_norm_type: str = "layer",
180
+ qkv_bias: bool = True,
181
+ attn_mode: str = "torch",
182
+ dtype: Optional[torch.dtype] = None,
183
+ device: Optional[torch.device] = None,
184
+ ):
185
+ factory_kwargs = {"device": device, "dtype": dtype}
186
+ super().__init__()
187
+ self.attn_mode = attn_mode
188
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
189
+
190
+ self.input_embedder = nn.Linear(
191
+ in_channels, hidden_size, bias=True, **factory_kwargs
192
+ )
193
+
194
+ act_layer = get_activation_layer(act_type)
195
+ # Build timestep embedding layer
196
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
197
+ # Build context embedding layer
198
+ self.c_embedder = TextProjection(
199
+ in_channels, hidden_size, act_layer, **factory_kwargs
200
+ )
201
+
202
+ self.individual_token_refiner = IndividualTokenRefiner(
203
+ hidden_size=hidden_size,
204
+ heads_num=heads_num,
205
+ depth=depth,
206
+ mlp_width_ratio=mlp_width_ratio,
207
+ mlp_drop_rate=mlp_drop_rate,
208
+ act_type=act_type,
209
+ qk_norm=qk_norm,
210
+ qk_norm_type=qk_norm_type,
211
+ qkv_bias=qkv_bias,
212
+ **factory_kwargs,
213
+ )
214
+
215
+ def forward(
216
+ self,
217
+ x: torch.Tensor,
218
+ t: torch.LongTensor,
219
+ mask: Optional[torch.LongTensor] = None,
220
+ ):
221
+ timestep_aware_representations = self.t_embedder(t)
222
+
223
+ if mask is None:
224
+ context_aware_representations = x.mean(dim=1)
225
+ else:
226
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
227
+ context_aware_representations = (x * mask_float).sum(
228
+ dim=1
229
+ ) / mask_float.sum(dim=1)
230
+ context_aware_representations = self.c_embedder(context_aware_representations.to(x.dtype))
231
+ c = timestep_aware_representations + context_aware_representations
232
+
233
+ x = self.input_embedder(x)
234
+
235
+ x = self.individual_token_refiner(x, c, mask)
236
+
237
+ return x
hyvideo/modules/utils.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Mask Mod for Image2Video"""
2
+
3
+ from math import floor
4
+ import torch
5
+ from torch import Tensor
6
+
7
+
8
+ from functools import lru_cache
9
+ from typing import Optional, List
10
+
11
+ import torch
12
+ from torch.nn.attention.flex_attention import (
13
+ create_block_mask,
14
+ )
15
+
16
+
17
+ @lru_cache
18
+ def create_block_mask_cached(score_mod, B, H, M, N, device="cuda", _compile=False):
19
+ block_mask = create_block_mask(score_mod, B, H, M, N, device=device, _compile=_compile)
20
+ return block_mask
21
+
22
+ def generate_temporal_head_mask_mod(context_length: int = 226, prompt_length: int = 226, num_frames: int = 13, token_per_frame: int = 1350, mul: int = 2):
23
+
24
+ def round_to_multiple(idx):
25
+ return floor(idx / 128) * 128
26
+
27
+ real_length = num_frames * token_per_frame + prompt_length
28
+ def temporal_mask_mod(b, h, q_idx, kv_idx):
29
+ real_mask = (kv_idx < real_length) & (q_idx < real_length)
30
+ fake_mask = (kv_idx >= real_length) & (q_idx >= real_length)
31
+
32
+ two_frame = round_to_multiple(mul * token_per_frame)
33
+ temporal_head_mask = (torch.abs(q_idx - kv_idx) < two_frame)
34
+
35
+ text_column_mask = (num_frames * token_per_frame <= kv_idx) & (kv_idx < real_length)
36
+ text_row_mask = (num_frames * token_per_frame <= q_idx) & (q_idx < real_length)
37
+
38
+ video_mask = temporal_head_mask | text_column_mask | text_row_mask
39
+ real_mask = real_mask & video_mask
40
+
41
+ return real_mask | fake_mask
42
+
43
+ return temporal_mask_mod
hyvideo/prompt_rewrite.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ normal_mode_prompt = """Normal mode - Video Recaption Task:
2
+
3
+ You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
4
+
5
+ 0. Preserve ALL information, including style words and technical terms.
6
+
7
+ 1. If the input is in Chinese, translate the entire description to English.
8
+
9
+ 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
10
+
11
+ 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
12
+
13
+ 4. Output ALL must be in English.
14
+
15
+ Given Input:
16
+ input: "{input}"
17
+ """
18
+
19
+
20
+ master_mode_prompt = """Master mode - Video Recaption Task:
21
+
22
+ You are a large language model specialized in rewriting video descriptions. Your task is to modify the input description.
23
+
24
+ 0. Preserve ALL information, including style words and technical terms.
25
+
26
+ 1. If the input is in Chinese, translate the entire description to English.
27
+
28
+ 2. If the input is just one or two words describing an object or person, provide a brief, simple description focusing on basic visual characteristics. Limit the description to 1-2 short sentences.
29
+
30
+ 3. If the input does not include style, lighting, atmosphere, you can make reasonable associations.
31
+
32
+ 4. Output ALL must be in English.
33
+
34
+ Given Input:
35
+ input: "{input}"
36
+ """
37
+
38
+ def get_rewrite_prompt(ori_prompt, mode="Normal"):
39
+ if mode == "Normal":
40
+ prompt = normal_mode_prompt.format(input=ori_prompt)
41
+ elif mode == "Master":
42
+ prompt = master_mode_prompt.format(input=ori_prompt)
43
+ else:
44
+ raise Exception("Only supports Normal and Normal", mode)
45
+ return prompt
46
+
47
+ ori_prompt = "一只小狗在草地上奔跑。"
48
+ normal_prompt = get_rewrite_prompt(ori_prompt, mode="Normal")
49
+ master_prompt = get_rewrite_prompt(ori_prompt, mode="Master")
50
+
51
+ # Then you can use the normal_prompt or master_prompt to access the hunyuan-large rewrite model to get the final prompt.
hyvideo/text_encoder/__init__.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+ from copy import deepcopy
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import (
7
+ CLIPTextModel,
8
+ CLIPTokenizer,
9
+ AutoTokenizer,
10
+ AutoModel,
11
+ LlavaForConditionalGeneration,
12
+ CLIPImageProcessor,
13
+ )
14
+ from transformers.utils import ModelOutput
15
+
16
+ from ..constants import TEXT_ENCODER_PATH, TOKENIZER_PATH
17
+ from ..constants import PRECISION_TO_TYPE
18
+
19
+
20
+ def use_default(value, default):
21
+ return value if value is not None else default
22
+
23
+
24
+ def load_text_encoder(
25
+ text_encoder_type,
26
+ text_encoder_precision=None,
27
+ text_encoder_path=None,
28
+ device=None,
29
+ ):
30
+ if text_encoder_path is None:
31
+ text_encoder_path = TEXT_ENCODER_PATH[text_encoder_type]
32
+
33
+ if text_encoder_type == "clipL":
34
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path)
35
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
36
+ elif text_encoder_type == "llm":
37
+ text_encoder = AutoModel.from_pretrained(
38
+ text_encoder_path, low_cpu_mem_usage=True
39
+ )
40
+ text_encoder.final_layer_norm = text_encoder.norm
41
+ elif text_encoder_type == "llm-i2v":
42
+ text_encoder = LlavaForConditionalGeneration.from_pretrained(
43
+ text_encoder_path, low_cpu_mem_usage=True
44
+ )
45
+ else:
46
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
47
+ # from_pretrained will ensure that the model is in eval mode.
48
+
49
+ if text_encoder_precision is not None:
50
+ text_encoder = text_encoder.to(dtype=PRECISION_TO_TYPE[text_encoder_precision])
51
+
52
+ text_encoder.requires_grad_(False)
53
+
54
+ if device is not None:
55
+ text_encoder = text_encoder.to(device)
56
+
57
+ return text_encoder, text_encoder_path
58
+
59
+
60
+ def load_tokenizer(
61
+ tokenizer_type, tokenizer_path=None, padding_side="right"
62
+ ):
63
+ if tokenizer_path is None:
64
+ tokenizer_path = TOKENIZER_PATH[tokenizer_type]
65
+
66
+ processor = None
67
+ if tokenizer_type == "clipL":
68
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
69
+ elif tokenizer_type == "llm":
70
+ tokenizer = AutoTokenizer.from_pretrained(
71
+ tokenizer_path, padding_side=padding_side
72
+ )
73
+ elif tokenizer_type == "llm-i2v":
74
+ tokenizer = AutoTokenizer.from_pretrained(
75
+ tokenizer_path, padding_side=padding_side
76
+ )
77
+ processor = CLIPImageProcessor.from_pretrained(tokenizer_path)
78
+ else:
79
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
80
+
81
+ return tokenizer, tokenizer_path, processor
82
+
83
+
84
+ @dataclass
85
+ class TextEncoderModelOutput(ModelOutput):
86
+ """
87
+ Base class for model's outputs that also contains a pooling of the last hidden states.
88
+
89
+ Args:
90
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
91
+ Sequence of hidden-states at the output of the last layer of the model.
92
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
93
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
94
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
95
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
96
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
97
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
98
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
99
+ List of decoded texts.
100
+ """
101
+
102
+ hidden_state: torch.FloatTensor = None
103
+ attention_mask: Optional[torch.LongTensor] = None
104
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
105
+ text_outputs: Optional[list] = None
106
+
107
+
108
+ class TextEncoder(nn.Module):
109
+ def __init__(
110
+ self,
111
+ text_encoder_type: str,
112
+ max_length: int,
113
+ text_encoder_precision: Optional[str] = None,
114
+ text_encoder_path: Optional[str] = None,
115
+ tokenizer_type: Optional[str] = None,
116
+ tokenizer_path: Optional[str] = None,
117
+ output_key: Optional[str] = None,
118
+ use_attention_mask: bool = True,
119
+ i2v_mode: bool = False,
120
+ input_max_length: Optional[int] = None,
121
+ prompt_template: Optional[dict] = None,
122
+ prompt_template_video: Optional[dict] = None,
123
+ hidden_state_skip_layer: Optional[int] = None,
124
+ apply_final_norm: bool = False,
125
+ reproduce: bool = False,
126
+ device=None,
127
+ # image_embed_interleave (int): The number of times to interleave the image and text embeddings. Defaults to 2.
128
+ image_embed_interleave=2,
129
+ ):
130
+ super().__init__()
131
+ self.text_encoder_type = text_encoder_type
132
+ self.max_length = max_length
133
+ self.precision = text_encoder_precision
134
+ self.model_path = text_encoder_path
135
+ self.tokenizer_type = (
136
+ tokenizer_type if tokenizer_type is not None else text_encoder_type
137
+ )
138
+ self.tokenizer_path = (
139
+ tokenizer_path if tokenizer_path is not None else None # text_encoder_path
140
+ )
141
+ self.use_attention_mask = use_attention_mask
142
+ if prompt_template_video is not None:
143
+ assert (
144
+ use_attention_mask is True
145
+ ), "Attention mask is True required when training videos."
146
+ self.input_max_length = (
147
+ input_max_length if input_max_length is not None else max_length
148
+ )
149
+ self.prompt_template = prompt_template
150
+ self.prompt_template_video = prompt_template_video
151
+ self.hidden_state_skip_layer = hidden_state_skip_layer
152
+ self.apply_final_norm = apply_final_norm
153
+ self.i2v_mode = i2v_mode
154
+ self.reproduce = reproduce
155
+ self.image_embed_interleave = image_embed_interleave
156
+
157
+ self.use_template = self.prompt_template is not None
158
+ if self.use_template:
159
+ assert (
160
+ isinstance(self.prompt_template, dict)
161
+ and "template" in self.prompt_template
162
+ ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
163
+ assert "{}" in str(self.prompt_template["template"]), (
164
+ "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
165
+ f"got {self.prompt_template['template']}"
166
+ )
167
+
168
+ self.use_video_template = self.prompt_template_video is not None
169
+ if self.use_video_template:
170
+ if self.prompt_template_video is not None:
171
+ assert (
172
+ isinstance(self.prompt_template_video, dict)
173
+ and "template" in self.prompt_template_video
174
+ ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
175
+ assert "{}" in str(self.prompt_template_video["template"]), (
176
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
177
+ f"got {self.prompt_template_video['template']}"
178
+ )
179
+
180
+ if "t5" in text_encoder_type:
181
+ self.output_key = output_key or "last_hidden_state"
182
+ elif "clip" in text_encoder_type:
183
+ self.output_key = output_key or "pooler_output"
184
+ elif "llm" in text_encoder_type or "glm" in text_encoder_type:
185
+ self.output_key = output_key or "last_hidden_state"
186
+ else:
187
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
188
+
189
+ if "llm" in text_encoder_type:
190
+ from mmgp import offload
191
+ forcedConfigPath= None if "i2v" in text_encoder_type else "ckpts/llava-llama-3-8b/config.json"
192
+ self.model= offload.fast_load_transformers_model(self.model_path, modelPrefix="language_model" if forcedConfigPath != None else None, forcedConfigPath=forcedConfigPath)
193
+ if forcedConfigPath != None:
194
+ self.model.final_layer_norm = self.model.model.norm
195
+
196
+ else:
197
+ self.model, self.model_path = load_text_encoder(
198
+ text_encoder_type=self.text_encoder_type,
199
+ text_encoder_precision=self.precision,
200
+ text_encoder_path=self.model_path,
201
+ device=device,
202
+ )
203
+
204
+ self.dtype = self.model.dtype
205
+ self.device = self.model.device
206
+
207
+ self.tokenizer, self.tokenizer_path, self.processor = load_tokenizer(
208
+ tokenizer_type=self.tokenizer_type,
209
+ tokenizer_path=self.tokenizer_path,
210
+ padding_side="right",
211
+ )
212
+
213
+ def __repr__(self):
214
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
215
+
216
+ @staticmethod
217
+ def apply_text_to_template(text, template, prevent_empty_text=True):
218
+ """
219
+ Apply text to template.
220
+
221
+ Args:
222
+ text (str): Input text.
223
+ template (str or list): Template string or list of chat conversation.
224
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
225
+ by adding a space. Defaults to True.
226
+ """
227
+ if isinstance(template, str):
228
+ # Will send string to tokenizer. Used for llm
229
+ return template.format(text)
230
+ else:
231
+ raise TypeError(f"Unsupported template type: {type(template)}")
232
+
233
+ def text2tokens(self, text, data_type="image", name = None):
234
+ """
235
+ Tokenize the input text.
236
+
237
+ Args:
238
+ text (str or list): Input text.
239
+ """
240
+ tokenize_input_type = "str"
241
+ if self.use_template:
242
+ if data_type == "image":
243
+ prompt_template = self.prompt_template["template"]
244
+ elif data_type == "video":
245
+ prompt_template = self.prompt_template_video["template"]
246
+ else:
247
+ raise ValueError(f"Unsupported data type: {data_type}")
248
+ if isinstance(text, (list, tuple)):
249
+ text = [
250
+ self.apply_text_to_template(one_text, prompt_template)
251
+ for one_text in text
252
+ ]
253
+ if isinstance(text[0], list):
254
+ tokenize_input_type = "list"
255
+ elif isinstance(text, str):
256
+ text = self.apply_text_to_template(text, prompt_template)
257
+ if isinstance(text, list):
258
+ tokenize_input_type = "list"
259
+ else:
260
+ raise TypeError(f"Unsupported text type: {type(text)}")
261
+
262
+ kwargs = dict(truncation=True, max_length=self.max_length, padding="max_length", return_tensors="pt")
263
+ if self.text_encoder_type == "llm-i2v" and name != None: #llava-llama-3-8b
264
+ if isinstance(text, list):
265
+ for i in range(len(text)):
266
+ text[i] = text[i] + '\nThe %s looks like<image>' % name
267
+ elif isinstance(text, str):
268
+ text = text + '\nThe %s looks like<image>' % name
269
+ else:
270
+ raise NotImplementedError
271
+
272
+ kwargs = dict(
273
+ truncation=True,
274
+ max_length=self.max_length,
275
+ padding="max_length",
276
+ return_tensors="pt",
277
+ )
278
+ if tokenize_input_type == "str":
279
+ return self.tokenizer(
280
+ text,
281
+ return_length=False,
282
+ return_overflowing_tokens=False,
283
+ return_attention_mask=True,
284
+ **kwargs,
285
+ )
286
+ elif tokenize_input_type == "list":
287
+ return self.tokenizer.apply_chat_template(
288
+ text,
289
+ add_generation_prompt=True,
290
+ tokenize=True,
291
+ return_dict=True,
292
+ **kwargs,
293
+ )
294
+ else:
295
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
296
+
297
+ def encode(
298
+ self,
299
+ batch_encoding,
300
+ use_attention_mask=None,
301
+ output_hidden_states=False,
302
+ do_sample=None,
303
+ hidden_state_skip_layer=None,
304
+ return_texts=False,
305
+ data_type="image",
306
+ semantic_images=None,
307
+ device=None,
308
+ ):
309
+ """
310
+ Args:
311
+ batch_encoding (dict): Batch encoding from tokenizer.
312
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
313
+ Defaults to None.
314
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
315
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
316
+ output_hidden_states will be set True. Defaults to False.
317
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
318
+ When self.produce is False, do_sample is set to True by default.
319
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
320
+ If None, self.output_key will be used. Defaults to None.
321
+ hidden_state_skip_layer (PIL.Image): The reference images for i2v models.
322
+ image_embed_interleave (int): The number of times to interleave the image and text embeddings. Defaults to 2.
323
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
324
+ """
325
+ device = self.model.device if device is None else device
326
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
327
+ hidden_state_skip_layer = use_default(
328
+ hidden_state_skip_layer, self.hidden_state_skip_layer
329
+ )
330
+ do_sample = use_default(do_sample, not self.reproduce)
331
+ if not self.i2v_mode:
332
+ attention_mask = (
333
+ batch_encoding["attention_mask"].to(device)
334
+ if use_attention_mask
335
+ else None
336
+ )
337
+
338
+ if 'pixel_value_llava' in batch_encoding:
339
+ outputs = self.model(
340
+ input_ids=batch_encoding["input_ids"].to(self.model.device),
341
+ attention_mask=attention_mask,
342
+ pixel_values=batch_encoding["pixel_value_llava"].to(self.model.device),
343
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None)
344
+ else:
345
+ outputs = self.model(
346
+ input_ids=batch_encoding["input_ids"].to(self.model.device),
347
+ attention_mask=attention_mask,
348
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,)
349
+
350
+ if hidden_state_skip_layer is not None:
351
+ last_hidden_state = outputs.hidden_states[
352
+ -(hidden_state_skip_layer + 1)
353
+ ]
354
+ # Real last hidden state already has layer norm applied. So here we only apply it
355
+ # for intermediate layers.
356
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
357
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
358
+ else:
359
+ last_hidden_state = outputs[self.output_key]
360
+
361
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
362
+ if self.use_template:
363
+ if data_type == "image":
364
+ crop_start = self.prompt_template.get("crop_start", -1)
365
+ elif data_type == "video":
366
+ crop_start = self.prompt_template_video.get("crop_start", -1)
367
+ else:
368
+ raise ValueError(f"Unsupported data type: {data_type}")
369
+ if crop_start > 0:
370
+ last_hidden_state = last_hidden_state[:, crop_start:]
371
+ attention_mask = (
372
+ attention_mask[:, crop_start:] if use_attention_mask else None
373
+ )
374
+
375
+ if output_hidden_states:
376
+ return TextEncoderModelOutput(
377
+ last_hidden_state, attention_mask, outputs.hidden_states
378
+ )
379
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
380
+ else:
381
+ image_outputs = self.processor(semantic_images, return_tensors="pt")[
382
+ "pixel_values"
383
+ ].to(device)
384
+ attention_mask = (
385
+ batch_encoding["attention_mask"].to(device)
386
+ if use_attention_mask
387
+ else None
388
+ )
389
+ outputs = self.model(
390
+ input_ids=batch_encoding["input_ids"].to(device),
391
+ attention_mask=attention_mask,
392
+ output_hidden_states=output_hidden_states
393
+ or hidden_state_skip_layer is not None,
394
+ pixel_values=image_outputs,
395
+ )
396
+ if hidden_state_skip_layer is not None:
397
+ last_hidden_state = outputs.hidden_states[
398
+ -(hidden_state_skip_layer + 1)
399
+ ]
400
+ # Real last hidden state already has layer norm applied. So here we only apply it
401
+ # for intermediate layers.
402
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
403
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
404
+ else:
405
+ last_hidden_state = outputs[self.output_key]
406
+ if self.use_template:
407
+ if data_type == "video":
408
+ crop_start = self.prompt_template_video.get("crop_start", -1)
409
+ text_crop_start = (
410
+ crop_start
411
+ - 1
412
+ + self.prompt_template_video.get("image_emb_len", 576)
413
+ )
414
+ image_crop_start = self.prompt_template_video.get(
415
+ "image_emb_start", 5
416
+ )
417
+ image_crop_end = self.prompt_template_video.get(
418
+ "image_emb_end", 581
419
+ )
420
+ batch_indices, last_double_return_token_indices = torch.where(
421
+ batch_encoding["input_ids"]
422
+ == self.prompt_template_video.get("double_return_token_id", 271)
423
+ )
424
+ if last_double_return_token_indices.shape[0] == 3:
425
+ # in case the prompt is too long
426
+ last_double_return_token_indices = torch.cat(
427
+ (
428
+ last_double_return_token_indices,
429
+ torch.tensor([batch_encoding["input_ids"].shape[-1]]),
430
+ )
431
+ )
432
+ batch_indices = torch.cat((batch_indices, torch.tensor([0])))
433
+ last_double_return_token_indices = (
434
+ last_double_return_token_indices.reshape(
435
+ batch_encoding["input_ids"].shape[0], -1
436
+ )[:, -1]
437
+ )
438
+ batch_indices = batch_indices.reshape(
439
+ batch_encoding["input_ids"].shape[0], -1
440
+ )[:, -1]
441
+ assistant_crop_start = (
442
+ last_double_return_token_indices
443
+ - 1
444
+ + self.prompt_template_video.get("image_emb_len", 576)
445
+ - 4
446
+ )
447
+ assistant_crop_end = (
448
+ last_double_return_token_indices
449
+ - 1
450
+ + self.prompt_template_video.get("image_emb_len", 576)
451
+ )
452
+ attention_mask_assistant_crop_start = (
453
+ last_double_return_token_indices - 4
454
+ )
455
+ attention_mask_assistant_crop_end = last_double_return_token_indices
456
+ else:
457
+ raise ValueError(f"Unsupported data type: {data_type}")
458
+ text_last_hidden_state = []
459
+
460
+ text_attention_mask = []
461
+ image_last_hidden_state = []
462
+ image_attention_mask = []
463
+ for i in range(batch_encoding["input_ids"].shape[0]):
464
+ text_last_hidden_state.append(
465
+ torch.cat(
466
+ [
467
+ last_hidden_state[
468
+ i, text_crop_start : assistant_crop_start[i].item()
469
+ ],
470
+ last_hidden_state[i, assistant_crop_end[i].item() :],
471
+ ]
472
+ )
473
+ )
474
+ text_attention_mask.append(
475
+ torch.cat(
476
+ [
477
+ attention_mask[
478
+ i,
479
+ crop_start : attention_mask_assistant_crop_start[
480
+ i
481
+ ].item(),
482
+ ],
483
+ attention_mask[
484
+ i, attention_mask_assistant_crop_end[i].item() :
485
+ ],
486
+ ]
487
+ )
488
+ if use_attention_mask
489
+ else None
490
+ )
491
+ image_last_hidden_state.append(
492
+ last_hidden_state[i, image_crop_start:image_crop_end]
493
+ )
494
+ image_attention_mask.append(
495
+ torch.ones(image_last_hidden_state[-1].shape[0])
496
+ .to(last_hidden_state.device)
497
+ .to(attention_mask.dtype)
498
+ if use_attention_mask
499
+ else None
500
+ )
501
+
502
+ text_last_hidden_state = torch.stack(text_last_hidden_state)
503
+ text_attention_mask = torch.stack(text_attention_mask)
504
+ image_last_hidden_state = torch.stack(image_last_hidden_state)
505
+ image_attention_mask = torch.stack(image_attention_mask)
506
+
507
+ if semantic_images is not None and 0 < self.image_embed_interleave < 6:
508
+ image_last_hidden_state = image_last_hidden_state[
509
+ :, ::self.image_embed_interleave, :
510
+ ]
511
+ image_attention_mask = image_attention_mask[
512
+ :, ::self.image_embed_interleave
513
+ ]
514
+
515
+ assert (
516
+ text_last_hidden_state.shape[0] == text_attention_mask.shape[0]
517
+ and image_last_hidden_state.shape[0]
518
+ == image_attention_mask.shape[0]
519
+ )
520
+
521
+ last_hidden_state = torch.cat(
522
+ [image_last_hidden_state, text_last_hidden_state], dim=1
523
+ )
524
+ attention_mask = torch.cat(
525
+ [image_attention_mask, text_attention_mask], dim=1
526
+ )
527
+ if output_hidden_states:
528
+ return TextEncoderModelOutput(
529
+ last_hidden_state,
530
+ attention_mask,
531
+ hidden_states_list=outputs.hidden_states,
532
+ )
533
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
534
+
535
+ def forward(
536
+ self,
537
+ text,
538
+ use_attention_mask=None,
539
+ output_hidden_states=False,
540
+ do_sample=False,
541
+ hidden_state_skip_layer=None,
542
+ return_texts=False,
543
+ ):
544
+ batch_encoding = self.text2tokens(text)
545
+ return self.encode(
546
+ batch_encoding,
547
+ use_attention_mask=use_attention_mask,
548
+ output_hidden_states=output_hidden_states,
549
+ do_sample=do_sample,
550
+ hidden_state_skip_layer=hidden_state_skip_layer,
551
+ return_texts=return_texts,
552
+ )
hyvideo/utils/__init__.py ADDED
File without changes
hyvideo/utils/data_utils.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import math
3
+ from PIL import Image
4
+ import torch
5
+ import copy
6
+ import string
7
+ import random
8
+
9
+
10
+ def align_to(value, alignment):
11
+ """align hight, width according to alignment
12
+
13
+ Args:
14
+ value (int): height or width
15
+ alignment (int): target alignment factor
16
+
17
+ Returns:
18
+ int: the aligned value
19
+ """
20
+ return int(math.ceil(value / alignment) * alignment)
21
+
22
+
23
+ def black_image(width, height):
24
+ """generate a black image
25
+
26
+ Args:
27
+ width (int): image width
28
+ height (int): image height
29
+
30
+ Returns:
31
+ _type_: a black image
32
+ """
33
+ black_image = Image.new("RGB", (width, height), (0, 0, 0))
34
+ return black_image
35
+
36
+
37
+ def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
38
+ """get the closest ratio in the buckets
39
+
40
+ Args:
41
+ height (float): video height
42
+ width (float): video width
43
+ ratios (list): video aspect ratio
44
+ buckets (list): buckets generate by `generate_crop_size_list`
45
+
46
+ Returns:
47
+ the closest ratio in the buckets and the corresponding ratio
48
+ """
49
+ aspect_ratio = float(height) / float(width)
50
+ closest_ratio_id = np.abs(ratios - aspect_ratio).argmin()
51
+ closest_ratio = min(ratios, key=lambda ratio: abs(float(ratio) - aspect_ratio))
52
+ return buckets[closest_ratio_id], float(closest_ratio)
53
+
54
+
55
+ def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
56
+ """generate crop size list
57
+
58
+ Args:
59
+ base_size (int, optional): the base size for generate bucket. Defaults to 256.
60
+ patch_size (int, optional): the stride to generate bucket. Defaults to 32.
61
+ max_ratio (float, optional): th max ratio for h or w based on base_size . Defaults to 4.0.
62
+
63
+ Returns:
64
+ list: generate crop size list
65
+ """
66
+ num_patches = round((base_size / patch_size) ** 2)
67
+ assert max_ratio >= 1.0
68
+ crop_size_list = []
69
+ wp, hp = num_patches, 1
70
+ while wp > 0:
71
+ if max(wp, hp) / min(wp, hp) <= max_ratio:
72
+ crop_size_list.append((wp * patch_size, hp * patch_size))
73
+ if (hp + 1) * wp <= num_patches:
74
+ hp += 1
75
+ else:
76
+ wp -= 1
77
+ return crop_size_list
78
+
79
+
80
+ def align_floor_to(value, alignment):
81
+ """align hight, width according to alignment
82
+
83
+ Args:
84
+ value (int): height or width
85
+ alignment (int): target alignment factor
86
+
87
+ Returns:
88
+ int: the aligned value
89
+ """
90
+ return int(math.floor(value / alignment) * alignment)
hyvideo/utils/file_utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from einops import rearrange
4
+
5
+ import torch
6
+ import torchvision
7
+ import numpy as np
8
+ import imageio
9
+
10
+ CODE_SUFFIXES = {
11
+ ".py", # Python codes
12
+ ".sh", # Shell scripts
13
+ ".yaml",
14
+ ".yml", # Configuration files
15
+ }
16
+
17
+
18
+ def safe_dir(path):
19
+ """
20
+ Create a directory (or the parent directory of a file) if it does not exist.
21
+
22
+ Args:
23
+ path (str or Path): Path to the directory.
24
+
25
+ Returns:
26
+ path (Path): Path object of the directory.
27
+ """
28
+ path = Path(path)
29
+ path.mkdir(exist_ok=True, parents=True)
30
+ return path
31
+
32
+
33
+ def safe_file(path):
34
+ """
35
+ Create the parent directory of a file if it does not exist.
36
+
37
+ Args:
38
+ path (str or Path): Path to the file.
39
+
40
+ Returns:
41
+ path (Path): Path object of the file.
42
+ """
43
+ path = Path(path)
44
+ path.parent.mkdir(exist_ok=True, parents=True)
45
+ return path
46
+
47
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
48
+ """save videos by video tensor
49
+ copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
50
+
51
+ Args:
52
+ videos (torch.Tensor): video tensor predicted by the model
53
+ path (str): path to save video
54
+ rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
55
+ n_rows (int, optional): Defaults to 1.
56
+ fps (int, optional): video save fps. Defaults to 8.
57
+ """
58
+ videos = rearrange(videos, "b c t h w -> t b c h w")
59
+ outputs = []
60
+ for x in videos:
61
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
62
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
63
+ if rescale:
64
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
65
+ x = torch.clamp(x, 0, 1)
66
+ x = (x * 255).numpy().astype(np.uint8)
67
+ outputs.append(x)
68
+
69
+ os.makedirs(os.path.dirname(path), exist_ok=True)
70
+ imageio.mimsave(path, outputs, fps=fps)