NeverMore0123 commited on
Commit
02e9885
·
1 Parent(s): 25c2765

copy from personal repo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +201 -0
  2. README.md +115 -0
  3. RELEASE.md +7 -0
  4. ar_config_base_model.py +118 -0
  5. ar_config_base_model_config.py +421 -0
  6. ar_config_base_tokenizer.py +137 -0
  7. ar_config_inference_inference_config.py +102 -0
  8. ar_diffusion_decoder_config_base_conditioner.py +61 -0
  9. ar_diffusion_decoder_config_config_latent_diffusion_decoder.py +62 -0
  10. ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py +85 -0
  11. ar_diffusion_decoder_config_registry.py +118 -0
  12. ar_diffusion_decoder_inference.py +120 -0
  13. ar_diffusion_decoder_model.py +231 -0
  14. ar_diffusion_decoder_network.py +163 -0
  15. ar_diffusion_decoder_utils.py +119 -0
  16. ar_model.py +596 -0
  17. ar_module_attention.py +262 -0
  18. ar_module_embedding.py +491 -0
  19. ar_module_mlp.py +50 -0
  20. ar_module_mm_projector.py +109 -0
  21. ar_module_normalization.py +88 -0
  22. ar_network_transformer.py +461 -0
  23. ar_network_vit.py +410 -0
  24. ar_tokenizer_discrete_video.py +360 -0
  25. ar_tokenizer_image_text_tokenizer.py +318 -0
  26. ar_tokenizer_modules.py +560 -0
  27. ar_tokenizer_networks.py +63 -0
  28. ar_tokenizer_patching.py +279 -0
  29. ar_tokenizer_quantizers.py +165 -0
  30. ar_tokenizer_text_tokenizer.py +317 -0
  31. ar_tokenizer_tokenizer.py +322 -0
  32. ar_tokenizer_utils.py +101 -0
  33. ar_utils_checkpoint.py +76 -0
  34. ar_utils_inference.py +360 -0
  35. ar_utils_misc.py +52 -0
  36. ar_utils_sampling.py +195 -0
  37. assets/cosmos-logo.png +0 -0
  38. assets/diffusion_decoder_image_output.mp4 +0 -0
  39. assets/diffusion_decoder_video_output.mp4 +0 -0
  40. assets/image_output.mp4 +0 -0
  41. assets/video_output.mp4 +0 -0
  42. base.py +116 -0
  43. base_world_generation_pipeline.py +358 -0
  44. config.json +10 -0
  45. config.py +165 -0
  46. config_helper.py +197 -0
  47. convert_pixtral_ckpt.py +209 -0
  48. cosmos1/models/POST_TRAINING.md +23 -0
  49. cosmos1/models/autoregressive/README.md +427 -0
  50. cosmos1/models/autoregressive/__init__.py +14 -0
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## How to Use
2
+ ### Example outputs can be found in assets folder
3
+
4
+ ```python
5
+
6
+ from transformers import AutoModel
7
+
8
+ model = AutoModel.from_pretrained(
9
+ "Nvidia-CMU25/ARVideo2WorldGeneration",
10
+ cache_dir="./cache",
11
+ trust_remote_code=True,
12
+
13
+ input_type = "text_and_image",
14
+ num_input_frames = 1,
15
+ prompt = "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." ,
16
+ input_image_or_video_path = "AutoregressiveVideo2WorldGeneration/cosmos1/models/autoregressive/assets/v1p0/input.jpg",
17
+ video_save_name = "diffusion_decoder_image_output",
18
+ ar_model_dir = "Cosmos-1.0-Autoregressive-5B-Video2World",
19
+
20
+ # input_type = "text_and_video",
21
+ # num_input_frames = 9,
22
+ # prompt = "A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." ,
23
+ # input_image_or_video_path = "AutoregressiveVideo2WorldGeneration/cosmos1/models/autoregressive/assets/v1p0/input.mp4",
24
+ # video_save_name = "diffusion_decoder_video_output",
25
+
26
+ # turn on offloading on a low GPU memory machine:
27
+ disable_diffusion_decoder=False,
28
+ offload_guardrail_models=True,
29
+ offload_diffusion_decoder=True,
30
+ offload_network=True,
31
+ offload_tokenizer=True,
32
+ offload_text_encoder_model=True,
33
+ )
34
+
35
+ model()
36
+ ```
37
+
38
+
39
+ ![Cosmos Logo](assets/cosmos-logo.png)
40
+
41
+ --------------------------------------------------------------------------------
42
+ ### [Website](https://www.nvidia.com/en-us/ai/cosmos/) | [HuggingFace](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6) | [GPU-free Preview](https://build.nvidia.com/explore/discover) | [Paper](https://arxiv.org/abs/2501.03575) | [Paper Website](https://research.nvidia.com/labs/dir/cosmos1/)
43
+
44
+ [NVIDIA Cosmos](https://www.nvidia.com/cosmos/) is a developer-first world foundation model platform designed to help Physical AI developers build their Physical AI systems better and faster. Cosmos contains
45
+
46
+ 1. pre-trained models, available via [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6) under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/) that allows commercial use of the models for free
47
+ 2. training scripts under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0), offered through [NVIDIA Nemo Framework](https://github.com/NVIDIA/NeMo) for post-training the models for various downstream Physical AI applications
48
+
49
+ Details of the platform is described in the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai). Preview access is avaiable at [build.nvidia.com](https://build.nvidia.com).
50
+
51
+ ## Key Features
52
+
53
+ - [Pre-trained Diffusion-based world foundation models](cosmos1/models/diffusion/README.md) for Text2World and Video2World generation where a user can generate visual simulation based on text prompts and video prompts.
54
+ - [Pre-trained Autoregressive-based world foundation models](cosmos1/models/autoregressive/README.md) for Video2World generation where a user can generate visual simulation based on video prompts and optional text prompts.
55
+ - [Video tokenizers](https://github.com/NVIDIA/Cosmos-Tokenizer) for tokenizing videos into continuous tokens (latent vectors) and discrete tokens (integers) efficiently and effectively.
56
+ - Video curation pipeline for building your own video dataset. [Coming soon]
57
+ - [Post-training scripts](cosmos1/models/POST_TRAINING.md) via NeMo Framework to post-train the pre-trained world foundation models for various Physical AI setup.
58
+ - Pre-training scripts via NeMo Framework for building your own world foundation model. [[Diffusion](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/diffusion)] [[Autoregressive](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/multimodal_autoregressive)] [[Tokenizer](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/diffusion/vae)].
59
+
60
+ ## Model Family
61
+
62
+ | Model name | Description | Try it out |
63
+ |------------|----------|----------|
64
+ | [Cosmos-1.0-Diffusion-7B-Text2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-7B-Text2World) | Text to visual world generation | [Inference](cosmos1/models/diffusion/README.md) |
65
+ | [Cosmos-1.0-Diffusion-14B-Text2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-14B-Text2World) | Text to visual world generation | [Inference](cosmos1/models/diffusion/README.md) |
66
+ | [Cosmos-1.0-Diffusion-7B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-7B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/diffusion/README.md) |
67
+ | [Cosmos-1.0-Diffusion-14B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Diffusion-14B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/diffusion/README.md) |
68
+ | [Cosmos-1.0-Autoregressive-4B](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-4B) | Future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) |
69
+ | [Cosmos-1.0-Autoregressive-12B](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-12B) | Future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) |
70
+ | [Cosmos-1.0-Autoregressive-5B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-5B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) |
71
+ | [Cosmos-1.0-Autoregressive-13B-Video2World](https://huggingface.co/nvidia/Cosmos-1.0-Autoregressive-13B-Video2World) | Video + Text based future visual world generation | [Inference](cosmos1/models/autoregressive/README.md) |
72
+ | [Cosmos-1.0-Guardrail](https://huggingface.co/nvidia/Cosmos-1.0-Guardrail) | Guardrail contains pre-Guard and post-Guard for safe use | Embedded in model inference scripts |
73
+
74
+ ## Example Usage
75
+
76
+ ### Inference
77
+
78
+ Follow the [Cosmos Installation Guide](INSTALL.md) to setup the docker. For inference with the pretrained models, please refer to [Cosmos Diffusion Inference](cosmos1/models/diffusion/README.md) and [Cosmos Autoregressive Inference](cosmos1/models/autoregressive/README.md).
79
+
80
+ The code snippet below provides a gist of the inference usage.
81
+
82
+ ```bash
83
+ PROMPT="A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. \
84
+ The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. \
85
+ A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, \
86
+ suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. \
87
+ The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of \
88
+ field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
89
+
90
+ # Example using 7B model
91
+ PYTHONPATH=$(pwd) python cosmos1/models/diffusion/inference/text2world.py \
92
+ --checkpoint_dir checkpoints \
93
+ --diffusion_transformer_dir Cosmos-1.0-Diffusion-7B-Text2World \
94
+ --prompt "$PROMPT" \
95
+ --offload_prompt_upsampler \
96
+ --video_save_name Cosmos-1.0-Diffusion-7B-Text2World
97
+ ```
98
+
99
+ <video src="https://github.com/user-attachments/assets/db7bebfe-5314-40a6-b045-4f6ce0a87f2a">
100
+ Your browser does not support the video tag.
101
+ </video>
102
+
103
+ We also offer [multi-GPU inference](cosmos1/models/diffusion/nemo/inference/README.md) support for Diffusion Text2World WFM models through NeMo Framework.
104
+
105
+ ### Post-training
106
+
107
+ NeMo Framework provides GPU accelerated post-training with general post-training for both [diffusion](cosmos1/models/diffusion/nemo/post_training/README.md) and [autoregressive](cosmos1/models/autoregressive/nemo/post_training/README.md) models, with other types of post-training coming soon.
108
+
109
+ ## License and Contact
110
+
111
+ This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use.
112
+
113
+ NVIDIA Cosmos source code is released under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0).
114
+
115
+ NVIDIA Cosmos models are released under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). For a custom license, please contact [[email protected]](mailto:[email protected]).
RELEASE.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Release Cadence
2
+
3
+
4
+ | Version | Description | Date |
5
+ |------------|----------|----------|
6
+ | [v1.0](release_notes/v0p1.md) | Initial diffusion and autoregressive WFMs release | 2025-01-06 |
7
+ | [v0.1](release_notes/v0p1.md) | Initial tokenizer release | 2024-11-06 |
ar_config_base_model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import attrs
19
+
20
+ from .ar_config_base_tokenizer import TokenizerConfig
21
+
22
+
23
+ @attrs.define
24
+ class ModelConfig:
25
+ """
26
+ A class to hold model configuration arguments.
27
+
28
+ Args:
29
+ dim (int): The dimensionality of the input and output of each transformer block.
30
+ n_layers (int): Number of layers in the transformer.
31
+ n_heads (int): Number of attention heads.
32
+ n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
33
+ `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
34
+ head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
35
+ vocab_size (int): Vocabulary size.
36
+ ffn_hidden_size (int): Hidden size for feedforward network.
37
+ norm_eps (float): Epsilon value for normalization.
38
+ rope_theta (float): Theta value for rotary positional embeddings.
39
+ apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
40
+ max_batch_size (int): Maximum batch size for inference.
41
+ max_seq_len (int): Maximum sequence length for input text.
42
+ fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True.
43
+ causal_mask (bool): Whether to use causal mask. Defaults to True.
44
+ norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
45
+ precision (str): Data type for the model.
46
+ use_qk_normalization (bool): Whether to enable QK normalization.
47
+ ckpt_dir (str): Checkpoint directory.
48
+ ckpt_path (str): Checkpoint path.
49
+ apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
50
+ yarn_scale (Optional[float]): Scale factor for YaRN.
51
+ yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
52
+ yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
53
+ original_seq_len (Optional[int]): Original sequence length.
54
+ vision_encoder (Optional[str]): Vision encoder name.
55
+ mm_projector (Optional[str]): Multi-modal projector name.
56
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
57
+ rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D".
58
+ pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2".
59
+ original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
60
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
61
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3.
62
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
63
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
64
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
65
+ num_video_frames (Optional[int]): Number of video frames.
66
+ video_height (Optional[int]): Raw video pixel height dimension.
67
+ video_width (Optional[int]): Raw video pixel width dimension.
68
+ video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W).
69
+ """
70
+
71
+ dim: int = attrs.field(default=4096)
72
+ n_layers: int = attrs.field(default=32)
73
+ n_heads: int = attrs.field(default=32)
74
+ n_kv_heads: Optional[int] = attrs.field(default=8)
75
+ head_dim: Optional[int] = attrs.field(default=None)
76
+ vocab_size: int = attrs.field(default=128256)
77
+ ffn_hidden_size: int = attrs.field(default=14336)
78
+ norm_eps: float = attrs.field(default=1e-5)
79
+ rope_theta: float = attrs.field(default=500000)
80
+ apply_abs_pos_emb: bool = attrs.field(default=False)
81
+ max_batch_size: int = attrs.field(default=1)
82
+ max_seq_len: int = attrs.field(default=8192)
83
+ fuse_qkv: bool = attrs.field(default=False)
84
+ causal_mask: bool = attrs.field(default=True)
85
+ norm_type: str = attrs.field(default="rmsnorm")
86
+ precision: str = attrs.field(default="bfloat16")
87
+ use_qk_normalization: bool = False
88
+ tokenizer: Optional[TokenizerConfig] = None
89
+ ckpt_dir: Optional[str] = attrs.field(default=None)
90
+ ckpt_path: Optional[str] = attrs.field(
91
+ default=None
92
+ ) # If not None, load the model from this path instead of ckpt_dir
93
+ apply_yarn: Optional[bool] = attrs.field(default=False)
94
+ yarn_scale: Optional[float] = attrs.field(default=None)
95
+ yarn_beta_fast: Optional[int] = attrs.field(default=None)
96
+ yarn_beta_slow: Optional[int] = attrs.field(default=None)
97
+ original_seq_len: Optional[int] = attrs.field(default=None)
98
+ vision_encoder: Optional[str] = attrs.field(default=None)
99
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
100
+ mm_projector: Optional[str] = attrs.field(default=None)
101
+ rope_dim: Optional[str] = attrs.field(default="1D")
102
+ pytorch_rope_version: Optional[str] = attrs.field(default="v2")
103
+ original_latent_shape: Optional[list] = None
104
+ pad_to_multiple_of: Optional[int] = None
105
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
106
+ insert_cross_attn: bool = False
107
+ insert_cross_attn_every_k_layers: int = 1
108
+ context_dim: Optional[int] = attrs.field(default=1024)
109
+ # For video training
110
+ num_video_frames: Optional[int] = None
111
+ # Raw video pixel dimension
112
+ video_height: Optional[int] = None
113
+ video_width: Optional[int] = None
114
+ # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
115
+ video_latent_shape: Optional[list] = None
116
+
117
+ def __getitem__(self, item):
118
+ return getattr(self, item)
ar_config_base_model_config.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ from typing import Callable, List, Optional
18
+
19
+ from .ar_config_base_model import ModelConfig
20
+ from .ar_config_base_tokenizer import (
21
+ TextTokenizerConfig,
22
+ TokenizerConfig,
23
+ VideoTokenizerConfig,
24
+ create_discrete_video_fsq_tokenizer_state_dict_config,
25
+ )
26
+ from .ar_tokenizer_image_text_tokenizer import ImageTextTokenizer
27
+ from .ar_tokenizer_text_tokenizer import TextTokenizer
28
+ from .log import log
29
+ from .lazy_config_init import LazyCall as L
30
+
31
+ # Common architecture specifications
32
+ BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336}
33
+ COSMOS_ARCHITECTURES = {
34
+ "4b": {
35
+ "n_layers": 16,
36
+ "dim": 4096,
37
+ "n_heads": 32,
38
+ },
39
+ "12b": {
40
+ "n_layers": 40,
41
+ "dim": 5120,
42
+ "n_heads": 32,
43
+ "head_dim": 128,
44
+ },
45
+ }
46
+
47
+ COSMOS_YARN_CONFIG = {
48
+ "original_latent_shape": [3, 40, 64],
49
+ "apply_yarn": True,
50
+ "yarn_beta_fast": 4,
51
+ "yarn_beta_slow": 1,
52
+ "yarn_scale": 2,
53
+ }
54
+
55
+ # Llama3 architecture specifications for different model sizes
56
+ LLAMA3_ARCHITECTURES = {
57
+ "8b": {
58
+ "n_layers": 32,
59
+ "dim": 4096,
60
+ "n_heads": 32,
61
+ "ffn_hidden_size": 14336,
62
+ },
63
+ }
64
+ # Llama3.1 uses YaRN for long context support (context of 128k tokens)
65
+ LLAMA_YARN_CONFIG = {
66
+ "apply_yarn": True,
67
+ "yarn_scale": 8,
68
+ "yarn_beta_fast": 4,
69
+ "yarn_beta_slow": 1,
70
+ }
71
+
72
+ # Mistral architecture specifications for different model sizes
73
+ MISTRAL_ARCHITECTURES = {
74
+ "12b": {
75
+ "n_layers": 40,
76
+ "dim": 5120,
77
+ "n_heads": 32,
78
+ "ffn_hidden_size": 14336,
79
+ "head_dim": 128,
80
+ },
81
+ }
82
+
83
+ PIXTRAL_VISION_ARCHITECTURES = {
84
+ "12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"},
85
+ }
86
+
87
+
88
+ def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict:
89
+ """
90
+ Get the model architecture specifications for the given model size, model family and pretrained status.
91
+
92
+ Args:
93
+ model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc.
94
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral"
95
+ pretrained (bool): Whether to load pretrained weights.
96
+
97
+ Returns:
98
+ dict: A dictionary containing the model architecture specifications.
99
+ """
100
+ arch_specs = copy.deepcopy(BASE_CONFIG)
101
+ model_size = model_size.lower()
102
+ if model_family.startswith("cosmos"):
103
+ arch_specs.update(COSMOS_ARCHITECTURES[model_size])
104
+ elif model_family.startswith("llama"):
105
+ arch_specs.update(LLAMA3_ARCHITECTURES[model_size])
106
+ elif model_family in ["mistral", "pixtral"]:
107
+ arch_specs.update(MISTRAL_ARCHITECTURES[model_size])
108
+ if model_family == "pixtral":
109
+ arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size])
110
+ else:
111
+ raise ValueError(f"Model family {model_family} is not supported.")
112
+
113
+ if pretrained:
114
+ if model_family == "cosmos":
115
+ if model_size == "12b":
116
+ arch_specs.update(COSMOS_YARN_CONFIG)
117
+ log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}")
118
+ else:
119
+ pass
120
+ elif model_family in ["llama", "llama3"]:
121
+ pretrained_specs = {
122
+ "rope_theta": 500000,
123
+ "max_seq_len": 8192,
124
+ "vocab_size": 128256,
125
+ }
126
+ arch_specs.update(pretrained_specs)
127
+ elif model_family == "llama3.1":
128
+ pretrained_specs = {
129
+ "rope_theta": 500000,
130
+ "max_seq_len": 131072,
131
+ "original_seq_len": 8192,
132
+ "vocab_size": 128256,
133
+ **LLAMA_YARN_CONFIG,
134
+ }
135
+ arch_specs.update(pretrained_specs)
136
+ elif model_family == "mistral":
137
+ assert model_size == "12b", "We only support Mistral-Nemo-12B model."
138
+ pretrained_specs = {
139
+ "rope_theta": 1000000,
140
+ "max_seq_len": 128000,
141
+ "vocab_size": 131072,
142
+ }
143
+ arch_specs.update(pretrained_specs)
144
+ elif model_family == "pixtral":
145
+ assert model_size == "12b", "We only support Pixtral 12B model."
146
+ pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072}
147
+ arch_specs.update(pretrained_specs)
148
+ else:
149
+ raise ValueError(f"Model family {model_family} doesn't have a pretrained config.")
150
+
151
+ return arch_specs
152
+
153
+
154
+ def create_text_model_config(
155
+ model_ckpt_path: str,
156
+ tokenizer_path: str,
157
+ model_family: str = "mistral",
158
+ model_size: str = "12b",
159
+ is_instruct_model: bool = True,
160
+ max_seq_len: int = None,
161
+ max_batch_size: int = 1,
162
+ rope_dim: str = "1D",
163
+ add_special_tokens: bool = True,
164
+ pytorch_rope_version: str = None,
165
+ ) -> dict:
166
+ """Create a text model for training or inference.
167
+ Args:
168
+ model_ckpt_path (str): Path to the model checkpoint.
169
+ tokenizer_path (str): Path to the tokenizer folder.
170
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
171
+ model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc.
172
+ is_instruct_model (bool): Whether the model is an instruct model.
173
+ inference (bool): Whether to create the model for inference.
174
+ max_seq_len (int): Maximum sequence length.
175
+ max_batch_size (int): Maximum batch size.
176
+ rope_dim (str): RoPE dimension. Choices: "1D", "3D".
177
+ add_special_tokens (bool): Whether to add special tokens.
178
+ Returns:
179
+ dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
180
+ """
181
+ # Model size specific parameters
182
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
183
+ if max_seq_len is not None:
184
+ # Override the max_seq_len if provided
185
+ model_arch_specs["max_seq_len"] = max_seq_len
186
+ if pytorch_rope_version is not None:
187
+ model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
188
+ model_config = ModelConfig(
189
+ max_batch_size=max_batch_size,
190
+ precision="bfloat16",
191
+ ckpt_path=model_ckpt_path,
192
+ use_qk_normalization=False,
193
+ rope_dim=rope_dim,
194
+ **model_arch_specs,
195
+ )
196
+
197
+ tokenizer_config = TokenizerConfig(
198
+ text_tokenizer=TextTokenizerConfig(
199
+ config=L(TextTokenizer)(
200
+ model_family=model_family,
201
+ is_instruct_model=is_instruct_model,
202
+ local_path=tokenizer_path,
203
+ ),
204
+ data_key="text",
205
+ tokenizer_offset=model_config.vocab_size,
206
+ tokenize_here=False,
207
+ vocab_size=model_config.vocab_size,
208
+ ),
209
+ seq_len=model_config.max_seq_len,
210
+ training_type="text_only",
211
+ add_special_tokens=add_special_tokens,
212
+ )
213
+ return model_config, tokenizer_config
214
+
215
+
216
+ def create_vision_language_model_config(
217
+ model_ckpt_path: str,
218
+ tokenizer_ckpt_path: str,
219
+ model_family: str = "pixtral",
220
+ model_size: str = "12b",
221
+ is_instruct_model: bool = True,
222
+ max_batch_size: int = 1,
223
+ rope_dim: str = "1D",
224
+ add_special_tokens: bool = True,
225
+ max_seq_len: int = None,
226
+ vision_encoder_in_channels: int = 3,
227
+ fuse_qkv: bool = False,
228
+ pytorch_rope_version: str = None,
229
+ ) -> dict:
230
+ """Create a vision-language model for training or inference.
231
+ Args:
232
+ model_ckpt_path (str): Path to the model checkpoint.
233
+ tokenizer_ckpt_path (str): Path to the tokenizer checkpoint.
234
+ model_family (str): Model family. Choices: "pixtral".
235
+ model_size (str): Model size. Choices: "12b".
236
+ is_instruct_model (bool): Whether the model is an instruct model.
237
+ rope_dim (str): RoPE dimension. Choices: "1D".
238
+ add_special_tokens (bool): Whether to add special tokens.
239
+ max_seq_len (int): Maximum sequence length.
240
+ vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4.
241
+ fuse_qkv (bool): Whether to fuse the QKV linear layers.
242
+ Returns:
243
+ dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
244
+ """
245
+ # Model size specific parameters
246
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
247
+ if max_seq_len is not None:
248
+ # Override the max_seq_len if provided
249
+ model_arch_specs["max_seq_len"] = max_seq_len
250
+ if pytorch_rope_version is not None:
251
+ model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
252
+
253
+ model_config = ModelConfig(
254
+ max_batch_size=max_batch_size,
255
+ precision="bfloat16",
256
+ ckpt_path=model_ckpt_path,
257
+ use_qk_normalization=False,
258
+ rope_dim=rope_dim,
259
+ vision_encoder_in_channels=vision_encoder_in_channels,
260
+ fuse_qkv=fuse_qkv,
261
+ **model_arch_specs,
262
+ )
263
+ # Vision-language tokenizer
264
+ tokenizer_config = TokenizerConfig(
265
+ text_tokenizer=TextTokenizerConfig(
266
+ config=L(ImageTextTokenizer)(
267
+ model_family=model_family,
268
+ is_instruct_model=is_instruct_model,
269
+ image_processor_path=tokenizer_ckpt_path,
270
+ tokenizer_path=tokenizer_ckpt_path,
271
+ ),
272
+ data_key="image_text_interleaved",
273
+ tokenizer_offset=model_config.vocab_size,
274
+ tokenize_here=False,
275
+ vocab_size=model_config.vocab_size,
276
+ ),
277
+ seq_len=model_config.max_seq_len,
278
+ training_type="image_text_interleaved",
279
+ add_special_tokens=add_special_tokens,
280
+ )
281
+ return model_config, tokenizer_config
282
+
283
+
284
+ def create_video2world_model_config(
285
+ model_ckpt_path: str,
286
+ tokenizer_ckpt_path: str,
287
+ model_family: str = "cosmos",
288
+ model_size: str = "4b",
289
+ pixel_chunk_duration: int = 9,
290
+ num_video_frames: int = 36,
291
+ compression_ratio: List[int] = [8, 16, 16],
292
+ original_seq_len: int = 8192,
293
+ num_condition_latents_t: int = 1,
294
+ num_tokens_to_ignore: int = -1,
295
+ batch_size: int = 2,
296
+ video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
297
+ rope_dim: str = "3D",
298
+ add_special_tokens: bool = True,
299
+ video_height: int = 384,
300
+ video_width: int = 640,
301
+ use_qk_normalization: bool = True,
302
+ insert_cross_attn: bool = False,
303
+ insert_cross_attn_every_k_layers: int = 1,
304
+ context_dim: int = 1024,
305
+ training_type: str = "video_to_video",
306
+ pad_to_multiple_of: Optional[int] = 64,
307
+ vocab_size: int = 64000,
308
+ apply_abs_pos_emb: bool = False,
309
+ ) -> dict:
310
+ """Create a video-to-world model config.
311
+ Args:
312
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
313
+ model_size (str): Model size. Choices: "1b", "8b", "3b".
314
+ pixel_chunk_duration (int): Number of frames in each chunk.
315
+ num_video_frames (int): Number of video frames.
316
+ compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
317
+ original_seq_len (int): Original sequence length.
318
+ apply_yarn (bool): Whether to apply YaRN for long context scaling.
319
+ yarn_beta_fast (Optional[int]): Fast beta for YaRN.
320
+ yarn_beta_slow (Optional[int]): Slow beta for YaRN.
321
+ yarn_scale (Optional[int]): Scale factor for ctx extension.
322
+ use_qk_normalization (bool): Whether to use Query-Key normalization.
323
+ training_type (str): Type of training task.
324
+ batch_size (int): Batch size.
325
+ video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
326
+ video_tokenizer_version (str): Version of the video tokenizer.
327
+ num_condition_latents_t (int): Number of conditioning latent channels
328
+ num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
329
+ video_height (int): Height of the video frame. Defaults to 384.
330
+ video_width (int): Width of the video frame. Defaults to 640.
331
+ rope_dim (str): RoPE dimension. Choices: "1D", "3D".
332
+ add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
333
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
334
+ vocab_size (int): Vocabulary size.
335
+ apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings.
336
+ Returns:
337
+ dict: A dictionary containing the model configuration representing the model object, can be instantiated.
338
+ """
339
+ assert (
340
+ pixel_chunk_duration % compression_ratio[0] == 1
341
+ ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
342
+ latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
343
+ latent_height = video_height // compression_ratio[1]
344
+ latent_width = video_width // compression_ratio[2]
345
+ # Do some math to compute the video latent shape and sequence length
346
+ assert (
347
+ num_video_frames % pixel_chunk_duration == 0
348
+ ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
349
+ video_latent_shape = [
350
+ num_video_frames // pixel_chunk_duration * latent_chunk_duration,
351
+ latent_height,
352
+ latent_width,
353
+ ]
354
+ # product of video_latent_shape
355
+ num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
356
+ if add_special_tokens:
357
+ seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
358
+ seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
359
+ # for text to video, we need to add <bov> token to indicate the start of the video
360
+ elif training_type == "text_to_video":
361
+ seq_len = num_token_video_latent + 1
362
+ else:
363
+ seq_len = num_token_video_latent
364
+
365
+ if seq_len % pad_to_multiple_of != 0:
366
+ # Round up to the nearest multiple of pad_to_multiple_of
367
+ seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
368
+
369
+ # Model size specific parameters
370
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
371
+
372
+ # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
373
+ # If num_tokens_to_ignore is specified, use it.
374
+ # Else compute it from num_condition_latents_t
375
+ if num_tokens_to_ignore < 0:
376
+ num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
377
+ if not add_special_tokens and num_condition_latents_t > 0:
378
+ # If there are no special tokens (bov), do a -1 so that you can compute the loss
379
+ # from the first token of the next chunk
380
+ num_tokens_to_ignore -= 1
381
+
382
+ model_config = ModelConfig(
383
+ video_height=video_height,
384
+ video_width=video_width,
385
+ max_seq_len=seq_len,
386
+ max_batch_size=batch_size,
387
+ precision="bfloat16",
388
+ ckpt_path=model_ckpt_path,
389
+ use_qk_normalization=use_qk_normalization,
390
+ vocab_size=64000,
391
+ original_seq_len=original_seq_len,
392
+ video_latent_shape=video_latent_shape,
393
+ num_video_frames=num_video_frames,
394
+ rope_dim=rope_dim,
395
+ pad_to_multiple_of=pad_to_multiple_of,
396
+ insert_cross_attn=insert_cross_attn,
397
+ insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
398
+ context_dim=context_dim,
399
+ apply_abs_pos_emb=apply_abs_pos_emb,
400
+ **model_arch_specs,
401
+ )
402
+
403
+ video_tokenizer_config = video_tokenizer_config_creator(
404
+ tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio
405
+ )
406
+ tokenizer_config = TokenizerConfig(
407
+ text_tokenizer=None,
408
+ video_tokenizer=VideoTokenizerConfig(
409
+ config=video_tokenizer_config,
410
+ data_key="video",
411
+ tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token.
412
+ tokenize_here=True,
413
+ max_seq_len=num_token_video_latent,
414
+ vocab_size=vocab_size,
415
+ ),
416
+ seq_len=seq_len,
417
+ training_type=training_type,
418
+ add_special_tokens=add_special_tokens,
419
+ pad_to_multiple_of=pad_to_multiple_of,
420
+ )
421
+ return model_config, tokenizer_config
ar_config_base_tokenizer.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import attrs
19
+
20
+ from .ar_tokenizer_discrete_video import DiscreteVideoFSQStateDictTokenizer
21
+ from .ar_tokenizer_networks import CausalDiscreteVideoTokenizer
22
+ from .lazy_config_init import LazyCall as L
23
+ from .lazy_config_init import LazyDict
24
+
25
+
26
+ def create_discrete_video_fsq_tokenizer_state_dict_config(
27
+ ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16]
28
+ ) -> LazyDict:
29
+ CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)(
30
+ # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime.
31
+ # - It relies on fully 3D discrete wavelet transform
32
+ # - Uses a layer norm instead of a group norm
33
+ # - Factorizes full convolutions into spatial and temporal convolutions
34
+ # - Factorizes full attention into spatial and temporal attention
35
+ # - Strictly causal, with flexible temporal length at inference.
36
+ attn_resolutions=[32],
37
+ channels=128,
38
+ channels_mult=[2, 4, 4],
39
+ dropout=0.0,
40
+ in_channels=3,
41
+ num_res_blocks=2,
42
+ out_channels=3,
43
+ resolution=1024,
44
+ patch_size=4,
45
+ patch_method="haar",
46
+ z_channels=16,
47
+ z_factor=1,
48
+ num_groups=1,
49
+ legacy_mode=False,
50
+ spatial_compression=16,
51
+ temporal_compression=8,
52
+ embedding_dim=6,
53
+ levels=[8, 8, 8, 5, 5, 5],
54
+ name="CausalDiscreteFactorizedVideoTokenizer",
55
+ )
56
+
57
+ return L(DiscreteVideoFSQStateDictTokenizer)(
58
+ enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"),
59
+ dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"),
60
+ tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig,
61
+ name="discrete_video_fsq",
62
+ latent_ch=6,
63
+ is_bf16=True,
64
+ pixel_chunk_duration=pixel_chunk_duration,
65
+ latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0],
66
+ max_enc_batch_size=8,
67
+ max_dec_batch_size=4,
68
+ levels=[8, 8, 8, 5, 5, 5],
69
+ compression_ratio=compression_ratio,
70
+ )
71
+
72
+
73
+ @attrs.define(slots=False)
74
+ class TextTokenizerConfig:
75
+ """
76
+ Text tokenizer config
77
+
78
+ Args:
79
+ config: Config file to define the text tokenizer class.
80
+ data_key (str): The input key from data_dict that will be passed to the text tokenizer.
81
+ tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
82
+ tokenizer_offset (int): Offset that is added to the tokens.
83
+ vocab_size (int): Vocabulary size of the tokenizer.
84
+ """
85
+
86
+ config: LazyDict
87
+ data_key: str = ""
88
+ tokenize_here: bool = False
89
+ tokenizer_offset: int = 0
90
+ vocab_size: int = 0
91
+
92
+
93
+ @attrs.define(slots=False)
94
+ class VideoTokenizerConfig:
95
+ """
96
+ Video tokenizer config
97
+
98
+ Args:
99
+ config: Config file to define the video tokenizer class.
100
+ data_key (str): The input key from data_dict that will be passed to the video tokenizer.
101
+ tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
102
+ tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we
103
+ add an offset to make sure that video tokens and text tokens don't overlap.
104
+ vocab_size (int): Vocabulary size of the tokenizer.
105
+ max_seq_len (int): Maximum token length for an input video.
106
+ """
107
+
108
+ config: LazyDict
109
+ data_key: str = ""
110
+ tokenize_here: bool = True
111
+ tokenizer_offset: int = 0
112
+ vocab_size: int = 0
113
+ max_seq_len: int = -1
114
+
115
+
116
+ @attrs.define(slots=False)
117
+ class TokenizerConfig:
118
+ """
119
+ Joint tokenizer config
120
+
121
+ Args:
122
+ text_tokenizer (TextTokenizerConfig): Text tokenizer config file
123
+ class_tokenizer (ClassTokenizerConfig): Class tokenizer config file
124
+ video_tokenizer (VideoTokenizerConfig): Video tokenizer config file
125
+ image_tokenizer (ImageTokenizerConfig): Image tokenizer config file
126
+ seq_len (int): Final token sequence length
127
+ training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"]
128
+ add_special_tokens (bool): Whether to add special tokens to the output tokens
129
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
130
+ """
131
+
132
+ text_tokenizer: Optional[TextTokenizerConfig] = None
133
+ video_tokenizer: Optional[VideoTokenizerConfig] = None
134
+ seq_len: int = 4096
135
+ training_type: str = None
136
+ add_special_tokens: bool = True
137
+ pad_to_multiple_of: Optional[int] = 64
ar_config_inference_inference_config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, List, Union
17
+
18
+ import attrs
19
+
20
+ from .ar_config_base_model import ModelConfig, TokenizerConfig
21
+
22
+
23
+ @attrs.define(slots=False)
24
+ class DataShapeConfig:
25
+ latent_shape: list = []
26
+ num_video_frames: Union[None, int] = None
27
+ height: Union[None, int] = None
28
+ width: Union[None, int] = None
29
+
30
+
31
+ @attrs.define(slots=False)
32
+ class SamplingConfig:
33
+ """
34
+ Sampling config
35
+ Args:
36
+ temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
37
+ top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
38
+ logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False.
39
+ echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
40
+
41
+ """
42
+
43
+ temperature: float = 0.6
44
+ top_k: int = None
45
+ top_p: float = 0.9
46
+ compile_prefill: bool = False
47
+ compile_sampling: bool = True
48
+ logprobs: bool = False
49
+ echo: bool = False
50
+
51
+
52
+ @attrs.define(slots=False)
53
+ class DiffusionDecoderSamplingConfig:
54
+ """
55
+ Diffusion decoder sampling config
56
+ Args:
57
+ guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8.
58
+ sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02.
59
+ sigma (float): Initial noise level for the diffusion process. Defaults to 8.
60
+ num_steps (int): Number of denoising steps to perform. Defaults to 35.
61
+ overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2.
62
+ continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16.
63
+ continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8.
64
+ dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57.
65
+ """
66
+
67
+ guidance: float = 1.8
68
+ sigma_min: float = 0.02
69
+ sigma: float = 8
70
+ num_steps: int = 15
71
+ overlap: int = 2
72
+ continuous_tokenizer_channel = 16
73
+ continuous_tokenizer_spatial_compression_ratio = 8
74
+ dd_train_num_video_frames: int = 57
75
+ max_iter: int = 99
76
+ fps: int = 24
77
+
78
+
79
+ @attrs.define(slots=False)
80
+ class InferenceConfig:
81
+ """
82
+ Inference config
83
+ Args:
84
+ model_config (ModelConfig): Model config
85
+ tokenizer_config (TokenizerConfig): Tokenizer config
86
+ ckpt_path (str): Path to the checkpoint
87
+ latent_shape (list): Shape of the latent
88
+ """
89
+
90
+ model_config: ModelConfig = None
91
+ tokenizer_config: TokenizerConfig = None
92
+ ckpt_path: str = ""
93
+ data_shape_config: DataShapeConfig = None
94
+
95
+ defaults: List[Any] = attrs.field(
96
+ factory=lambda: [
97
+ "_self_",
98
+ {"data_val": None},
99
+ {"data_shape_config": "video_shape_as_model_config"},
100
+ {"eval_job": None},
101
+ ]
102
+ )
ar_diffusion_decoder_config_base_conditioner.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Dict, Optional
18
+
19
+ import torch
20
+
21
+ from .df_conditioner import BaseVideoCondition, GeneralConditioner
22
+ from .df_config_base_conditioner import (
23
+ FPSConfig,
24
+ ImageSizeConfig,
25
+ LatentConditionConfig,
26
+ LatentConditionSigmaConfig,
27
+ NumFramesConfig,
28
+ PaddingMaskConfig,
29
+ TextConfig,
30
+ )
31
+ from .lazy_config_init import LazyCall as L
32
+ from .lazy_config_init import LazyDict
33
+
34
+
35
+ @dataclass
36
+ class VideoLatentDiffusionDecoderCondition(BaseVideoCondition):
37
+ # latent_condition will concat to the input of network, along channel dim;
38
+ # cfg will make latent_condition all zero padding.
39
+ latent_condition: Optional[torch.Tensor] = None
40
+ latent_condition_sigma: Optional[torch.Tensor] = None
41
+
42
+
43
+ class VideoDiffusionDecoderConditioner(GeneralConditioner):
44
+ def forward(
45
+ self,
46
+ batch: Dict,
47
+ override_dropout_rate: Optional[Dict[str, float]] = None,
48
+ ) -> VideoLatentDiffusionDecoderCondition:
49
+ output = super()._forward(batch, override_dropout_rate)
50
+ return VideoLatentDiffusionDecoderCondition(**output)
51
+
52
+
53
+ VideoLatentDiffusionDecoderConditionerConfig: LazyDict = L(VideoDiffusionDecoderConditioner)(
54
+ text=TextConfig(),
55
+ fps=FPSConfig(),
56
+ num_frames=NumFramesConfig(),
57
+ image_size=ImageSizeConfig(),
58
+ padding_mask=PaddingMaskConfig(),
59
+ latent_condition=LatentConditionConfig(),
60
+ latent_condition_sigma=LatentConditionSigmaConfig(),
61
+ )
ar_diffusion_decoder_config_config_latent_diffusion_decoder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, List
17
+
18
+ import attrs
19
+
20
+ from .ar_diffusion_decoder_config_registry import register_configs as register_dd_configs
21
+ from .df_config_base_model import LatentDiffusionDecoderModelConfig
22
+ from .df_config_registry import register_configs
23
+ from .config import Config as ori_Config
24
+ from .config_helper import import_all_modules_from_package
25
+
26
+ from .ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b import LazyDict
27
+
28
+ @attrs.define(slots=False)
29
+ class Config(ori_Config):
30
+ # default config groups that will be used unless overwritten
31
+ # see config groups in registry.py
32
+ defaults: List[Any] = attrs.field(
33
+ factory=lambda: [
34
+ "_self_",
35
+ {"net": None},
36
+ {"conditioner": "basic"},
37
+ {"tokenizer": "tokenizer"},
38
+ {"tokenizer_corruptor": None},
39
+ {"latent_corruptor": None},
40
+ {"pixel_corruptor": None},
41
+ {"experiment": None},
42
+ ]
43
+ )
44
+
45
+
46
+ def make_config():
47
+ c = Config(model=LatentDiffusionDecoderModelConfig())
48
+
49
+ # Specifying values through instances of attrs
50
+ c.job.project = "cosmos_video4"
51
+ c.job.group = "debug"
52
+ c.job.name = "delete_${now:%Y-%m-%d}_${now:%H-%M-%S}"
53
+
54
+ # # Call this function to register config groups for advanced overriding.
55
+ register_configs()
56
+ register_dd_configs()
57
+
58
+ # # experiment config are defined in the experiment folder
59
+ # # call import_all_modules_from_package to register them
60
+ # import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True)
61
+ # import_all_modules_from_package("cosmos1.models.autoregressive.diffusion_decoder.config.inference", reload=True)
62
+ return c
ar_diffusion_decoder_config_inference_cosmos_diffusiondecoder_7b.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from hydra.core.config_store import ConfigStore
17
+
18
+ from .ar_diffusion_decoder_network import DiffusionDecoderGeneralDIT
19
+ from .lazy_config_init import LazyCall as L
20
+ from .lazy_config_init import LazyDict
21
+
22
+ num_frames = 57
23
+ Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY: LazyDict = LazyDict(
24
+ dict(
25
+ defaults=[
26
+ {"override /net": "faditv2_7b"},
27
+ {"override /tokenizer": "cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624"},
28
+ {"override /conditioner": "video_latent_diffusion_decoder_cond"},
29
+ {"override /tokenizer_corruptor": "cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224"},
30
+ "_self_",
31
+ ],
32
+ job=dict(
33
+ group="diffusion_deocder_FT_7Bv1_001",
34
+ name="DD_FT_7Bv1_003_002_tokenizer888_spatch2_discrete_cond_on_token",
35
+ ),
36
+ model=dict(
37
+ diffusion_decoder_cond_sigma_low=0.0,
38
+ diffusion_decoder_cond_sigma_high=0.0,
39
+ diffusion_decoder_corrupt_prob=0.0,
40
+ condition_on_tokenizer_corruptor_token=True,
41
+ latent_shape=[
42
+ 16,
43
+ num_frames,
44
+ 88,
45
+ 160,
46
+ ],
47
+ tokenizer_corruptor=dict(
48
+ pixel_chunk_duration=num_frames,
49
+ latent_chunk_duration=1 + (num_frames - 1) // 8,
50
+ ),
51
+ net=L(DiffusionDecoderGeneralDIT)(
52
+ diffusion_decoder_condition_on_sigma=False,
53
+ max_img_h=240,
54
+ max_img_w=240,
55
+ rope_h_extrapolation_ratio=1.5,
56
+ rope_w_extrapolation_ratio=1.5,
57
+ rope_t_extrapolation_ratio=1,
58
+ block_x_format="THWBD",
59
+ is_diffusion_decoder=True,
60
+ patch_spatial=2,
61
+ diffusion_decoder_condition_on_token=True,
62
+ diffusion_decoder_token_condition_voc_size=64000,
63
+ diffusion_decoder_token_condition_dim=32,
64
+ ),
65
+ tokenizer=dict(
66
+ video_vae=dict(
67
+ pixel_chunk_duration=num_frames,
68
+ )
69
+ ),
70
+ conditioner=dict(
71
+ latent_condition=dict(
72
+ dropout_rate=0.2,
73
+ )
74
+ ),
75
+ ),
76
+ )
77
+ )
78
+
79
+ cs = ConfigStore.instance()
80
+ cs.store(
81
+ group="experiment",
82
+ package="_global_",
83
+ name=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY["job"]["name"],
84
+ node=Cosmos_DiffusionDecoder_7B_INFERENCE_ONLY,
85
+ )
ar_diffusion_decoder_config_registry.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from hydra.core.config_store import ConfigStore
17
+
18
+ from .ar_diffusion_decoder_config_base_conditioner import (
19
+ VideoLatentDiffusionDecoderConditionerConfig,
20
+ )
21
+ from .ar_tokenizer_discrete_video import DiscreteVideoFSQJITTokenizer
22
+ from .df_module_pretrained_vae import JITVAE, JointImageVideoSharedJITTokenizer, VideoJITTokenizer
23
+ from .lazy_config_init import LazyCall as L
24
+
25
+
26
+ def get_cosmos_video_discrete_tokenizer_comp8x16x16(
27
+ resolution: str,
28
+ chunk_duration: int,
29
+ checkpoint_path: str,
30
+ ):
31
+ assert resolution in ["720"]
32
+
33
+ pixel_chunk_duration = chunk_duration
34
+ temporal_compression_factor = 8
35
+ spatial_compression_factor = 16
36
+
37
+ return L(DiscreteVideoFSQJITTokenizer)(
38
+ enc_fp=checkpoint_path.replace(".jit", "encoder.jit"),
39
+ dec_fp=checkpoint_path.replace(".jit", "decoder.jit"),
40
+ name="discrete_video_fsq",
41
+ latent_ch=6,
42
+ is_bf16=True,
43
+ pixel_chunk_duration=pixel_chunk_duration,
44
+ latent_chunk_duration=1 + (pixel_chunk_duration - 1) // temporal_compression_factor,
45
+ max_enc_batch_size=8,
46
+ max_dec_batch_size=4,
47
+ levels=[8, 8, 8, 5, 5, 5],
48
+ compression_ratio=[temporal_compression_factor, spatial_compression_factor, spatial_compression_factor],
49
+ )
50
+
51
+
52
+ def get_cosmos_video_tokenizer_comp8x8x8(resolution: str, chunk_duration: int, checkpoint_path=None):
53
+ pixel_chunk_duration = chunk_duration
54
+ temporal_compression_factor = 8
55
+ spatial_compression_factor = 8
56
+
57
+ return L(JointImageVideoSharedJITTokenizer)(
58
+ video_vae=L(VideoJITTokenizer)(
59
+ name="cosmos_1_0_diffusion_tokenizer",
60
+ latent_ch=16,
61
+ is_bf16=True,
62
+ pixel_chunk_duration=pixel_chunk_duration,
63
+ temporal_compression_factor=temporal_compression_factor,
64
+ spatial_compression_factor=spatial_compression_factor,
65
+ spatial_resolution=resolution,
66
+ ),
67
+ image_vae=L(JITVAE)(
68
+ name="cosmos_1_0_diffusion_tokenizer",
69
+ latent_ch=16,
70
+ is_image=False,
71
+ is_bf16=True,
72
+ ),
73
+ name="cosmos_diffusion_tokenizer_res720_comp8x8x8_t121_ver092624",
74
+ latent_ch=16,
75
+ )
76
+
77
+
78
+ def register_tokenizer(cs):
79
+ cs.store(
80
+ group="tokenizer",
81
+ package="model.tokenizer",
82
+ name="cosmos_video_tokenizer_res720_comp8x8x8_t121_ver092624",
83
+ node=get_cosmos_video_tokenizer_comp8x8x8(
84
+ resolution="720",
85
+ chunk_duration=121,
86
+ checkpoint_path="checkpoints/Cosmos-1.0-Tokenizer-CV8x8x8/.jit",
87
+ ),
88
+ )
89
+
90
+
91
+ def register_corruptor(cs):
92
+ cs.store(
93
+ group="tokenizer_corruptor",
94
+ package="model.tokenizer_corruptor",
95
+ name="cosmos_video_discrete_tokenizer_res720_comp8x16x16_t49_ver110224",
96
+ node=get_cosmos_video_discrete_tokenizer_comp8x16x16(
97
+ resolution="720",
98
+ chunk_duration=49,
99
+ checkpoint_path="checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/.jit",
100
+ ),
101
+ )
102
+
103
+
104
+ def register_conditioner(cs):
105
+ cs.store(
106
+ group="conditioner",
107
+ package="model.conditioner",
108
+ name="video_latent_diffusion_decoder_cond",
109
+ node=VideoLatentDiffusionDecoderConditionerConfig,
110
+ )
111
+
112
+
113
+ def register_configs():
114
+ cs = ConfigStore.instance()
115
+
116
+ register_conditioner(cs)
117
+ register_corruptor(cs)
118
+ register_tokenizer(cs)
ar_diffusion_decoder_inference.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ import gc
18
+ from typing import List
19
+
20
+ import torch
21
+
22
+ from .ar_config_inference_inference_config import DiffusionDecoderSamplingConfig
23
+ from .ar_diffusion_decoder_model import LatentDiffusionDecoderModel
24
+ from .ar_diffusion_decoder_utils import linear_blend_video_list, split_with_overlap
25
+ from .log import log
26
+
27
+
28
+ def diffusion_decoder_process_tokens(
29
+ model: LatentDiffusionDecoderModel,
30
+ indices_tensor: List[torch.Tensor],
31
+ dd_sampling_config: DiffusionDecoderSamplingConfig = None,
32
+ original_video_example: torch.Tensor = None,
33
+ t5_emb_batch: List[torch.Tensor] = None,
34
+ ):
35
+ _, T, H, W = original_video_example.shape
36
+ if dd_sampling_config is None:
37
+ dd_sampling_config = DiffusionDecoderSamplingConfig()
38
+ # indices_tensor is assumed to be a list of tensors with shape 1LHW
39
+ data_batch_list = []
40
+ for sample_num, token_CTHW in enumerate(indices_tensor):
41
+ token_BCTHW = token_CTHW.unsqueeze(0).unsqueeze(1)
42
+ token_BCTHW = split_with_overlap(
43
+ token_BCTHW,
44
+ (dd_sampling_config.dd_train_num_video_frames - 1) // 8 + 1,
45
+ overlap=dd_sampling_config.overlap,
46
+ tobf16=False,
47
+ )
48
+ data_batch_list.append(
49
+ {
50
+ "token_chunks": token_BCTHW,
51
+ "t5_text_embeddings": t5_emb_batch[sample_num].to(torch.bfloat16),
52
+ "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(),
53
+ # other conditions
54
+ "image_size": torch.tensor([[H, W, H, W]] * 1, dtype=torch.bfloat16).cuda(),
55
+ "fps": torch.tensor([dd_sampling_config.fps] * 1, dtype=torch.bfloat16).cuda(),
56
+ "num_frames": torch.tensor(
57
+ [dd_sampling_config.dd_train_num_video_frames] * 1, dtype=torch.bfloat16
58
+ ).cuda(),
59
+ "padding_mask": torch.zeros((1, 1, H, W), dtype=torch.bfloat16).cuda(),
60
+ }
61
+ )
62
+
63
+ out_videos_batch = []
64
+
65
+ for idx, data_batch_template in enumerate(data_batch_list):
66
+ full_length_sample = []
67
+ iterations = min(len(data_batch_template["token_chunks"]), dd_sampling_config.max_iter)
68
+ for iter in range(iterations):
69
+ gc.collect()
70
+ torch.cuda.empty_cache()
71
+
72
+ data_batch = copy.deepcopy(data_batch_template)
73
+ data_batch["video"] = data_batch_template["token_chunks"][iter].cuda().to("cuda")
74
+
75
+ log.debug(f"Run iter {iter} for video # {idx} at length {data_batch['video'].shape[2]}")
76
+ # org_video,
77
+ with torch.no_grad():
78
+ samples_latent = model.generate_samples_from_batch(
79
+ data_batch,
80
+ guidance=dd_sampling_config.guidance,
81
+ sigma_min=dd_sampling_config.sigma_min,
82
+ state_shape=[
83
+ dd_sampling_config.continuous_tokenizer_channel,
84
+ dd_sampling_config.continuous_tokenizer_spatial_compression_ratio,
85
+ H // 8,
86
+ W // 8,
87
+ ],
88
+ apply_corruptor=False,
89
+ return_recon_x=False,
90
+ # corrupt_sigma=dd_sampling_config.sigma,
91
+ preencode_condition=True, # We are using discrete model, so the input is already pre-encoded
92
+ num_steps=dd_sampling_config.num_steps,
93
+ )
94
+ log.debug(f"Current sample shape {samples_latent.shape} for video # {idx} ")
95
+ full_length_sample.append(samples_latent.detach())
96
+
97
+ # Turn off because we remove CP
98
+ # distributed.barrier()
99
+ del data_batch
100
+
101
+ torch.cuda.empty_cache()
102
+
103
+ gc.collect()
104
+ torch.cuda.empty_cache()
105
+
106
+ # Decode full-length samples and free GPU memory
107
+ full_length_sample_pixs = [model.decode(item).clamp(-1, 1).cpu() for item in full_length_sample]
108
+ torch.cuda.empty_cache()
109
+
110
+ # Blend pixel samples
111
+ if len(full_length_sample_pixs) > 1:
112
+ full_length_sample_pixel_blend = linear_blend_video_list(
113
+ full_length_sample_pixs, dd_sampling_config.overlap
114
+ )[:, :, :T]
115
+ else:
116
+ full_length_sample_pixel_blend = full_length_sample_pixs[0][:, :, :T]
117
+
118
+ # Batch size of full_length_sample_pixel_blend is always 1
119
+ out_videos_batch.append((1 + full_length_sample_pixel_blend[0].cpu()) / 2)
120
+ return out_videos_batch
ar_diffusion_decoder_model.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Callable, Dict, Optional, Tuple
18
+
19
+ import torch
20
+ from torch import Tensor
21
+
22
+ from .df_conditioner import BaseVideoCondition
23
+ from .df_df_functional_batch_ops import batch_mul
24
+ from .df_df_module_res_sampler import COMMON_SOLVER_OPTIONS
25
+ from .df_model_model_t2w import DiffusionT2WModel as VideoDiffusionModel
26
+ from .lazy_config_init import instantiate as lazy_instantiate
27
+
28
+
29
+ @dataclass
30
+ class VideoLatentDiffusionDecoderCondition(BaseVideoCondition):
31
+ # latent_condition will concat to the input of network, along channel dim;
32
+ # cfg will make latent_condition all zero padding.
33
+ latent_condition: Optional[torch.Tensor] = None
34
+ latent_condition_sigma: Optional[torch.Tensor] = None
35
+
36
+
37
+ class LatentDiffusionDecoderModel(VideoDiffusionModel):
38
+ def __init__(self, config):
39
+ super().__init__(config)
40
+ """
41
+ latent_corruptor: the corruption module is used to corrupt the latents. It add gaussian noise to the latents.
42
+ pixel_corruptor: the corruption module is used to corrupt the pixels. It apply gaussian blur kernel to pixels in a temporal consistent way.
43
+ tokenizer_corruptor: the corruption module is used to simulate tokenizer reconstruction errors.
44
+
45
+ diffusion decoder noise augmentation pipeline for continuous token condition model:
46
+ condition: GT_video [T, H, W]
47
+ -> tokenizer_corruptor~(8x8x8) encode -> latent_corruptor -> tokenizer_corruptor~(8x8x8) decode
48
+ -> pixel corruptor
49
+ -> tokenizer~(1x8x8) encode -> condition [T, H/8, W/8]
50
+ GT: GT_video [T, H, W] -> tokenizer~(1x8x8) -> x_t [T, H/8, W/8].
51
+
52
+ diffusion decoder noise augmentation pipeline for discrete token condition model:
53
+ condition: GT_video [T, H, W]
54
+ -> pixel corruptor
55
+ -> discrete tokenizer encode -> condition [T, T/8, H/16, W/16]
56
+ GT: GT_video [T, H, W] -> tokenizer~(8x8x8) -> x_t [T, T/8, H/8, W/8].
57
+
58
+ """
59
+ self.latent_corruptor = lazy_instantiate(config.latent_corruptor)
60
+ self.pixel_corruptor = lazy_instantiate(config.pixel_corruptor)
61
+ self.tokenizer_corruptor = lazy_instantiate(config.tokenizer_corruptor)
62
+
63
+ if self.latent_corruptor:
64
+ self.latent_corruptor.to(**self.tensor_kwargs)
65
+ if self.pixel_corruptor:
66
+ self.pixel_corruptor.to(**self.tensor_kwargs)
67
+
68
+ if self.tokenizer_corruptor:
69
+ if hasattr(self.tokenizer_corruptor, "reset_dtype"):
70
+ self.tokenizer_corruptor.reset_dtype()
71
+ else:
72
+ assert self.pixel_corruptor is not None
73
+
74
+ self.diffusion_decoder_cond_sigma_low = config.diffusion_decoder_cond_sigma_low
75
+ self.diffusion_decoder_cond_sigma_high = config.diffusion_decoder_cond_sigma_high
76
+ self.diffusion_decoder_corrupt_prob = config.diffusion_decoder_corrupt_prob
77
+ if hasattr(config, "condition_on_tokenizer_corruptor_token"):
78
+ self.condition_on_tokenizer_corruptor_token = config.condition_on_tokenizer_corruptor_token
79
+ else:
80
+ self.condition_on_tokenizer_corruptor_token = False
81
+
82
+ def is_image_batch(self, data_batch: dict[str, Tensor]) -> bool:
83
+ """We hanlde two types of data_batch. One comes from a joint_dataloader where "dataset_name" can be used to differenciate image_batch and video_batch.
84
+ Another comes from a dataloader which we by default assumes as video_data for video model training.
85
+ """
86
+ is_image = self.input_image_key in data_batch
87
+ is_video = self.input_data_key in data_batch
88
+ assert (
89
+ is_image != is_video
90
+ ), "Only one of the input_image_key or input_data_key should be present in the data_batch."
91
+ return is_image
92
+
93
+ def get_x0_fn_from_batch(
94
+ self,
95
+ data_batch: Dict,
96
+ guidance: float = 1.5,
97
+ is_negative_prompt: bool = False,
98
+ apply_corruptor: bool = True,
99
+ corrupt_sigma: float = 1.5,
100
+ preencode_condition: bool = False,
101
+ ) -> Callable:
102
+ """
103
+ Generates a callable function `x0_fn` based on the provided data batch and guidance factor.
104
+
105
+ This function first processes the input data batch through a conditioning workflow (`conditioner`) to obtain conditioned and unconditioned states. It then defines a nested function `x0_fn` which applies a denoising operation on an input `noise_x` at a given noise level `sigma` using both the conditioned and unconditioned states.
106
+
107
+ Args:
108
+ - data_batch (Dict): A batch of data used for conditioning. The format and content of this dictionary should align with the expectations of the `self.conditioner`
109
+ - guidance (float, optional): A scalar value that modulates the influence of the conditioned state relative to the unconditioned state in the output. Defaults to 1.5.
110
+ - is_negative_prompt (bool): use negative prompt t5 in uncondition if true
111
+
112
+ Returns:
113
+ - Callable: A function `x0_fn(noise_x, sigma)` that takes two arguments, `noise_x` and `sigma`, and return x0 predictoin
114
+
115
+ The returned function is suitable for use in scenarios where a denoised state is required based on both conditioned and unconditioned inputs, with an adjustable level of guidance influence.
116
+ """
117
+ input_key = self.input_data_key # by default it is video key
118
+ # Latent state
119
+ raw_state = data_batch[input_key]
120
+
121
+ if self.condition_on_tokenizer_corruptor_token:
122
+ if preencode_condition:
123
+ latent_condition = raw_state.to(torch.int32).contiguous()
124
+ corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition[:, 0])
125
+ else:
126
+ corrupted_pixel = (
127
+ self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state
128
+ )
129
+ latent_condition = self.tokenizer_corruptor.encode(corrupted_pixel)
130
+ latent_condition = latent_condition[1] if isinstance(latent_condition, tuple) else latent_condition
131
+ corrupted_pixel = self.tokenizer_corruptor.decode(latent_condition)
132
+ latent_condition = latent_condition.unsqueeze(1)
133
+ else:
134
+ if preencode_condition:
135
+ latent_condition = raw_state
136
+ corrupted_pixel = self.decode(latent_condition)
137
+ else:
138
+ corrupted_pixel = (
139
+ self.pixel_corruptor(raw_state) if apply_corruptor and self.pixel_corruptor else raw_state
140
+ )
141
+ latent_condition = self.encode(corrupted_pixel).contiguous()
142
+
143
+ sigma = (
144
+ torch.rand((latent_condition.shape[0],)).to(**self.tensor_kwargs) * corrupt_sigma
145
+ ) # small value to indicate clean video
146
+ _, _, _, c_noise_cond = self.scaling(sigma=sigma)
147
+ if corrupt_sigma != self.diffusion_decoder_cond_sigma_low and self.diffusion_decoder_corrupt_prob > 0:
148
+ noise = batch_mul(sigma, torch.randn_like(latent_condition))
149
+ latent_condition = latent_condition + noise
150
+ data_batch["latent_condition_sigma"] = batch_mul(torch.ones_like(latent_condition[:, 0:1, ::]), c_noise_cond)
151
+ data_batch["latent_condition"] = latent_condition
152
+ if is_negative_prompt:
153
+ condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
154
+ else:
155
+ condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
156
+
157
+ def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
158
+ cond_x0 = self.denoise(noise_x, sigma, condition).x0
159
+ uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0
160
+ return cond_x0 + guidance * (cond_x0 - uncond_x0)
161
+
162
+ return x0_fn, corrupted_pixel
163
+
164
+ def generate_samples_from_batch(
165
+ self,
166
+ data_batch: Dict,
167
+ guidance: float = 1.5,
168
+ seed: int = 1,
169
+ state_shape: Tuple | None = None,
170
+ n_sample: int | None = None,
171
+ is_negative_prompt: bool = False,
172
+ num_steps: int = 35,
173
+ solver_option: COMMON_SOLVER_OPTIONS = "2ab",
174
+ sigma_min: float = 0.02,
175
+ apply_corruptor: bool = False,
176
+ return_recon_x: bool = False,
177
+ corrupt_sigma: float = 0.01,
178
+ preencode_condition: bool = False,
179
+ ) -> Tensor:
180
+ """
181
+ Generate samples from the batch. Based on given batch, it will automatically determine whether to generate image or video samples.
182
+ Args:
183
+ data_batch (dict): raw data batch draw from the training data loader.
184
+ iteration (int): Current iteration number.
185
+ guidance (float): guidance weights
186
+ seed (int): random seed
187
+ state_shape (tuple): shape of the state, default to self.state_shape if not provided
188
+ n_sample (int): number of samples to generate
189
+ is_negative_prompt (bool): use negative prompt t5 in uncondition if true
190
+ num_steps (int): number of steps for the diffusion process
191
+ solver_option (str): differential equation solver option, default to "2ab"~(mulitstep solver)
192
+ preencode_condition (bool): use pre-computed condition if true, save tokenizer's inference time memory/
193
+ """
194
+ if not preencode_condition:
195
+ self._normalize_video_databatch_inplace(data_batch)
196
+ self._augment_image_dim_inplace(data_batch)
197
+ is_image_batch = False
198
+ if n_sample is None:
199
+ input_key = self.input_image_key if is_image_batch else self.input_data_key
200
+ n_sample = data_batch[input_key].shape[0]
201
+ if state_shape is None:
202
+ if is_image_batch:
203
+ state_shape = (self.state_shape[0], 1, *self.state_shape[2:]) # C,T,H,W
204
+
205
+ x0_fn, recon_x = self.get_x0_fn_from_batch(
206
+ data_batch,
207
+ guidance,
208
+ is_negative_prompt=is_negative_prompt,
209
+ apply_corruptor=apply_corruptor,
210
+ corrupt_sigma=corrupt_sigma,
211
+ preencode_condition=preencode_condition,
212
+ )
213
+ generator = torch.Generator(device=self.tensor_kwargs["device"])
214
+ generator.manual_seed(seed)
215
+ x_sigma_max = (
216
+ torch.randn(n_sample, *state_shape, **self.tensor_kwargs, generator=generator) * self.sde.sigma_max
217
+ )
218
+
219
+ samples = self.sampler(
220
+ x0_fn,
221
+ x_sigma_max,
222
+ num_steps=num_steps,
223
+ sigma_min=sigma_min,
224
+ sigma_max=self.sde.sigma_max,
225
+ solver_option=solver_option,
226
+ )
227
+
228
+ if return_recon_x:
229
+ return samples, recon_x
230
+ else:
231
+ return samples
ar_diffusion_decoder_network.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ from einops import rearrange
20
+ from torch import nn
21
+ from torchvision import transforms
22
+
23
+ from .df_module_blocks import PatchEmbed
24
+ from .df_network_general_dit import GeneralDIT
25
+
26
+
27
+ class DiffusionDecoderGeneralDIT(GeneralDIT):
28
+ def __init__(
29
+ self,
30
+ *args,
31
+ is_diffusion_decoder: bool = True,
32
+ diffusion_decoder_condition_on_sigma: bool = False,
33
+ diffusion_decoder_condition_on_token: bool = False,
34
+ diffusion_decoder_token_condition_voc_size: int = 64000,
35
+ diffusion_decoder_token_condition_dim: int = 32,
36
+ **kwargs,
37
+ ):
38
+ # diffusion decoder setting
39
+ self.is_diffusion_decoder = is_diffusion_decoder
40
+ self.diffusion_decoder_condition_on_sigma = diffusion_decoder_condition_on_sigma
41
+ self.diffusion_decoder_condition_on_token = diffusion_decoder_condition_on_token
42
+ self.diffusion_decoder_token_condition_voc_size = diffusion_decoder_token_condition_voc_size
43
+ self.diffusion_decoder_token_condition_dim = diffusion_decoder_token_condition_dim
44
+ super().__init__(*args, **kwargs)
45
+
46
+ def initialize_weights(self):
47
+ # Initialize transformer layers:
48
+ super().initialize_weights()
49
+ if self.diffusion_decoder_condition_on_token:
50
+ nn.init.constant_(self.token_embedder.weight, 0)
51
+
52
+ def build_patch_embed(self):
53
+ (
54
+ concat_padding_mask,
55
+ in_channels,
56
+ patch_spatial,
57
+ patch_temporal,
58
+ model_channels,
59
+ is_diffusion_decoder,
60
+ diffusion_decoder_token_condition_dim,
61
+ diffusion_decoder_condition_on_sigma,
62
+ ) = (
63
+ self.concat_padding_mask,
64
+ self.in_channels,
65
+ self.patch_spatial,
66
+ self.patch_temporal,
67
+ self.model_channels,
68
+ self.is_diffusion_decoder,
69
+ self.diffusion_decoder_token_condition_dim,
70
+ self.diffusion_decoder_condition_on_sigma,
71
+ )
72
+ in_channels = (
73
+ in_channels + in_channels
74
+ if (is_diffusion_decoder and not self.diffusion_decoder_condition_on_token)
75
+ else in_channels
76
+ )
77
+ in_channels = in_channels + 1 if diffusion_decoder_condition_on_sigma else in_channels
78
+ in_channels = (
79
+ in_channels + self.diffusion_decoder_token_condition_dim
80
+ if self.diffusion_decoder_condition_on_token
81
+ else in_channels
82
+ )
83
+ in_channels = in_channels + 1 if concat_padding_mask else in_channels
84
+
85
+ self.x_embedder = PatchEmbed(
86
+ spatial_patch_size=patch_spatial,
87
+ temporal_patch_size=patch_temporal,
88
+ in_channels=in_channels,
89
+ out_channels=model_channels,
90
+ bias=False,
91
+ )
92
+
93
+ if self.diffusion_decoder_condition_on_token:
94
+ self.token_embedder = nn.Embedding(
95
+ self.diffusion_decoder_token_condition_voc_size, self.diffusion_decoder_token_condition_dim
96
+ )
97
+
98
+ def prepare_embedded_sequence(
99
+ self,
100
+ x_B_C_T_H_W: torch.Tensor,
101
+ fps: Optional[torch.Tensor] = None,
102
+ padding_mask: Optional[torch.Tensor] = None,
103
+ latent_condition: Optional[torch.Tensor] = None,
104
+ latent_condition_sigma: Optional[torch.Tensor] = None,
105
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
106
+ """
107
+ Prepares an embedded sequence tensor by applying positional embeddings and handling padding masks.
108
+
109
+ Args:
110
+ x_B_C_T_H_W (torch.Tensor): video
111
+ fps (Optional[torch.Tensor]): Frames per second tensor to be used for positional embedding when required.
112
+ If None, a default value (`self.base_fps`) will be used.
113
+ padding_mask (Optional[torch.Tensor]): current it is not used
114
+
115
+ Returns:
116
+ Tuple[torch.Tensor, Optional[torch.Tensor]]:
117
+ - A tensor of shape (B, T, H, W, D) with the embedded sequence.
118
+ - An optional positional embedding tensor, returned only if the positional embedding class
119
+ (`self.pos_emb_cls`) includes 'rope'. Otherwise, None.
120
+
121
+ Notes:
122
+ - If `self.concat_padding_mask` is True, a padding mask channel is concatenated to the input tensor.
123
+ - The method of applying positional embeddings depends on the value of `self.pos_emb_cls`.
124
+ - If 'rope' is in `self.pos_emb_cls` (case insensitive), the positional embeddings are generated using
125
+ the `self.pos_embedder` with the shape [T, H, W].
126
+ - If "fps_aware" is in `self.pos_emb_cls`, the positional embeddings are generated using the `self.pos_embedder`
127
+ with the fps tensor.
128
+ - Otherwise, the positional embeddings are generated without considering fps.
129
+ """
130
+ if self.diffusion_decoder_condition_on_token:
131
+ latent_condition = self.token_embedder(latent_condition)
132
+ B, _, T, H, W, _ = latent_condition.shape
133
+ latent_condition = rearrange(latent_condition, "B 1 T H W D -> (B T) (1 D) H W")
134
+
135
+ latent_condition = transforms.functional.resize(
136
+ latent_condition, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.BILINEAR
137
+ )
138
+ latent_condition = rearrange(latent_condition, "(B T) D H W -> B D T H W ", B=B, T=T)
139
+ x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, latent_condition], dim=1)
140
+ if self.diffusion_decoder_condition_on_sigma:
141
+ x_B_C_T_H_W = torch.cat([x_B_C_T_H_W, latent_condition_sigma], dim=1)
142
+ if self.concat_padding_mask:
143
+ padding_mask = transforms.functional.resize(
144
+ padding_mask, list(x_B_C_T_H_W.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
145
+ )
146
+ x_B_C_T_H_W = torch.cat(
147
+ [x_B_C_T_H_W, padding_mask.unsqueeze(1).repeat(1, 1, x_B_C_T_H_W.shape[2], 1, 1)], dim=1
148
+ )
149
+ x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
150
+
151
+ if self.extra_per_block_abs_pos_emb:
152
+ extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps)
153
+ else:
154
+ extra_pos_emb = None
155
+
156
+ if "rope" in self.pos_emb_cls.lower():
157
+ return x_B_T_H_W_D, self.pos_embedder(x_B_T_H_W_D, fps=fps), extra_pos_emb
158
+
159
+ if "fps_aware" in self.pos_emb_cls:
160
+ x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D, fps=fps) # [B, T, H, W, D]
161
+ else:
162
+ x_B_T_H_W_D = x_B_T_H_W_D + self.pos_embedder(x_B_T_H_W_D) # [B, T, H, W, D]
163
+ return x_B_T_H_W_D, None, extra_pos_emb
ar_diffusion_decoder_utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn.functional as F
18
+
19
+
20
+ def split_with_overlap(video_BCTHW, num_video_frames, overlap=2, tobf16=True):
21
+ """
22
+ Splits the video tensor into chunks of num_video_frames with a specified overlap.
23
+
24
+ Args:
25
+ - video_BCTHW (torch.Tensor): Input tensor with shape [Batch, Channels, Time, Height, Width].
26
+ - num_video_frames (int): Number of frames per chunk.
27
+ - overlap (int): Number of overlapping frames between chunks.
28
+
29
+ Returns:
30
+ - List of torch.Tensors: List of video chunks with overlap.
31
+ """
32
+ # Get the dimensions of the input tensor
33
+ B, C, T, H, W = video_BCTHW.shape
34
+
35
+ # Ensure overlap is less than num_video_frames
36
+ assert overlap < num_video_frames, "Overlap should be less than num_video_frames."
37
+
38
+ # List to store the chunks
39
+ chunks = []
40
+
41
+ # Step size for the sliding window
42
+ step = num_video_frames - overlap
43
+
44
+ # Loop through the time dimension (T) with the sliding window
45
+ for start in range(0, T - overlap, step):
46
+ end = start + num_video_frames
47
+ # Handle the case when the last chunk might go out of bounds
48
+ if end > T:
49
+ # Get the last available frame
50
+ num_padding_frames = end - T
51
+ chunk = F.pad(video_BCTHW[:, :, start:T, :, :], (0, 0, 0, 0, 0, num_padding_frames), mode="reflect")
52
+ else:
53
+ # Regular case: no padding needed
54
+ chunk = video_BCTHW[:, :, start:end, :, :]
55
+ if tobf16:
56
+ chunks.append(chunk.to(torch.bfloat16))
57
+ else:
58
+ chunks.append(chunk)
59
+ return chunks
60
+
61
+
62
+ def linear_blend_video_list(videos, D):
63
+ """
64
+ Linearly blends a list of videos along the time dimension with overlap length D.
65
+
66
+ Parameters:
67
+ - videos: list of video tensors, each of shape [b, c, t, h, w]
68
+ - D: int, overlap length
69
+
70
+ Returns:
71
+ - output_video: blended video tensor of shape [b, c, L, h, w]
72
+ """
73
+ assert len(videos) >= 2, "At least two videos are required."
74
+ b, c, t, h, w = videos[0].shape
75
+ N = len(videos)
76
+
77
+ # Ensure all videos have the same shape
78
+ for video in videos:
79
+ assert video.shape == (b, c, t, h, w), "All videos must have the same shape."
80
+
81
+ # Calculate total output length
82
+ L = N * t - D * (N - 1)
83
+ output_video = torch.zeros((b, c, L, h, w), device=videos[0].device)
84
+
85
+ output_index = 0 # Current index in the output video
86
+
87
+ for i in range(N):
88
+ if i == 0:
89
+ # Copy frames from the first video up to t - D
90
+ output_video[:, :, output_index : output_index + t - D, :, :] = videos[i][:, :, : t - D, :, :]
91
+ output_index += t - D
92
+ else:
93
+ # Blend overlapping frames between videos[i-1] and videos[i]
94
+ blend_weights = torch.linspace(0, 1, steps=D, device=videos[0].device)
95
+
96
+ for j in range(D):
97
+ w1 = 1 - blend_weights[j]
98
+ w2 = blend_weights[j]
99
+ frame_from_prev = videos[i - 1][:, :, t - D + j, :, :]
100
+ frame_from_curr = videos[i][:, :, j, :, :]
101
+ output_frame = w1 * frame_from_prev + w2 * frame_from_curr
102
+ output_video[:, :, output_index, :, :] = output_frame
103
+ output_index += 1
104
+
105
+ if i < N - 1:
106
+ # Copy non-overlapping frames from current video up to t - D
107
+ frames_to_copy = t - 2 * D
108
+ if frames_to_copy > 0:
109
+ output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][
110
+ :, :, D : t - D, :, :
111
+ ]
112
+ output_index += frames_to_copy
113
+ else:
114
+ # For the last video, copy frames from D to t
115
+ frames_to_copy = t - D
116
+ output_video[:, :, output_index : output_index + frames_to_copy, :, :] = videos[i][:, :, D:, :, :]
117
+ output_index += frames_to_copy
118
+
119
+ return output_video
ar_model.py ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import json
17
+ import os
18
+ import time
19
+ from pathlib import Path
20
+ from typing import Any, Dict, List, Optional, Set
21
+
22
+ from .misc import misc, Color, timer
23
+ import torch
24
+ from safetensors.torch import load_file
25
+ from torch.nn.modules.module import _IncompatibleKeys
26
+
27
+ from .ar_config_base_model import ModelConfig
28
+ from .ar_config_base_tokenizer import TokenizerConfig
29
+ from .ar_module_mm_projector import MultimodalProjector
30
+ from .ar_network_transformer import Transformer
31
+ from .ar_network_vit import VisionTransformer, get_vit_config
32
+ from .ar_tokenizer_tokenizer import DiscreteMultimodalTokenizer, update_vocab_size
33
+ from .ar_utils_checkpoint import (
34
+ get_partial_state_dict,
35
+ process_state_dict,
36
+ substrings_to_ignore,
37
+ )
38
+ from .ar_utils_sampling import decode_n_tokens, decode_one_token, prefill
39
+ from .log import log
40
+
41
+
42
+ class AutoRegressiveModel(torch.nn.Module):
43
+ """
44
+ A class to build and use a AutoRegressiveModel model for text generation.
45
+
46
+ Methods:
47
+ build: Build a AutoRegressiveModel instance by initializing and loading a model checkpoint.
48
+ generate: Generate text sequences based on provided prompts using the language generation model.
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ model: Transformer = None,
54
+ tokenizer: DiscreteMultimodalTokenizer = None,
55
+ config: ModelConfig = None,
56
+ vision_encoder: VisionTransformer = None,
57
+ mm_projector: MultimodalProjector = None,
58
+ ):
59
+ """
60
+ Initialize the AutoRegressiveModel instance with a model and tokenizer.
61
+
62
+ Args:
63
+ model (Transformer): The Transformer model for text generation.
64
+ tokenizer (Tokenizer): The tokenizer for encoding and decoding text.
65
+ config (Config): The configuration for the AutoRegressiveModel model.
66
+ vision_encoder (VisionTransformer): The vision encoder for the AutoRegressiveModel model.
67
+ mm_projector (MultimodalProjector): The multi-modal projector for the AutoRegressiveModel model.
68
+ """
69
+ super().__init__()
70
+ self.model = model
71
+ self.tokenizer = tokenizer
72
+ self.config = config
73
+
74
+ self.vision_encoder = vision_encoder
75
+ self.mm_projector = mm_projector
76
+
77
+ @property
78
+ def precision(self):
79
+ return self.model.precision
80
+
81
+ def get_num_params(
82
+ self,
83
+ ) -> int:
84
+ """
85
+ Return the number of parameters in the model.
86
+ """
87
+ n_params = sum(p.numel() for p in self.parameters())
88
+ return n_params
89
+
90
+ def load_ar_model(
91
+ self,
92
+ tokenizer_config,
93
+ ):
94
+ """
95
+ Load the AR model.
96
+ """
97
+ model_config = self.config
98
+ ckpt_path = model_config.ckpt_path
99
+ with timer(f"loading checkpoint from {ckpt_path}"):
100
+ if ckpt_path.endswith("safetensors"):
101
+ # Load with safetensors API
102
+ checkpoint = load_file(ckpt_path, device="cpu")
103
+ else:
104
+ # The pytorch version
105
+ checkpoint = torch.load(
106
+ ckpt_path,
107
+ map_location="cpu",
108
+ mmap=True, # load the checkpoint in memory-mapped mode
109
+ weights_only=True,
110
+ )
111
+ llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint
112
+ orig_precision = torch.get_default_dtype()
113
+ precision = getattr(torch, model_config.precision)
114
+ torch.set_default_dtype(precision)
115
+ log.debug(f"Setting torch default dtype to {precision}")
116
+
117
+ model = Transformer(
118
+ params=model_config,
119
+ tokenizer_config=tokenizer_config,
120
+ )
121
+ log.debug(
122
+ f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size}"
123
+ )
124
+ vocab_size = update_vocab_size(
125
+ existing_vocab_size=0,
126
+ to_be_added_vocab_size=tokenizer_config.video_tokenizer.vocab_size,
127
+ training_type=tokenizer_config.training_type,
128
+ add_special_tokens=False,
129
+ )
130
+ log.debug(
131
+ f"tokenizer tokenizer_config.video_tokenizer.vocab_size {tokenizer_config.video_tokenizer.vocab_size} vocab_size {vocab_size}"
132
+ )
133
+ # Perform vocab expansion
134
+ if vocab_size > model.vocab_size:
135
+ log.debug(f"Expanding vocab size to {vocab_size}")
136
+ # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer,
137
+ expand_output_layer = not (tokenizer_config.training_type == "text_to_video")
138
+ model.expand_vocab(
139
+ vocab_size,
140
+ init_method="gaussian",
141
+ expand_output_layer=expand_output_layer,
142
+ )
143
+ # Remove the "model." prefix in the state_dict
144
+ llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
145
+ with timer("loading state_dict into model"):
146
+ missing_keys, _ = model.load_state_dict(llm_checkpoint, strict=True)
147
+ # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
148
+ missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
149
+ assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
150
+
151
+ self.model = model.to(precision).to("cuda")
152
+ torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
153
+
154
+ def load_tokenizer(self, tokenizer_config):
155
+ """
156
+ Load the tokenizer.
157
+ """
158
+ self.tokenizer = DiscreteMultimodalTokenizer(tokenizer_config)
159
+
160
+ @staticmethod
161
+ def build(
162
+ model_config: ModelConfig = ModelConfig(),
163
+ tokenizer_config: TokenizerConfig = None,
164
+ ) -> "AutoRegressiveModel":
165
+ """
166
+ Build a AutoRegressiveModel instance by initializing and loading a model checkpoint.
167
+
168
+ Args:
169
+ model_config (ModelConfig, optional): The model configuration for the AutoRegressiveModel instance. Defaults to ModelConfig().
170
+ tokenizer_config (TokenizerConfig, optional): The tokenizer configuration for the AutoRegressiveModel instance. Defaults to None.
171
+ download_rank_sync (bool, optional): Whether to download the checkpoint in a rank-synchronized manner. Defaults to True.
172
+ Returns:
173
+ AutoRegressiveModel: An instance of the AutoRegressiveModel class with the loaded model and tokenizer.
174
+
175
+ Raises:
176
+ AssertionError: If there are no checkpoint files in the specified directory.
177
+
178
+ Note:
179
+ This method sets the device to CUDA and loads the pre-trained model and tokenizer.
180
+ """
181
+ # Initialize model configuration parameters
182
+ config_params = {}
183
+
184
+ # Load checkpoint and model parameters
185
+
186
+ if model_config.ckpt_path is None:
187
+ # If ckpt_path is not provided, we assume the model checkpoint is saved in the ckpt_dir
188
+ ckpt_dir = model_config.ckpt_dir
189
+
190
+ # We prioritize safetensors version over the pytorch version, since the former is
191
+ # much faster for checkpoint loading.
192
+ checkpoints = sorted(Path(ckpt_dir).glob("*.safetensors"))
193
+ if len(checkpoints) == 0:
194
+ checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
195
+
196
+ assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
197
+ assert (
198
+ len(checkpoints) == 1
199
+ ), f"multiple checkpoint files found in {ckpt_dir} (currently only one is supported)"
200
+ ckpt_path = str(checkpoints[0]) # Assuming single checkpoint for non-parallel case
201
+
202
+ if os.path.exists(Path(ckpt_dir) / "config.json"):
203
+ with open(Path(ckpt_dir) / "config.json", "r") as f:
204
+ config_params = json.loads(f.read())
205
+ else:
206
+ log.info(
207
+ f"No params.json found in the checkpoint directory ({ckpt_dir}). " f"Using default model config."
208
+ )
209
+
210
+ else:
211
+ # If ckpt_path is provided, we load the model from the specified path,
212
+ # and use the default model configuration
213
+ ckpt_path = model_config.ckpt_path
214
+
215
+ for key, value in config_params.items():
216
+ if hasattr(model_config, key):
217
+ # Override the default model configuration with the parameters from the checkpoint
218
+ setattr(model_config, key, value)
219
+
220
+ with timer(f"loading checkpoint from {ckpt_path}"):
221
+ if ckpt_path.endswith("safetensors"):
222
+ # Load with safetensors API
223
+ checkpoint = load_file(ckpt_path, device="cpu")
224
+ else:
225
+ # The pytorch version
226
+ checkpoint = torch.load(
227
+ ckpt_path,
228
+ map_location="cpu",
229
+ mmap=True, # load the checkpoint in memory-mapped mode
230
+ weights_only=True,
231
+ )
232
+ llm_checkpoint = checkpoint["model"] if "model" in checkpoint else checkpoint
233
+
234
+ if model_config.vision_encoder is not None:
235
+ # Take the LLM weights (starting with "model.") from the VLM checkpoint
236
+ llm_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="model.")
237
+ if model_config.vision_encoder is not None:
238
+ # For vanilla VLM ckpt before fine-tuning, `checkpoint['model']` only contains LLM weights, and `checkpoint['vision_encoder']`
239
+ # and `checkpoint['mm_projector']` are both for those weights
240
+ # For fine-tuned VLM ckpt, `checkpoint['model']` contains all LLM, mm_projector and vision_encoder weights
241
+ if "vision_encoder" in checkpoint:
242
+ log.debug("Using pretrained vision_encoder")
243
+ vit_checkpoint = checkpoint["vision_encoder"]
244
+ else:
245
+ log.debug("Using fine-tuned vision_encoder")
246
+ vit_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="vision_encoder.")
247
+ vit_checkpoint = process_state_dict(vit_checkpoint, prefix_to_remove="vision_encoder.")
248
+ if "mm_projector" in checkpoint:
249
+ log.debug("Using pretrained mm_projector")
250
+ projector_checkpoint = checkpoint["mm_projector"]
251
+ else:
252
+ log.debug("Using fine-tuned mm_projector")
253
+ projector_checkpoint = get_partial_state_dict(llm_checkpoint, prefix="mm_projector.")
254
+ projector_checkpoint = process_state_dict(projector_checkpoint, prefix_to_remove="mm_projector.")
255
+ assert (
256
+ len(vit_checkpoint) > 0 and len(projector_checkpoint) > 0
257
+ ), "vit_checkpoint and projector_checkpoint cannot be empty. We do not support random initialization for vision_encoder and mm_projector."
258
+
259
+ tokenizer = DiscreteMultimodalTokenizer(tokenizer_config)
260
+ orig_precision = torch.get_default_dtype()
261
+ precision = getattr(torch, model_config.precision)
262
+ torch.set_default_dtype(precision)
263
+ log.debug(f"Setting torch default dtype to {precision}")
264
+
265
+ model = Transformer(
266
+ params=model_config,
267
+ tokenizer_config=tokenizer_config,
268
+ )
269
+ model_kwargs = {}
270
+
271
+ if model_config.vision_encoder is not None:
272
+ assert model_config.mm_projector is not None, "mm_projector must be provided if vision_encoder is provided."
273
+ vit_config = get_vit_config(model_config.vision_encoder)
274
+ vision_encoder = VisionTransformer.build(
275
+ vit_config,
276
+ )
277
+
278
+ mm_projector = MultimodalProjector(
279
+ mm_projector_type=model_config.mm_projector, in_dim=vit_config["dim"], out_dim=model_config["dim"]
280
+ )
281
+ model_kwargs.update({"vision_encoder": vision_encoder, "mm_projector": mm_projector})
282
+
283
+ # Perform vocab expansion
284
+ if tokenizer.vocab_size > model.vocab_size:
285
+ log.debug(f"Expanding vocab size to {tokenizer.vocab_size}")
286
+ # For text-to-video training, we only expand the embedding layer but not the output (unembedding) layer,
287
+ expand_output_layer = not (tokenizer.training_type == "text_to_video")
288
+ model.expand_vocab(
289
+ tokenizer.vocab_size,
290
+ init_method="gaussian",
291
+ expand_output_layer=expand_output_layer,
292
+ )
293
+
294
+ # Remove the "model." prefix in the state_dict
295
+ llm_checkpoint = process_state_dict(llm_checkpoint, prefix_to_remove="model.")
296
+ with timer("loading state_dict into model"):
297
+ missing_keys, unexpected_keys = model.load_state_dict(llm_checkpoint, strict=True)
298
+ # Remove keys with "_extra_state" suffix in missing_keys (defined by TransformerEngine for FP8 usage)
299
+ missing_keys = [k for k in missing_keys if not k.endswith("_extra_state")]
300
+ assert len(missing_keys) == 0, f"Missing keys: {missing_keys}"
301
+
302
+ if model_config.vision_encoder is not None:
303
+ vision_encoder.load_state_dict(vit_checkpoint)
304
+ mm_projector.load_state_dict(projector_checkpoint)
305
+ if model_config.vision_encoder_in_channels != 3:
306
+ vision_encoder.expand_in_channels(model_config.vision_encoder_in_channels)
307
+
308
+ model = model.to(precision) # ensure model parameters are in the correct precision
309
+ log.debug(f"Model config: {model_config}")
310
+
311
+ model_class = AutoRegressiveModel
312
+
313
+ torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
314
+
315
+ return model_class(model, tokenizer, model_config, **model_kwargs)
316
+
317
+ @torch.no_grad()
318
+ def generate(
319
+ self,
320
+ prompt_tokens: List[List[int]] | torch.Tensor,
321
+ max_gen_len: int,
322
+ temperature: float = 1.0,
323
+ top_k: Optional[int] = None,
324
+ top_p: Optional[float] = None,
325
+ num_gen_seq: int = 1,
326
+ logprobs: bool = False,
327
+ echo: bool = False,
328
+ seed: int = None,
329
+ context: Optional[torch.Tensor] = None,
330
+ context_mask: Optional[torch.Tensor] = None,
331
+ compile_sampling: bool = True,
332
+ compile_prefill: bool = False,
333
+ verbose: bool = True,
334
+ stop_tokens: Optional[Set[int]] = None,
335
+ images: Optional[torch.Tensor] = None,
336
+ ):
337
+ """
338
+ Autoregressive generation built upon the gpt-fast implementation (https://github.com/pytorch-labs/gpt-fast).
339
+
340
+ Args:
341
+ prompt_tokens (List[List[int]] | torch.Tensor): A single prompt of shape (1, seq_len).
342
+ max_gen_len (int): Maximum length of the generated text sequence.
343
+ temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
344
+ top_k (int, optional): Top-k value for top-k sampling. Defaults to None.
345
+ top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to None.
346
+ num_gen_seq (int, optional): Number of outputs to generate given the same prompt. Defaults to 1. When temperature == 0, num_gen_seq must be 1 because the generation is deterministic.
347
+ echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
348
+ logit_clipping_range (list, optional): Range of logits to clip. Defaults to [].
349
+ seed (int, optional): Random seed for reproducibility. Defaults to None.
350
+ compile_sampling (bool, optional): Flag indicating whether to compile the decoding function. Defaults to True.
351
+ compile_prefill (bool, optional): Flag indicating whether to compile the prefill function. Defaults to False.
352
+ verbose (bool, optional): Flag indicating whether to print the the time. Defaults to False.
353
+ """
354
+ assert top_k is None or top_p is None, f"Only one of top_k ({top_k} or top_p ({top_p} should be specified."
355
+ if temperature == 0:
356
+ top_p, top_k = None, None
357
+ log.debug("Setting top_p and top_k to None because temperature is 0")
358
+ if top_p is not None:
359
+ log.debug(f"Using top-p sampling with p={top_p} and temperature={temperature}")
360
+ elif top_k is not None:
361
+ log.debug(f"Using top-k sampling with k={top_k} and temperature={temperature}")
362
+ else:
363
+ log.debug("Not applying top-k or top-p sampling. Will use top-k sampling with k=None")
364
+
365
+ orig_precision = torch.get_default_dtype()
366
+ torch.set_default_dtype(self.precision)
367
+
368
+ torch._inductor.config.coordinate_descent_tuning = True
369
+ torch._inductor.config.triton.unique_kernel_names = True
370
+ # Experimental features to reduce compilation times, will be on by default in future
371
+ torch._inductor.config.fx_graph_cache = True
372
+
373
+ if seed is not None:
374
+ misc.set_random_seed(seed)
375
+
376
+ assert not logprobs, "logprobs are not supported for fast_generate yet"
377
+ # Examine if the function prefil and decode_one_token functions are compiled yet. If not, compile them based on the flags
378
+ if compile_sampling and not getattr(self, "inference_decode_compiled", False):
379
+ self.decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
380
+ self.inference_decode_compiled = True
381
+ log.info("Compiled AR sampling function. Note: the first run will be slower due to compilation")
382
+ if compile_prefill and not getattr(self, "inference_prefill_compiled", False):
383
+ self.prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
384
+ self.inference_prefill_compiled = True
385
+ log.info("Compiled prefill function. Note: the first run will be slower due to compilation")
386
+
387
+ if not hasattr(self, "decode_one_token"):
388
+ self.decode_one_token = decode_one_token
389
+ if not hasattr(self, "prefill"):
390
+ self.prefill = prefill
391
+
392
+ # Initialization and Assertions
393
+ if isinstance(self.model.params, list):
394
+ # During training, model.params is a list
395
+ log.debug(
396
+ f"Find self.model.params is a list, use self.config instead. Get max_batch_size={self.config.max_batch_size}, max_seq_len={self.config.max_seq_len}"
397
+ )
398
+ params = self.config
399
+ else:
400
+ params = self.model.params
401
+ if isinstance(prompt_tokens, list):
402
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cuda")
403
+ if prompt_tokens.ndim == 1:
404
+ prompt_tokens = prompt_tokens.view(1, -1)
405
+ else:
406
+ assert prompt_tokens.ndim == 2, f"prompt_tokens has shape {prompt_tokens.shape}"
407
+ batch_size, prompt_len = prompt_tokens.shape
408
+ total_len = min(params.max_seq_len, max_gen_len + prompt_len)
409
+ if max_gen_len + prompt_len > params.max_seq_len:
410
+ log.warning(
411
+ f"max_gen_len + prompt_len={max_gen_len + prompt_len} exceeds max_seq_len={params.max_seq_len}, truncate max_gen_len to {params.max_seq_len - prompt_len}"
412
+ )
413
+ max_gen_len = params.max_seq_len - prompt_len
414
+
415
+ if context_mask is not None:
416
+ context_mask = context_mask.to(dtype=torch.bool)
417
+ if context_mask.ndim == 2:
418
+ assert (
419
+ context_mask.shape[0] == batch_size
420
+ ), f"batch_size mismatch: {context_mask.shape[0]} != {batch_size}"
421
+ # Unsqueeze it to make it of shape [batch_size, 1, 1, context_seq_len]
422
+ context_mask = context_mask.view(batch_size, 1, 1, -1)
423
+
424
+ if num_gen_seq > 1:
425
+ assert (
426
+ batch_size == 1
427
+ ), f"num_gen_seq > 1 is only supported for a single prompt, got {len(prompt_tokens)} prompts"
428
+ log.debug(f"Generating {num_gen_seq} sequences with the same prompt")
429
+ assert (
430
+ num_gen_seq <= params.max_batch_size
431
+ ), f"num_gen_seq={num_gen_seq} exceeds max_batch_size={params.max_batch_size}"
432
+ # repeat the prompt tokens for num_gen_seq times
433
+ prompt_tokens = prompt_tokens.repeat(num_gen_seq, 1)
434
+ assert prompt_tokens.shape == (
435
+ num_gen_seq,
436
+ prompt_len,
437
+ ), f"prompt_tokens must be of shape (num_gen_seq, seq_len), got {prompt_tokens.shape}"
438
+ batch_size = len(prompt_tokens)
439
+
440
+ # create an empty tensor of the expected final shape and fill in the current tokens
441
+ empty = torch.empty(batch_size, total_len, dtype=prompt_tokens.dtype, device=prompt_tokens.device)
442
+ empty[:, :prompt_len] = prompt_tokens
443
+ seq = empty
444
+ input_pos = torch.arange(0, prompt_len, device="cuda")
445
+
446
+ if verbose:
447
+ prefill_start = time.time()
448
+
449
+ if images is not None:
450
+ images = images.to(device=prompt_tokens.device, dtype=torch.bfloat16)
451
+ prompt_token_embeddings = self.embed_vision_language_features(prompt_tokens, images)
452
+ else:
453
+ prompt_token_embeddings = None
454
+
455
+ if context is not None:
456
+ context = context.to(device=prompt_tokens.device, dtype=self.precision)
457
+
458
+ # Prefill stage
459
+ next_token = self.prefill(
460
+ self.model,
461
+ input_pos=input_pos,
462
+ tokens=prompt_tokens if prompt_token_embeddings is None else None,
463
+ token_embeddings=prompt_token_embeddings,
464
+ temperature=temperature,
465
+ top_k=top_k,
466
+ top_p=top_p,
467
+ context=context,
468
+ context_mask=context_mask,
469
+ )
470
+ if verbose:
471
+ prefill_time = time.time() - prefill_start
472
+
473
+ seq[:, [prompt_len]] = next_token.to(dtype=seq.dtype)
474
+ input_pos = torch.tensor([prompt_len], dtype=torch.long, device="cuda")
475
+ stop_tokens = self.tokenizer.stop_tokens if stop_tokens is None else stop_tokens
476
+ stop_tokens = torch.tensor(list(stop_tokens), dtype=torch.long, device="cuda")
477
+
478
+ if verbose:
479
+ decode_start = time.time()
480
+ # Decode stage
481
+ generated_tokens = decode_n_tokens(
482
+ self.model,
483
+ next_token.view(batch_size, -1),
484
+ input_pos,
485
+ max_gen_len - 1,
486
+ temperature=temperature,
487
+ top_k=top_k,
488
+ top_p=top_p,
489
+ stop_tokens=stop_tokens,
490
+ decode_one_token_function=self.decode_one_token,
491
+ context=context,
492
+ context_mask=context_mask,
493
+ )
494
+ gen_len = len(generated_tokens)
495
+ if verbose:
496
+ decode_time = time.time() - decode_start
497
+ prefill_throughput = prompt_len / prefill_time
498
+ decode_throughput = gen_len / decode_time
499
+ log.debug(f"[Prefill] Time: {prefill_time:.2f}s; Throughput: {prefill_throughput:.2f} tokens/s")
500
+ log.debug(f"[Decode] Time: {decode_time:.2f}s; Throughput: {decode_throughput:.2f} tokens/s")
501
+
502
+ generated_tokens = torch.cat(generated_tokens, dim=1)
503
+
504
+ log.debug(f"generated_tokens: {generated_tokens.shape}")
505
+ seq = seq[:, : prompt_len + 1 + gen_len]
506
+ seq[:, prompt_len + 1 :] = generated_tokens
507
+ if not echo:
508
+ seq = seq[:, prompt_len:]
509
+
510
+ torch.set_default_dtype(orig_precision) # Reset the default dtype to the original value
511
+
512
+ return seq, None
513
+
514
+ def embed_vision_language_features(self, input_ids: torch.Tensor, images: torch.tensor) -> torch.Tensor:
515
+ """
516
+ Embed vision and language features into a combined representation.
517
+
518
+ Args:
519
+ input_ids (torch.Tensor): Input token IDs.
520
+ images (torch.tensor): Input images.
521
+
522
+ Returns:
523
+ torch.Tensor: Combined vision-language features.
524
+
525
+ Raises:
526
+ AssertionError: If vision encoder or mm projector is not initialized,
527
+ or if dimensions mismatch.
528
+ """
529
+ # Ensure vision encoder and mm projector are initialized
530
+ assert self.vision_encoder is not None
531
+ assert self.mm_projector is not None
532
+
533
+ # Get image token ID and validate it
534
+ image_token_id = self.vision_encoder.image_token_id
535
+ assert isinstance(image_token_id, int) and image_token_id >= 0, f"Invalid image_token_id: {image_token_id}"
536
+
537
+ # Identify text and image locations in the input
538
+ text_locations = input_ids != image_token_id
539
+ image_locations = input_ids == image_token_id
540
+
541
+ # Process text features
542
+ text_features = self.model.tok_embeddings(input_ids[text_locations])
543
+
544
+ # Process image features
545
+ images = images.to(device=text_features.device, dtype=text_features.dtype)
546
+ vit_outputs = self.vision_encoder(images)
547
+ image_features = self.mm_projector(vit_outputs)
548
+
549
+ # Get dimensions
550
+ B, seq_len = input_ids.shape
551
+ N_total = B * seq_len
552
+ N_txt, D_txt = text_features.shape
553
+ N_img, N_patch, D_img = image_features.shape
554
+
555
+ # Reshape image features
556
+ image_features = image_features.reshape(N_img * N_patch, D_img)
557
+
558
+ # Validate dimensions
559
+ assert D_txt == D_img, f"Text features dim {D_txt} should be equal to image features dim {D_img}"
560
+ assert (
561
+ N_total == N_txt + N_img * N_patch
562
+ ), f"seq_len {seq_len} should be equal to N_txt + N_img*N_Patch {(N_txt, N_img * N_patch, image_locations.sum().item())}"
563
+
564
+ # Combine text and image features
565
+ combined_features = torch.empty(
566
+ (B, seq_len, D_txt),
567
+ dtype=text_features.dtype,
568
+ device=text_features.device,
569
+ )
570
+ combined_features[text_locations, :] = text_features
571
+ combined_features[image_locations, :] = image_features
572
+
573
+ return combined_features
574
+
575
+ def state_dict(self, *args, **kwargs):
576
+ """
577
+ Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8).
578
+ """
579
+ state_dict = super().state_dict(*args, **kwargs)
580
+ return process_state_dict(state_dict)
581
+
582
+ def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False):
583
+ """
584
+ Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by
585
+ TransformerEngine for FP8).
586
+ """
587
+ state_dict = process_state_dict(state_dict)
588
+ missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign)
589
+ actual_missing_keys = []
590
+ for key in missing_keys:
591
+ if not any(substring in key for substring in substrings_to_ignore):
592
+ actual_missing_keys.append(key)
593
+ if strict:
594
+ if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0:
595
+ raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}")
596
+ return _IncompatibleKeys(actual_missing_keys, unexpected_keys)
ar_module_attention.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import Optional, Union
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+ from .ar_module_embedding import RotaryPositionEmbedding
23
+ from .ar_module_normalization import create_norm
24
+
25
+
26
+ class Attention(nn.Module):
27
+ """
28
+ Attenion layer with KV cache.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ n_heads: int,
34
+ n_kv_heads: Union[int, None],
35
+ dim: int,
36
+ max_batch_size: int,
37
+ max_seq_len: int,
38
+ context_dim: Optional[int] = None,
39
+ use_qk_normalization: bool = False,
40
+ norm_type: str = "rmsnorm",
41
+ norm_eps: float = 1e-5,
42
+ causal_mask: Optional[bool] = True,
43
+ head_dim: Optional[int] = None,
44
+ fuse_qkv: bool = False,
45
+ precision: str = "bfloat16",
46
+ attn_type: str = "self",
47
+ ):
48
+ """
49
+ Initializes the GQA module.
50
+
51
+ Args:
52
+ n_heads (int): The number of attention heads.
53
+ n_kv_heads (int, optional): The number of key-value attention heads. None defaults to n_heads.
54
+ dim (int): The dimensionality of the input and output.
55
+ max_batch_size (int): The maximum batch size.
56
+ max_seq_len (int): The maximum sequence length.
57
+ context_dim (int, optional): The dimensionality of the context for cross-attn. Defaults to None.
58
+ use_qk_normalization (bool, optional): Whether to apply QK normalization. Defaults to False.
59
+ norm_type (str, optional): The type of normalization layer. Defaults to "rmsnorm".
60
+ norm_eps (float, optional): The epsilon value for normalization. Defaults to 1e-5.
61
+ causal_mask (bool, optional): Whether to use causal mask. Defaults to True.
62
+ head_dim (int, optional): The dimensionality of each attention head. If None, defaults to dim // n_heads.
63
+ fuse_qkv (bool, optional): Whether to fuse QKV. Defaults to False.
64
+ precision (str, optional): The precision of the module. Defaults to "bfloat16".
65
+ attn_type (str, optional): The type of attention. Defaults to "self".
66
+ """
67
+ super().__init__()
68
+ assert attn_type in ["self", "cross", "full"], f"Invalid attention type: {attn_type}"
69
+ self.attn_type = attn_type
70
+ context_dim = dim if context_dim is None else context_dim
71
+
72
+ self.dim = dim
73
+ self.context_dim = context_dim
74
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
75
+ self.n_local_kv_heads = self.n_kv_heads
76
+ self.n_local_heads = n_heads
77
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
78
+ self.head_dim = dim // n_heads if head_dim is None else head_dim
79
+ self.causal_mask = causal_mask
80
+ self.fuse_qkv = fuse_qkv
81
+ self.precision = precision
82
+
83
+ if fuse_qkv:
84
+ assert context_dim == dim, f"Fuse QKV requires context_dim ({context_dim}) to be equal to dim ({dim})"
85
+ self.total_local_head_dim = (self.n_local_heads + 2 * self.n_local_kv_heads) * self.head_dim
86
+ self.wqkv = nn.Linear(dim, self.total_local_head_dim, bias=False)
87
+ # Register hook to load fused QKV weights
88
+ self._register_load_state_dict_pre_hook(self.load_hook)
89
+ else:
90
+ self.wq = nn.Linear(dim, self.n_local_heads * self.head_dim, bias=False)
91
+ self.wk = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
92
+ self.wv = nn.Linear(context_dim, self.n_local_kv_heads * self.head_dim, bias=False)
93
+ self.wo = nn.Linear(self.n_local_heads * self.head_dim, dim, bias=False)
94
+
95
+ self.max_batch_size = max_batch_size
96
+ self.max_seq_len = max_seq_len
97
+
98
+ if self.attn_type == "self":
99
+ # Cache for key and value tensors
100
+ self.init_kv_cache()
101
+
102
+ # QK normalization layers
103
+ if use_qk_normalization:
104
+ self.q_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
105
+ self.k_norm = create_norm(norm_type, dim=self.head_dim, eps=norm_eps)
106
+
107
+ self.use_qk_normalization = use_qk_normalization
108
+
109
+ self.to(dtype=getattr(torch, self.precision))
110
+
111
+ def load_hook(self, state_dict, prefix, *args):
112
+ if prefix + "wq.weight" in state_dict:
113
+ wq = state_dict.pop(prefix + "wq.weight")
114
+ wk = state_dict.pop(prefix + "wk.weight")
115
+ wv = state_dict.pop(prefix + "wv.weight")
116
+ state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
117
+
118
+ def init_kv_cache(self, dtype=None):
119
+ cache_shape = (self.max_batch_size, self.n_local_kv_heads, self.max_seq_len, self.head_dim)
120
+ if dtype is None:
121
+ dtype = getattr(torch, self.precision)
122
+ if self.attn_type == "self":
123
+ self.cache_k = torch.zeros(cache_shape, dtype=dtype).cuda()
124
+ self.cache_v = torch.zeros(cache_shape, dtype=dtype).cuda()
125
+
126
+ def forward(
127
+ self,
128
+ x: torch.Tensor,
129
+ rope: RotaryPositionEmbedding,
130
+ input_pos: torch.Tensor,
131
+ mask: Optional[torch.Tensor] = None,
132
+ context: Optional[torch.Tensor] = None,
133
+ ):
134
+ """
135
+ Forward pass of GQA.
136
+
137
+ Args:
138
+ x: The input tensor of shape (batch_size, seq_len, dim).
139
+ rope: The rotary positional embedding module.
140
+ input_pos: The starting position of the current sequence.
141
+ mask: The attention mask tensor.
142
+ context: The context tensor of shape (batch_size, context_len, dim).
143
+
144
+ Returns:
145
+ The output tensor after applying GQA.
146
+ """
147
+ bsz, seqlen, _ = x.shape
148
+
149
+ # Use one single module to handle both self-attn and cross-attn
150
+ context = x if context is None else context
151
+ context_len = seqlen if context is None else context.shape[1]
152
+
153
+ if self.fuse_qkv:
154
+ q_size = self.n_local_heads * self.head_dim
155
+ kv_size = self.n_local_kv_heads * self.head_dim
156
+ xq, xk, xv = self.wqkv(x).split([q_size, kv_size, kv_size], dim=-1)
157
+ else:
158
+ # Compute query, key, and value projections
159
+ xq, xk, xv = self.wq(x), self.wk(context), self.wv(context)
160
+
161
+ # Reshape projections
162
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
163
+ xk = xk.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
164
+ xv = xv.view(bsz, context_len, self.n_local_kv_heads, self.head_dim)
165
+
166
+ # QK normalization
167
+ if self.use_qk_normalization:
168
+ xq = self.q_norm(xq)
169
+ xk = self.k_norm(xk)
170
+
171
+ # Apply rotary positional embeddings to queries and keys
172
+ # Only apply RoPE to self-attention!
173
+ if self.attn_type in ["self", "full"]:
174
+ xq, xk = rope(xq, xk, input_pos, seqlen)
175
+
176
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
177
+ # xq: (bs, n_local_heads, seqlen, head_dim)
178
+ # xk: (bs, n_kv_heads, cache_len + context_len, head_dim)
179
+ # xv: (bs, n_kv_heads, cache_len + context_len, head_dim)
180
+ if self.attn_type == "self":
181
+ # Update cache with current key and value tensors
182
+ assert input_pos is not None
183
+ self.cache_k[:bsz, :, input_pos] = xk
184
+ self.cache_v[:bsz, :, input_pos] = xv
185
+ keys, values = (
186
+ self.cache_k[:bsz, :, :],
187
+ self.cache_v[:bsz, :, :],
188
+ )
189
+ else:
190
+ keys, values = xk, xv
191
+
192
+ # Repeat keys and values if necessary
193
+ keys = keys.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim)
194
+ values = values.repeat_interleave(self.n_rep, dim=1) # (bs, n_local_heads, cache_len + context_len, head_dim)
195
+
196
+ # For self-attention, `is_causal` should be set to False when KV cache is pre-computed and used,
197
+ # since the masking is handled outside this attention module.
198
+ # For cross-attention, it's always full-attn without causal mask
199
+ is_causal = False
200
+ output = scaled_dot_product_attention(
201
+ xq,
202
+ keys,
203
+ values,
204
+ head_dim=self.head_dim,
205
+ mask=mask,
206
+ is_causal=is_causal,
207
+ dropout_p=0.0,
208
+ )
209
+ output = output.view(bsz, seqlen, -1)
210
+ output = self.wo(output)
211
+ return output
212
+
213
+
214
+ def scaled_dot_product_attention(
215
+ q: torch.Tensor,
216
+ k: torch.Tensor,
217
+ v: torch.Tensor,
218
+ head_dim: int,
219
+ mask: Optional[torch.Tensor] = None,
220
+ is_causal: Optional[bool] = None,
221
+ dropout_p: float = 0.0,
222
+ ) -> torch.Tensor:
223
+ """
224
+ PyTorch's native implementation of Flash Attention 2.
225
+
226
+ If `is_causal` is given, then the causal attention mask is applied accordingly:
227
+ - If `is_causal` is True, the standard upper-left causal attention masking is applied.
228
+ - If `is_causal` is False, no attention mask is applied, unless an explicit mask tensor is
229
+ provided (i.e., `mask is not None`).
230
+
231
+ If `is_causal` is not given (i.e., `is_causal is None`), then the attention mask is applied
232
+ based on the provided mask tensor:
233
+ - If no explicit attention mask is given (i.e., `mask is None`), `is_causal` is set to True,
234
+ leading to the standard upper-left causal attention masking.
235
+ - If an attention mask is given (i.e., `mask is not None`), the provided mask is used,
236
+ and `is_causal` is set to False.
237
+
238
+ Args:
239
+ q (torch.Tensor): Query tensor
240
+ k (torch.Tensor): Key tensor
241
+ v (torch.Tensor): Value tensor
242
+ head_dim (int): Dimension of each attention head
243
+ mask (Optional[torch.Tensor], optional): Attention mask. Defaults to None.
244
+ is_causal (Optional[bool], optional): Whether to apply causal attention mask. Defaults to None.
245
+ dropout_p (float, optional): Dropout rate. Defaults to 0.0.
246
+
247
+ Returns:
248
+ torch.Tensor: Output tensor after applying scaled dot-product attention
249
+ """
250
+ scale = 1.0 / math.sqrt(head_dim)
251
+ if is_causal is None:
252
+ is_causal = mask is None
253
+ y = torch.nn.functional.scaled_dot_product_attention(
254
+ q,
255
+ k,
256
+ v,
257
+ attn_mask=mask,
258
+ dropout_p=dropout_p,
259
+ scale=scale,
260
+ is_causal=is_causal,
261
+ )
262
+ return y.transpose(1, 2).contiguous()
ar_module_embedding.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import math
17
+ from typing import List, Optional, Tuple
18
+
19
+ import numpy as np
20
+ import torch
21
+ from einops import rearrange, repeat
22
+
23
+
24
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
25
+ """
26
+ embed_dim: output dimension for each position
27
+ pos: a list of positions to be encoded: size (M,)
28
+ out: (M, D)
29
+ """
30
+ assert embed_dim % 2 == 0
31
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
32
+ omega /= embed_dim / 2.0
33
+ omega = 1.0 / 10000**omega # (D/2,)
34
+
35
+ pos = pos.reshape(-1) # (M,)
36
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
37
+
38
+ emb_sin = np.sin(out) # (M, D/2)
39
+ emb_cos = np.cos(out) # (M, D/2)
40
+
41
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
42
+ return emb
43
+
44
+
45
+ def _rotate_half_te(x: torch.Tensor) -> torch.Tensor:
46
+ """
47
+ change sign so the last dimension becomes [-odd, +even].
48
+ Adopted from TransformerEngine.
49
+ Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py
50
+ """
51
+ x = x.view(x.shape[:-1] + torch.Size((2, x.shape[-1] // 2)))
52
+ x1, x2 = x.unbind(dim=-2)
53
+ return torch.cat((-x2, x1), dim=-1)
54
+
55
+
56
+ def _apply_rotary_pos_emb_te(
57
+ t: torch.Tensor,
58
+ cos_freqs: torch.Tensor,
59
+ sin_freqs: torch.Tensor,
60
+ ) -> torch.Tensor:
61
+ """
62
+ Apply rotary positional embedding tensor to the input tensor.
63
+ Adopted from TransformerEngine.
64
+ Source: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py
65
+
66
+ Parameters
67
+ ----------
68
+ t: torch.Tensor
69
+ Input tensor of shape `[b, s, h, d]`, on which
70
+ rotary positional embedding will be applied.
71
+ cos_freqs: torch.Tensor
72
+ Cosine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float',
73
+ sin_freqs: torch.Tensor
74
+ Sine component of rotary positional embedding tensor of shape `[s, 1, 1, d]` and dtype 'float',
75
+ """
76
+ rot_dim = cos_freqs.shape[-1]
77
+ # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
78
+ t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
79
+ # first part is cosine component
80
+ # second part is sine component, need to change signs with _rotate_half method
81
+ t = (t * cos_freqs) + (_rotate_half_te(t) * sin_freqs)
82
+ output = torch.cat((t, t_pass), dim=-1)
83
+ return output
84
+
85
+
86
+ class RotaryPositionEmbedding(torch.nn.Module):
87
+ """
88
+ Rotary Position Embedding module as described in the paper:
89
+ https://arxiv.org/abs/2104.09864
90
+
91
+ This module implements rotary positional embeddings, which are used to
92
+ enhance the performance of transformer models.
93
+
94
+ Args:
95
+ dim (int): Dimensionality of the input tensor.
96
+ max_position_embeddings (Optional[int]): Maximum position embeddings.
97
+ original_max_position_embeddings (Optional[int]): Original maximum position embeddings.
98
+ rope_theta (Optional[float]): Base for the frequency calculation.
99
+ apply_yarn (Optional[bool]): Whether to apply YaRN (Yet another Rotary).
100
+ scale (Optional[int]): Scaling factor for the frequency calculation.
101
+ extrapolation_factor (Optional[int]): Extrapolation factor for the frequency extension.
102
+ attn_factor (Optional[int]): Attention factor for the frequency calculation.
103
+ beta_fast (Optional[int]): Fast beta value for the YaRN frequency calculation.
104
+ beta_slow (Optional[int]): Slow beta value for the YaRN frequency calculation.
105
+ rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D".
106
+ latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs.
107
+ original_latent_shape (Optional[List[int]]): Original shape of the latent tensor for video or image inputs.
108
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
109
+ """
110
+
111
+ def __init__(
112
+ self,
113
+ dim: int,
114
+ max_position_embeddings: Optional[int] = None,
115
+ original_max_position_embeddings: Optional[int] = None,
116
+ rope_theta: Optional[float] = 10000.0,
117
+ apply_yarn: Optional[bool] = False,
118
+ scale: Optional[int] = None,
119
+ extrapolation_factor: Optional[int] = 1,
120
+ attn_factor: Optional[int] = 1,
121
+ beta_fast: Optional[int] = 32,
122
+ beta_slow: Optional[int] = 1,
123
+ rope_dim: Optional[str] = "1D",
124
+ latent_shape: Optional[List[int]] = None,
125
+ original_latent_shape: Optional[List[int]] = None,
126
+ pad_to_multiple_of: Optional[int] = None,
127
+ ):
128
+ super().__init__()
129
+
130
+ self.dim = dim
131
+ self.max_position_embeddings = max_position_embeddings
132
+ self.original_max_position_embeddings = original_max_position_embeddings
133
+ self.rope_theta = rope_theta
134
+ self.apply_yarn = apply_yarn
135
+ self.scale = scale
136
+ self.extrapolation_factor = extrapolation_factor
137
+ self.attn_factor = attn_factor
138
+ self.beta_fast = beta_fast
139
+ self.beta_slow = beta_slow
140
+ self.mscale = 1.0
141
+ self.rope_dim = rope_dim
142
+ self.latent_shape = latent_shape
143
+ self.original_latent_shape = original_latent_shape
144
+ self.pad_to_multiple_of = pad_to_multiple_of
145
+ self.get_inv_freq(torch.cuda.current_device())
146
+
147
+ def get_mscale(self, scale: float = 1.0) -> float:
148
+ """Get the magnitude scaling factor for YaRN."""
149
+ if scale <= 1:
150
+ return 1.0
151
+ return 0.1 * math.log(scale) + 1.0
152
+
153
+ def forward(self, seq_len: Optional[int] = None) -> torch.Tensor:
154
+ """
155
+ Forward pass for the rotary position embedding.
156
+
157
+ Args:
158
+ seq_len (Optional[int]): Length of the sequence.
159
+
160
+ Returns:
161
+ torch.Tensor: The computed frequencies for positional embedding.
162
+ """
163
+
164
+ if self.apply_yarn and seq_len > self.max_seq_len_cached:
165
+ self.max_seq_len_cached = seq_len
166
+ self.freqs = self.compute_freqs()
167
+
168
+ return self.freqs
169
+
170
+ def compute_freqs(
171
+ self,
172
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
173
+ """Compute the spatial frequencies for the latent tensor."""
174
+ self.seq = torch.arange(self.max_seq_len_cached, dtype=torch.float).cuda()
175
+ if self.rope_dim == "1D":
176
+ emb = torch.einsum("i,j->ij", self.seq, self.inv_freq)
177
+
178
+ elif self.rope_dim == "2D":
179
+ H, W = self.latent_shape
180
+ half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq)
181
+ half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq)
182
+ emb = torch.cat(
183
+ [
184
+ repeat(half_emb_h, "h d -> h w d", w=W),
185
+ repeat(half_emb_w, "w d -> h w d", h=H),
186
+ ]
187
+ * 2,
188
+ dim=-1,
189
+ )
190
+ emb = rearrange(emb, "h w d -> (h w) 1 1 d").float()
191
+
192
+ elif self.rope_dim == "3D":
193
+ T, H, W = self.latent_shape
194
+ half_emb_t = torch.outer(self.seq[:T], self.temporal_inv_freq)
195
+ half_emb_h = torch.outer(self.seq[:H], self.spatial_inv_freq)
196
+ half_emb_w = torch.outer(self.seq[:W], self.spatial_inv_freq)
197
+ emb = torch.cat(
198
+ [
199
+ repeat(half_emb_t, "t d -> t h w d", h=H, w=W),
200
+ repeat(half_emb_h, "h d -> t h w d", t=T, w=W),
201
+ repeat(half_emb_w, "w d -> t h w d", t=T, h=H),
202
+ ]
203
+ * 2,
204
+ dim=-1,
205
+ )
206
+ emb = rearrange(emb, "t h w d -> (t h w) 1 1 d").float()
207
+ else:
208
+ raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
209
+ return emb
210
+
211
+ def get_scale_factors(self, inv_freq: torch.Tensor, original_seq_len: int) -> torch.Tensor:
212
+ """Get the scale factors for YaRN."""
213
+ # Calculate the high and low frequency cutoffs for YaRN. Note: `beta_fast` and `beta_slow` are called
214
+ # `high_freq_factor` and `low_freq_factor` in the Llama 3.1 RoPE scaling code.
215
+ high_freq_cutoff = 2 * math.pi * self.beta_fast / original_seq_len
216
+ low_freq_cutoff = 2 * math.pi * self.beta_slow / original_seq_len
217
+ # Obtain a smooth mask that has a value of 0 for low frequencies and 1 for high frequencies, with linear
218
+ # interpolation in between.
219
+ smooth_mask = torch.clamp((inv_freq - low_freq_cutoff) / (high_freq_cutoff - low_freq_cutoff), min=0, max=1)
220
+ # For low frequencies, we scale the frequency by 1/self.scale. For high frequencies, we keep the frequency.
221
+ scale_factors = (1 - smooth_mask) / self.scale + smooth_mask
222
+ return scale_factors
223
+
224
+ def get_inv_freq(self, device: torch.device) -> None:
225
+ """Get the inverse frequency."""
226
+ if self.rope_dim == "1D":
227
+ assert self.max_position_embeddings is not None, "Max position embeddings required."
228
+ inv_freq = 1.0 / (
229
+ self.rope_theta ** (torch.arange(0, self.dim, 2, dtype=torch.float32, device=device) / self.dim)
230
+ )
231
+ if self.apply_yarn:
232
+ assert self.original_max_position_embeddings is not None, "Original max position embeddings required."
233
+ assert self.beta_slow is not None, "Beta slow value required."
234
+ assert self.beta_fast is not None, "Beta fast value required."
235
+
236
+ scale_factors = self.get_scale_factors(inv_freq, self.original_max_position_embeddings)
237
+ # Apply the scaling factors to inv_freq.
238
+ inv_freq = inv_freq * scale_factors
239
+ # Set the magnitude scaling factor.
240
+ self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
241
+ self.max_seq_len_cached = self.max_position_embeddings
242
+ self.inv_freq = inv_freq
243
+
244
+ elif self.rope_dim == "2D":
245
+ assert self.latent_shape is not None, "Latent shape required."
246
+ dim_h = self.dim // 2
247
+ spatial_inv_freq = 1.0 / (
248
+ self.rope_theta ** torch.arange(0, dim_h, 2, dtype=torch.float32, device=device) / dim_h
249
+ )
250
+ if self.apply_yarn:
251
+ assert self.original_latent_shape is not None, "Original latent shape required."
252
+ assert self.beta_slow is not None, "Beta slow value required."
253
+ assert self.beta_fast is not None, "Beta fast value required."
254
+
255
+ scale_factors = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[0])
256
+ spatial_inv_freq = spatial_inv_freq * scale_factors
257
+ self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
258
+ self.spatial_inv_freq = spatial_inv_freq
259
+ self.max_seq_len_cached = max(self.latent_shape)
260
+
261
+ elif self.rope_dim == "3D":
262
+ assert self.latent_shape is not None, "Latent shape required."
263
+ dim_h = self.dim // 6 * 2
264
+ dim_t = self.dim - 2 * dim_h
265
+ self.dim_spatial_range = torch.arange(0, dim_h, 2)[: (dim_h // 2)].float().to(device) / dim_h
266
+ spatial_inv_freq = 1.0 / (self.rope_theta**self.dim_spatial_range)
267
+ self.dim_temporal_range = torch.arange(0, dim_t, 2)[: (dim_t // 2)].float().to(device) / dim_t
268
+ temporal_inv_freq = 1.0 / (self.rope_theta**self.dim_temporal_range)
269
+ if self.apply_yarn:
270
+ assert self.original_latent_shape is not None, "Original latent shape required."
271
+ assert self.beta_slow is not None, "Beta slow value required."
272
+ assert self.beta_fast is not None, "Beta fast value required."
273
+ scale_factors_spatial = self.get_scale_factors(spatial_inv_freq, self.original_latent_shape[1])
274
+ spatial_inv_freq = spatial_inv_freq * scale_factors_spatial
275
+ scale_factors_temporal = self.get_scale_factors(temporal_inv_freq, self.original_latent_shape[0])
276
+ temporal_inv_freq = temporal_inv_freq * scale_factors_temporal
277
+ self.mscale = float(self.get_mscale(self.scale) * self.attn_factor)
278
+ self.spatial_inv_freq = spatial_inv_freq
279
+ self.temporal_inv_freq = temporal_inv_freq
280
+ self.max_seq_len_cached = max(self.latent_shape)
281
+ else:
282
+ raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
283
+
284
+ self.freqs = self.compute_freqs()
285
+
286
+
287
+ class RotaryPositionEmbeddingPytorchV2(RotaryPositionEmbedding):
288
+ """
289
+ Rotary Position Embedding that works in the same way as the TransformerEngine RoPE
290
+ (https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py)
291
+
292
+ """
293
+
294
+ def __init__(
295
+ self,
296
+ seq_len: int,
297
+ training_type: str = None,
298
+ **kwargs,
299
+ ):
300
+ super().__init__(
301
+ **kwargs,
302
+ )
303
+ emb = self.create_rope_freqs(seq_len=seq_len, training_type=training_type)
304
+ emb = emb.transpose(0, 1).contiguous() # [seq, 1, 1, dim] -> [1, seq, 1, dim]
305
+ assert emb.shape[0] == 1 and emb.shape[2] == 1, f"emb shape: {emb.shape}"
306
+ # cos/sin first then dtype conversion for better precision
307
+ self.register_buffer("cos_cached", torch.cos(emb), persistent=False)
308
+ self.register_buffer("sin_cached", torch.sin(emb), persistent=False)
309
+
310
+ def create_rope_freqs(self, seq_len: int, training_type: str = None) -> torch.Tensor:
311
+ """
312
+ Create rotary position embedding frequencies.
313
+
314
+ Args:
315
+ seq_len (int): Sequence length of a sample.
316
+
317
+ Returns:
318
+ torch.Tensor: The computed positional embeddings.
319
+ """
320
+ if self.rope_dim == "1D":
321
+ freqs = super().forward(seq_len=seq_len)
322
+ emb = torch.cat((freqs, freqs), dim=-1)
323
+ emb = emb.reshape(emb.size(0), 1, 1, emb.size(1))
324
+
325
+ elif self.rope_dim in ["2D", "3D"]:
326
+ emb = super().forward(seq_len=seq_len)
327
+ if training_type == "text_to_video":
328
+ # since we added <bov> token at the beginning of the video for text2world, we also extend the position embedding by one token in the beginning
329
+ bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device)
330
+ emb = torch.cat((bov_pe, emb), dim=0)
331
+ else:
332
+ raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
333
+ if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0:
334
+ # Round up to the nearest multiple of pad_to_multiple_of
335
+ pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of
336
+ emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device)), dim=0)
337
+
338
+ return emb
339
+
340
+ def forward(
341
+ self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None
342
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
343
+ if q.dtype != self.cos_cached.dtype:
344
+ self.cos_cached = self.cos_cached.to(q.dtype)
345
+ self.sin_cached = self.sin_cached.to(q.dtype)
346
+
347
+ cos_emb = self.cos_cached
348
+ sin_emb = self.sin_cached
349
+ if input_pos is not None:
350
+ cos_emb = cos_emb[:, input_pos, :, :]
351
+ sin_emb = sin_emb[:, input_pos, :, :]
352
+ elif seq_len is not None:
353
+ cos_emb = cos_emb[:, :seq_len, :, :]
354
+ sin_emb = sin_emb[:, :seq_len, :, :]
355
+ q = _apply_rotary_pos_emb_te(q, cos_emb, sin_emb)
356
+ k = _apply_rotary_pos_emb_te(k, cos_emb, sin_emb)
357
+ return q, k
358
+
359
+
360
+ class RotaryPositionEmbeddingPytorchV1(RotaryPositionEmbedding):
361
+ """
362
+ Rotary Position Embedding that works in the same way as
363
+ mistral_inference (https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/rope.py)
364
+ or llama3 (https://github.com/meta-llama/llama3/blob/main/llama/model.py)
365
+
366
+ """
367
+
368
+ def __init__(
369
+ self,
370
+ **kwargs,
371
+ ):
372
+ super().__init__(
373
+ **kwargs,
374
+ )
375
+ if self.rope_dim == "1D":
376
+ emb = torch.stack((self.freqs, self.freqs), dim=-1).reshape(*self.freqs.shape[:-1], -1)
377
+ elif self.rope_dim in ["2D", "3D"]:
378
+ emb = rearrange(self.freqs, "s 1 1 d -> s d").float()
379
+ self.register_buffer("cos_cached", (emb.cos() * self.mscale)[None, :, None, :], persistent=False)
380
+ self.register_buffer("sin_cached", (emb.sin() * self.mscale)[None, :, None, :], persistent=False)
381
+
382
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
383
+ """Rotate half the hidden dimensions of the input tensor."""
384
+ x_reshaped = x.reshape(*x.shape[:-1], -1, 2)
385
+ x1 = x_reshaped[..., 0]
386
+ x2 = x_reshaped[..., 1]
387
+ output = torch.stack((-x2, x1), dim=-1).reshape(*x.shape)
388
+ return output
389
+
390
+ def forward(
391
+ self, q: torch.Tensor, k: torch.Tensor, input_pos: Optional[torch.Tensor] = None, seq_len: Optional[int] = None
392
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
393
+ """
394
+ Forward pass for the rotary position embedding.
395
+
396
+ Args:
397
+ q (torch.Tensor): Query tensor.
398
+ k (torch.Tensor): Key tensor.
399
+ input_pos (Optional[torch.Tensor]): Starting position for the sequence.
400
+ seq_len (Optional[int]): Length of the sequence.
401
+
402
+ Returns:
403
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors.
404
+ """
405
+ if self.apply_yarn and seq_len > self.max_seq_len_cached:
406
+ freqs = super().forward(seq_len)
407
+ if self.rope_dim == "1D":
408
+ emb = torch.stack((freqs, freqs), dim=-1).reshape(*freqs.shape[:-1], -1)
409
+ elif self.rope_dim in ["2D", "3D"]:
410
+ emb = rearrange(freqs, "s 1 1 d -> s d").float()
411
+ else:
412
+ raise ValueError(f"Invalid RoPE dimensionality: {self.rope_dim}")
413
+ self.register_buffer(
414
+ "cos_cached", (emb.cos() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False
415
+ )
416
+ self.register_buffer(
417
+ "sin_cached", (emb.sin() * self.mscale)[None, :, None, :].to(q.dtype), persistent=False
418
+ )
419
+
420
+ if input_pos is not None:
421
+ cos_cached = self.cos_cached[:, input_pos]
422
+ sin_cached = self.sin_cached[:, input_pos]
423
+ else:
424
+ assert (
425
+ self.cos_cached.shape[1] >= seq_len
426
+ ), f"Invalid sequence length; cos_cached.shape {self.cos_cached.shape}, seq_len {seq_len}."
427
+ cos_cached = self.cos_cached[:, :seq_len, ...]
428
+ sin_cached = self.sin_cached[:, :seq_len, ...]
429
+ xq = q * cos_cached + self.rotate_half(q) * sin_cached
430
+ xk = k * cos_cached + self.rotate_half(k) * sin_cached
431
+
432
+ return xq.type_as(q), xk.type_as(k)
433
+
434
+
435
+ class SinCosPosEmbAxisTE(torch.nn.Module):
436
+ def __init__(
437
+ self,
438
+ dim: int,
439
+ latent_shape: Optional[List[int]] = None,
440
+ pad_to_multiple_of: Optional[int] = None,
441
+ dtype: torch.dtype = torch.bfloat16,
442
+ **kwargs,
443
+ ):
444
+ """
445
+ Args:
446
+ dim (int): Dimensionality of the input tensor.
447
+ latent_shape (Optional[List[int]]): Shape of the latent tensor for video or image inputs.
448
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
449
+ dtype (torch.dtype): Data type of the position embedding tensor.
450
+ """
451
+ super().__init__()
452
+ dim_h = dim // 6 * 2
453
+ dim_w = dim_h
454
+ dim_t = dim - 2 * dim_h
455
+ assert dim == dim_h + dim_w + dim_t, f"bad dim: {dim} != {dim_h} + {dim_w} + {dim_t}"
456
+ self.latent_shape = latent_shape
457
+ T, H, W = latent_shape
458
+ emb_h = get_1d_sincos_pos_embed_from_grid(dim_h, pos=np.arange(H))
459
+ emb_w = get_1d_sincos_pos_embed_from_grid(dim_w, pos=np.arange(W))
460
+ emb_t = get_1d_sincos_pos_embed_from_grid(dim_t, pos=np.arange(T))
461
+
462
+ self.register_buffer("pos_emb_h", torch.from_numpy(emb_h).to(dtype=dtype, device="cuda"), persistent=False)
463
+ self.register_buffer("pos_emb_w", torch.from_numpy(emb_w).to(dtype=dtype, device="cuda"), persistent=False)
464
+ self.register_buffer("pos_emb_t", torch.from_numpy(emb_t).to(dtype=dtype, device="cuda"), persistent=False)
465
+ self.pad_to_multiple_of = pad_to_multiple_of
466
+
467
+ def forward(
468
+ self,
469
+ training_type: str = None,
470
+ ) -> torch.Tensor:
471
+ T, H, W = self.latent_shape
472
+ emb = torch.cat(
473
+ [
474
+ repeat(self.pos_emb_t, "t d-> t h w d", h=H, w=W),
475
+ repeat(self.pos_emb_h, "h d-> t h w d", t=T, w=W),
476
+ repeat(self.pos_emb_w, "w d-> t h w d", t=T, h=H),
477
+ ],
478
+ dim=-1,
479
+ )
480
+ # Flatten the T,H,W dimensions
481
+ emb = rearrange(emb, "t h w d -> (t h w) d")
482
+
483
+ if training_type == "text_to_video":
484
+ bov_pe = torch.zeros((1, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)
485
+ emb = torch.cat((bov_pe, emb), dim=0)
486
+ if self.pad_to_multiple_of is not None and emb.shape[0] % self.pad_to_multiple_of != 0:
487
+ pad_len = self.pad_to_multiple_of - emb.shape[0] % self.pad_to_multiple_of
488
+ emb = torch.cat((emb, torch.zeros((pad_len, *emb.shape[1:]), device=emb.device, dtype=emb.dtype)), dim=0)
489
+ seq_len, dim = emb.shape
490
+ emb = emb.reshape(1, seq_len, dim)
491
+ return emb
ar_module_mlp.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+
21
+ class MLP(nn.Module):
22
+ def __init__(
23
+ self,
24
+ dim: int,
25
+ hidden_dim: int,
26
+ ):
27
+ """
28
+ Initializes the multilayer perceptron (MLP) module.
29
+
30
+ Args:
31
+ dim: The input and output dimensionality.
32
+ hidden_dim: The dimensionality of the hidden layer.
33
+ """
34
+ super().__init__()
35
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
36
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
37
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Performs the forward pass of the MLP module.
42
+
43
+ Args:
44
+ x: The input tensor of shape (batch_size, dim).
45
+
46
+ Returns:
47
+ The output tensor of shape (batch_size, dim).
48
+ """
49
+ output = self.w2(F.silu(self.w1(x)) * self.w3(x))
50
+ return output
ar_module_mm_projector.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Multimodal projector to connect vision encoder / tokenizer with the LLM."""
17
+
18
+ from typing import Any, Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+
24
+ class DownSampleBlock(nn.Module):
25
+ """Downsample block."""
26
+
27
+ def __init__(self):
28
+ super().__init__()
29
+
30
+ def forward(self, x):
31
+ """
32
+ Performs the forward pass of the downsample block.
33
+
34
+ Args:
35
+ x (torch.Tensor): The input tensor from ViT's output of a sequence of embeddings.
36
+ Shape: (b, seq_len, c).
37
+
38
+ Returns:
39
+ torch.Tensor: The output tensor. Shape: (b, seq_len/4, c*4).
40
+ """
41
+ vit_embeds = x
42
+ # Get h and w as the sqrt of seq length. This assumes that the input is square-shaped.
43
+ h = w = int(vit_embeds.shape[1] ** 0.5)
44
+ b = vit_embeds.shape[0]
45
+ vit_embeds = vit_embeds.reshape(b, h, w, -1)
46
+ vit_embeds = self.flat_square(vit_embeds)
47
+ vit_embeds = vit_embeds.reshape(b, -1, vit_embeds.shape[-1])
48
+ return vit_embeds
49
+
50
+ def flat_square(self, x: torch.Tensor) -> torch.Tensor:
51
+ """
52
+ Performs spatial downsampling while increasing the number of channels.
53
+
54
+ Args:
55
+ x (torch.Tensor): The input tensor reshaped to a 2D grid.
56
+ Shape: (b, h, w, c)
57
+
58
+ Returns:
59
+ torch.Tensor: The output tensor after the spatial downsampling.
60
+ Shape: (b, h/2, w/2, c*4)
61
+ """
62
+ b, h, w, c = x.size()
63
+ # If w or h is odd, pad a column or a row of zeros.
64
+ if h % 2 == 1:
65
+ x = torch.concat([x, torch.zeros((b, 1, w, c), dtype=x.dtype).to(x.device)], dim=1).contiguous()
66
+ b, h, w, c = x.size()
67
+ if w % 2 == 1:
68
+ x = torch.concat([x, torch.zeros((b, h, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous()
69
+ b, h, w, c = x.size()
70
+ # 2x spatial downsampling, 4x channel increasing.
71
+ x = x.view(b, h, int(w / 2), int(c * 2))
72
+ x = x.permute(0, 2, 1, 3).contiguous()
73
+ x = x.view(b, int(h / 2), int(w / 2), int(c * 4))
74
+ x = x.permute(0, 2, 1, 3).contiguous()
75
+ return x
76
+
77
+
78
+ class MultimodalProjector(nn.Module):
79
+ """Multimodal projector."""
80
+
81
+ def __init__(
82
+ self,
83
+ mm_projector_type: str,
84
+ in_dim: int,
85
+ out_dim: Optional[int] = None,
86
+ **kwargs: Any,
87
+ ):
88
+ super().__init__()
89
+ if out_dim is None:
90
+ out_dim = in_dim
91
+ if mm_projector_type == "identity":
92
+ self.projector = nn.Identity()
93
+ elif mm_projector_type == "linear":
94
+ self.projector = nn.Linear(in_dim, out_dim)
95
+ elif mm_projector_type == "mlp":
96
+ self.projector = nn.Sequential(nn.Linear(in_dim, out_dim), nn.GELU(), nn.Linear(out_dim, out_dim))
97
+ elif mm_projector_type == "mlp_downsample":
98
+ self.projector = nn.Sequential(
99
+ DownSampleBlock(),
100
+ nn.LayerNorm(in_dim * 4),
101
+ nn.Linear(in_dim * 4, out_dim),
102
+ nn.GELU(),
103
+ nn.Linear(out_dim, out_dim),
104
+ )
105
+ else:
106
+ raise ValueError(f"Unknown projector type: {mm_projector_type}")
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ return self.projector(x)
ar_module_normalization.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+
19
+
20
+ def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
21
+ """
22
+ Creates the specified normalization layer based on the norm_type.
23
+ Adopted from TorchTriton: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
24
+
25
+ Args:
26
+ norm_type (str): The type of normalization layer to create.
27
+ Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
28
+ dim (int): The dimension of the normalization layer.
29
+ eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
30
+
31
+ Returns:
32
+ The created normalization layer.
33
+
34
+ Raises:
35
+ NotImplementedError: If an unknown norm_type is provided.
36
+ """
37
+ norm_type = norm_type.lower() # Normalize to lowercase
38
+
39
+ if norm_type == "layernorm":
40
+ return nn.LayerNorm(dim, eps=eps, bias=False)
41
+ elif norm_type == "np_layernorm":
42
+ return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
43
+ elif norm_type == "rmsnorm":
44
+ return RMSNorm(dim, eps=eps, compile=False)
45
+ elif norm_type == "compiled_rmsnorm":
46
+ return RMSNorm(dim, eps=eps, compile=True)
47
+ elif norm_type == "fused_rmsnorm":
48
+ raise NotImplementedError("Fused RMSNorm is not supported yet.")
49
+ else:
50
+ raise NotImplementedError(f"Unknown norm_type: '{norm_type}'")
51
+
52
+
53
+ class RMSNorm(nn.Module):
54
+ """
55
+ Initialize the RMSNorm normalization layer.
56
+ Reference implementation: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/norms.py
57
+
58
+ Args:
59
+ dim (int): The dimension of the input tensor.
60
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
61
+ compile (bool, optional): Whether to compile the forward function. Default is False.
62
+
63
+ Attributes:
64
+ eps (float): A small value added to the denominator for numerical stability.
65
+ weight (nn.Parameter): Learnable scaling parameter.
66
+
67
+ """
68
+
69
+ def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False):
70
+ super().__init__()
71
+ self.eps = eps
72
+ self.weight = nn.Parameter(torch.ones(dim))
73
+ self.rmsnorm_fn = torch.compile(self.compute_rmsnorm, fullgraph=True) if compile else self.compute_rmsnorm
74
+
75
+ @staticmethod
76
+ def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float):
77
+ def _norm(x, eps):
78
+ # Computes the root-mean-square norm of the input tensor.
79
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
80
+
81
+ output = _norm(x.float(), eps).type_as(x)
82
+ return output * weight
83
+
84
+ def forward(self, x: torch.Tensor):
85
+ return self.rmsnorm_fn(x, self.weight, self.eps)
86
+
87
+ def reset_parameters(self):
88
+ torch.nn.init.ones_(self.weight)
ar_network_transformer.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.nn.modules.module import _IncompatibleKeys
21
+
22
+ from .ar_module_attention import Attention
23
+ from .ar_module_embedding import (
24
+ RotaryPositionEmbeddingPytorchV1,
25
+ RotaryPositionEmbeddingPytorchV2,
26
+ SinCosPosEmbAxisTE,
27
+ )
28
+ from .ar_module_mlp import MLP
29
+ from .ar_module_normalization import create_norm
30
+ from .ar_utils_checkpoint import process_state_dict, substrings_to_ignore
31
+ from .ar_utils_misc import maybe_convert_to_namespace
32
+ from .log import log
33
+
34
+
35
+ class TransformerBlock(nn.Module):
36
+ """
37
+ A single transformer block consisting of an attention layer and a feed-forward layer.
38
+ """
39
+
40
+ def __init__(self, layer_id: int, args=None):
41
+ """
42
+ Initializes the TransformerBlock module.
43
+
44
+ Args:
45
+ layer_id: The ID of the transformer block.
46
+ args: The model arguments containing hyperparameters.
47
+ """
48
+ super().__init__()
49
+ args = maybe_convert_to_namespace(args)
50
+ attention_args = {
51
+ "n_heads": args["n_heads"],
52
+ "n_kv_heads": args["n_kv_heads"],
53
+ "dim": args["dim"],
54
+ "context_dim": None,
55
+ "max_batch_size": args["max_batch_size"],
56
+ "max_seq_len": args["max_seq_len"],
57
+ "use_qk_normalization": args["use_qk_normalization"],
58
+ "causal_mask": args["causal_mask"],
59
+ "head_dim": args["head_dim"],
60
+ "fuse_qkv": getattr(args, "fuse_qkv", False),
61
+ "precision": getattr(args, "precision", "bfloat16"),
62
+ "attn_type": getattr(args, "attn_type", "self"),
63
+ }
64
+ self.attention = Attention(**attention_args)
65
+
66
+ self.has_cross_attention = False
67
+ self.cross_attention, self.cross_attention_norm = None, None
68
+
69
+ if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0:
70
+ self.has_cross_attention = True
71
+ cross_attention_args = attention_args.copy()
72
+ cross_attention_args.update({"context_dim": args["context_dim"], "fuse_qkv": False, "attn_type": "cross"})
73
+ self.cross_attention = Attention(**cross_attention_args)
74
+ self.cross_attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"])
75
+
76
+ self.feed_forward = MLP(
77
+ dim=args["dim"],
78
+ hidden_dim=args["ffn_hidden_size"],
79
+ )
80
+ self.layer_id = layer_id
81
+ self.attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"])
82
+ self.ffn_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"])
83
+
84
+ def forward(
85
+ self,
86
+ x: torch.Tensor,
87
+ rope: RotaryPositionEmbeddingPytorchV2,
88
+ input_pos: Optional[torch.Tensor] = None,
89
+ mask: Optional[torch.Tensor] = None,
90
+ context: Optional[torch.Tensor] = None,
91
+ context_mask: Optional[torch.Tensor] = None,
92
+ ) -> torch.Tensor:
93
+ """
94
+ Performs the forward pass of the TransformerBlock module.
95
+
96
+ Args:
97
+ x: The input tensor.
98
+ input_pos: The position of the current sequence. Used in inference (with KV cache) only.
99
+ freqs_cis: The precomputed frequency values for rotary position embeddings.
100
+ mask: The attention mask tensor.
101
+ context (Optional[torch.Tensor]): The context tensor added via cross-attn.
102
+ context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor.
103
+
104
+ Returns:
105
+ The output tensor after applying the transformer block.
106
+ """
107
+ # Apply attention and residual connection
108
+ h = x + self.attention(self.attention_norm(x), rope=rope, input_pos=input_pos, mask=mask)
109
+
110
+ # If insert cross-attention, apply CA and residual connection
111
+ if self.has_cross_attention:
112
+ h = h + self.cross_attention(
113
+ self.cross_attention_norm(h), rope=rope, input_pos=input_pos, mask=context_mask, context=context
114
+ )
115
+
116
+ # Apply feed-forward network and residual connection
117
+ out = h + self.feed_forward(self.ffn_norm(h))
118
+ return out
119
+
120
+ def init_weights(self):
121
+ """
122
+ Initializes the weights of the transformer block.
123
+ """
124
+ for norm in (self.attention_norm, self.ffn_norm):
125
+ norm.reset_parameters()
126
+ self.attention.init_weights(self.weight_init_std)
127
+ self.feed_forward.init_weights(self.weight_init_std)
128
+
129
+ if self.has_cross_attention:
130
+ self.cross_attention_norm.reset_parameters()
131
+ self.cross_attention.init_weights(self.weight_init_std)
132
+ # zero-init the final output layer of cross-attention
133
+ # nn.init.zeros_(self.cross_attention.wo.weight)
134
+
135
+
136
+ class Transformer(nn.Module):
137
+ """
138
+ The Transformer network consisting of transformer blocks.
139
+ """
140
+
141
+ def __init__(self, params, tokenizer_config=None, init_weights: bool = True):
142
+ """
143
+ Initializes the Transformer module.
144
+
145
+ Args:
146
+ params: The model parameters containing hyperparameters.
147
+ tokenizer_config: The model tokenizer configuration.
148
+ init_weights (bool): Whether to initialize the weights of the transformer following
149
+ TorchTitan's Llama3 initialization scheme.
150
+ """
151
+ super().__init__()
152
+ # Check if self.params is an OmegaConf DictConfig instance
153
+ self.params = maybe_convert_to_namespace(params)
154
+ self.vocab_size = params["vocab_size"]
155
+ self.n_layers = params["n_layers"]
156
+ self.precision = getattr(torch, params["precision"])
157
+ self.tokenizer_config = tokenizer_config
158
+ self.num_video_frames = params["num_video_frames"]
159
+
160
+ # Token embeddings
161
+ self.tok_embeddings = self._create_token_embeddings()
162
+ self.rope_config = self._create_rope_config()
163
+
164
+ # Transformer layers
165
+ self.layers = nn.ModuleList(
166
+ [TransformerBlock(layer_id, self.params).to(self.precision) for layer_id in range(self.n_layers)]
167
+ )
168
+
169
+ # Final layer normalization
170
+ self.norm = create_norm(self.params["norm_type"], dim=self.params["dim"], eps=self.params["norm_eps"]).to(
171
+ self.precision
172
+ )
173
+ if self.params["pytorch_rope_version"] == "v1":
174
+ self.rope = RotaryPositionEmbeddingPytorchV1(**self.rope_config)
175
+ elif self.params["pytorch_rope_version"] == "v2":
176
+ # Rotary position embeddings
177
+ training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None
178
+ self.rope = RotaryPositionEmbeddingPytorchV2(
179
+ seq_len=self.params["max_seq_len"], training_type=training_type, **self.rope_config
180
+ )
181
+ else:
182
+ raise ValueError(f"Invalid PyTorch RoPE version: {self.params['pytorch_rope_version']}")
183
+ # Causal mask
184
+ self.causal_mask = torch.tril(
185
+ torch.ones(self.params["max_seq_len"], self.params["max_seq_len"], dtype=torch.bool)
186
+ ).cuda()
187
+
188
+ # Output projection
189
+ self.output = self._create_output_projection()
190
+
191
+ # Freeze network parameters for finetuning w/ cross-attention
192
+ self.has_cross_attention = getattr(params, "insert_cross_attn", False)
193
+
194
+ # Absolute position embeddings
195
+ if self.params["apply_abs_pos_emb"]:
196
+ self.pos_emb_config = self._create_abs_pos_emb_config()
197
+ self.pos_emb, self.abs_pos_emb = self._initialize_abs_pos_emb()
198
+
199
+ def _create_rope_config(self) -> Dict:
200
+ shape_map = {
201
+ "3D": self.params["video_latent_shape"],
202
+ "1D": None,
203
+ }
204
+ latent_shape = shape_map.get(self.params["rope_dim"], None)
205
+ head_dim = self.params["head_dim"]
206
+ if head_dim is None:
207
+ head_dim = self.params["dim"] // self.params["n_heads"]
208
+ return {
209
+ "dim": head_dim,
210
+ "max_position_embeddings": self.params["max_seq_len"],
211
+ "original_max_position_embeddings": self.params["original_seq_len"],
212
+ "rope_theta": self.params["rope_theta"],
213
+ "apply_yarn": self.params["apply_yarn"],
214
+ "scale": self.params["yarn_scale"],
215
+ "beta_fast": self.params["yarn_beta_fast"],
216
+ "beta_slow": self.params["yarn_beta_slow"],
217
+ "rope_dim": self.params["rope_dim"],
218
+ "latent_shape": latent_shape,
219
+ "original_latent_shape": self.params["original_latent_shape"],
220
+ "pad_to_multiple_of": self.params["pad_to_multiple_of"],
221
+ }
222
+
223
+ def _create_abs_pos_emb_config(self):
224
+ shape_map = {
225
+ "3D": self.params["video_latent_shape"],
226
+ "1D": None,
227
+ }
228
+ latent_shape = shape_map.get(self.params["rope_dim"], None)
229
+ return {
230
+ "dim": self.params["dim"],
231
+ "latent_shape": latent_shape,
232
+ "pad_to_multiple_of": self.params["pad_to_multiple_of"],
233
+ }
234
+
235
+ def _create_token_embeddings(self, vocab_size: int = None):
236
+ """
237
+ Create token embeddings.
238
+
239
+ Returns:
240
+ nn.Module: Token embeddings module.
241
+ """
242
+ if vocab_size is None:
243
+ vocab_size = self.params["vocab_size"]
244
+ return nn.Embedding(vocab_size, self.params["dim"]).to(self.precision)
245
+
246
+ def _create_output_projection(self, vocab_size: int = None):
247
+ """
248
+ Create the output projection layer.
249
+
250
+ Args:
251
+ vocab_size (int): Vocabulary size (to override the default vocab size).
252
+ Returns:
253
+ LinearTE: Output projection layer.
254
+ """
255
+ if vocab_size is None:
256
+ vocab_size = self.params["vocab_size"]
257
+ return nn.Linear(self.params["dim"], vocab_size, bias=False).to(self.precision)
258
+
259
+ def _initialize_abs_pos_emb(self):
260
+ pos_emb = SinCosPosEmbAxisTE(**self.pos_emb_config)
261
+ training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None
262
+ abs_pos_emb = pos_emb.forward(training_type=training_type)
263
+ return pos_emb, abs_pos_emb
264
+
265
+ def forward(
266
+ self,
267
+ tokens: Optional[torch.Tensor] = None,
268
+ input_pos: Optional[torch.Tensor] = None,
269
+ token_embeddings: Optional[torch.Tensor] = None,
270
+ context: Optional[torch.Tensor] = None,
271
+ context_mask: Optional[torch.Tensor] = None,
272
+ ) -> torch.Tensor:
273
+ """
274
+ Performs the forward pass of the Transformer module.
275
+
276
+ Args:
277
+ tokens (torch.Tensor, optional): The input tensor of token IDs.
278
+ input_pos (Optional[torch.Tensor]): The position of the current sequence. Used in inference with KV cache.
279
+ token_embeddings (torch.Tensor, optional): Precomputed token embeddings. If provided, tokens should be None.
280
+ context (Optional[torch.Tensor]): The context tensor added via cross-attn.
281
+ context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor.
282
+ Returns:
283
+ The output tensor after applying the transformer layers.
284
+ """
285
+ # Token embeddings
286
+ assert (
287
+ tokens is None or token_embeddings is None
288
+ ), "Either tokens or token_embeddings should be provided, not both."
289
+
290
+ if token_embeddings is None:
291
+ seq_len = tokens.shape[1]
292
+ h = self.tok_embeddings(tokens)
293
+ else:
294
+ seq_len = token_embeddings.shape[1]
295
+ h = token_embeddings
296
+
297
+ # Create attention mask
298
+ mask = self._create_attention_mask(input_pos=input_pos)
299
+
300
+ # Prepare layer arguments
301
+ layer_kwargs = self._prepare_layer_kwargs(
302
+ input_pos=input_pos,
303
+ mask=mask,
304
+ context=context,
305
+ context_mask=context_mask,
306
+ )
307
+
308
+ # Apply transformer layers
309
+ for layer in self.layers:
310
+ if self.params["apply_abs_pos_emb"]:
311
+ h = self.apply_abs_pos_emb(h, input_pos=input_pos)
312
+ h = layer(h, **layer_kwargs)
313
+
314
+ # Apply final layer normalization
315
+ h = self.norm(h)
316
+
317
+ # Output linear projection
318
+ output = self.output(h)
319
+ return output
320
+
321
+ def _create_attention_mask(self, input_pos: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
322
+ """
323
+ Creates an attention mask for the transformer layers.
324
+
325
+ Args:
326
+ input_pos[torch.Tensor]: The position of input sequence (used for inference only).
327
+
328
+ Returns:
329
+ Optional[torch.Tensor]: The attention mask, or None for causal mask.
330
+ """
331
+
332
+ assert input_pos is not None, "input_pos must be provided for inference"
333
+ mask = self.causal_mask[input_pos]
334
+ return mask
335
+
336
+ def _prepare_layer_kwargs(
337
+ self,
338
+ input_pos: Optional[torch.Tensor],
339
+ mask: Optional[torch.Tensor],
340
+ context: Optional[torch.Tensor],
341
+ context_mask: Optional[torch.Tensor],
342
+ ) -> Dict[str, Any]:
343
+ """
344
+ Prepares the keyword arguments for transformer layers.
345
+
346
+ Args:
347
+ input_pos (Optional[torch.Tensor]): The position of the current sequence.
348
+ mask (Optional[torch.Tensor]): The attention mask.
349
+ context (Optional[torch.Tensor]): The context tensor added via cross-attn.
350
+ context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor.
351
+
352
+ Returns:
353
+ Dict[str, Any]: A dictionary of keyword arguments for the transformer layers.
354
+ """
355
+ if context is not None:
356
+ context = context.to(self.precision)
357
+
358
+ if isinstance(mask, torch.Tensor) and mask.ndim == 2:
359
+ mask = mask[None, None, :, :]
360
+ if isinstance(context_mask, torch.Tensor) and context_mask.ndim == 2:
361
+ context_mask = context_mask[None, None, :, :]
362
+
363
+ layer_kwargs = {
364
+ "mask": mask,
365
+ "context": context,
366
+ "context_mask": context_mask,
367
+ }
368
+
369
+ layer_kwargs["input_pos"] = input_pos
370
+ layer_kwargs["rope"] = self.rope
371
+
372
+ return layer_kwargs
373
+
374
+ def apply_abs_pos_emb(self, x: torch.Tensor, input_pos: int = None) -> torch.Tensor:
375
+ """
376
+ Applies the absolute position embeddings to the input tensor.
377
+ """
378
+ abs_pos_emb = self.abs_pos_emb
379
+ abs_pos_emb = abs_pos_emb[:, input_pos, :] if input_pos is not None else abs_pos_emb
380
+ return x + abs_pos_emb
381
+
382
+ @torch.no_grad()
383
+ def expand_vocab(
384
+ self, new_vocab_size: int, init_method: str = "gaussian", multiple_of=64, expand_output_layer=True
385
+ ):
386
+ """
387
+ Expands the vocabulary of the model to the new size.
388
+
389
+ Args:
390
+ new_vocab_size (int): The new vocabulary size.
391
+ init_method (str): The initialization method for new embeddings.
392
+ Can be "zero" or "gaussian". Default is "gaussian".
393
+ multiple_of (int): The new vocabulary size must be a multiple of this value. Defaults to 64 to fully
394
+ leverage the power of NVIDIA TensorCore (source 1: https://x.com/karpathy/status/1621578354024677377,
395
+ source 2: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc)
396
+ expand_output_layer (bool): Whether to also expand the output layer. Defaults to True.
397
+
398
+ Returns:
399
+ None
400
+ """
401
+ if new_vocab_size <= self.vocab_size:
402
+ raise ValueError(
403
+ f"New vocabulary size ({new_vocab_size}) must be " f"larger than current size ({self.vocab_size})"
404
+ )
405
+ if new_vocab_size % multiple_of != 0:
406
+ log.debug(f"New vocabulary size must be a multiple of {multiple_of}. Obtained {new_vocab_size}.")
407
+ new_vocab_size = (new_vocab_size // multiple_of + 1) * multiple_of
408
+ log.debug(f"Rounded vocabulary size to {new_vocab_size}.")
409
+ # Resize token embeddings
410
+ old_embeddings = self.tok_embeddings
411
+ tensor_kwargs = {"device": old_embeddings.weight.device, "dtype": old_embeddings.weight.dtype}
412
+ self.tok_embeddings = self._create_token_embeddings(vocab_size=new_vocab_size).to(**tensor_kwargs)
413
+ # Initialize new embeddings
414
+ if init_method not in ["zero", "gaussian"]:
415
+ raise ValueError(f"Unknown initialization method: {init_method}")
416
+ # The default initialization of nn.Embedding is Gaussian, so we don't need to do anything
417
+ # if init_method == "gaussian". Only if init_method == "zero", we need to zero out the new embeddings.
418
+ if init_method == "zero":
419
+ self.tok_embeddings.weight.data[self.vocab_size :].zero_()
420
+
421
+ # Copy old embeddings
422
+ log.debug(
423
+ f"old_embeddings: {old_embeddings.weight.data.shape}, new_embeddings: {self.tok_embeddings.weight.data.shape}, vocab_size: {self.vocab_size}"
424
+ )
425
+ self.tok_embeddings.weight.data[: self.vocab_size] = old_embeddings.weight.data
426
+ # Resize output layer
427
+ old_output = self.output
428
+ self.output = self._create_output_projection(vocab_size=new_vocab_size if expand_output_layer else None)
429
+
430
+ # Initialize new output weights
431
+ self.output.weight.data[self.vocab_size :].zero_()
432
+ # Copy old output weights
433
+ self.output.weight.data[: self.vocab_size] = old_output.weight.data
434
+
435
+ # Update vocab size
436
+ self.vocab_size = new_vocab_size
437
+ log.debug(f"Expanded vocabulary size to {new_vocab_size}")
438
+
439
+ def state_dict(self, *args, **kwargs):
440
+ """
441
+ Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8).
442
+ """
443
+ state_dict = super().state_dict(*args, **kwargs)
444
+ return process_state_dict(state_dict)
445
+
446
+ def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False):
447
+ """
448
+ Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by
449
+ TransformerEngine for FP8).
450
+ """
451
+ state_dict = process_state_dict(state_dict)
452
+ missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign)
453
+ if strict:
454
+ actual_missing_keys = []
455
+ for key in missing_keys:
456
+ if not any(substring in key for substring in substrings_to_ignore):
457
+ actual_missing_keys.append(key)
458
+ if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0:
459
+ raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}")
460
+ missing_keys = actual_missing_keys
461
+ return _IncompatibleKeys(missing_keys, unexpected_keys)
ar_network_vit.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ This module implements a Vision Transformer (ViT) with 2D Rotary Position Embeddings,
18
+ designed for processing image inputs in vision-language models.
19
+
20
+ This module follows Mistral's vision encoder implementation (for their Pistral-12B VLM):
21
+ https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/vision_encoder.py
22
+ """
23
+ from functools import partial
24
+ from typing import Any, Callable, Mapping, Optional, Tuple
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ from .ar_module_normalization import create_norm
30
+ from .ar_network_transformer import TransformerBlock
31
+ from .log import log
32
+
33
+
34
+ def get_vit_config(model_name: str) -> Mapping[str, Any]:
35
+ """
36
+ Get the ViT configuration for a given model name.
37
+ """
38
+ if model_name == "pixtral-12b-vit":
39
+ # The 400M ViT of Pixtral 12B VLM
40
+ return dict(
41
+ dim=1024,
42
+ num_channels=3,
43
+ image_size=1024,
44
+ patch_size=16,
45
+ rope_theta=10000,
46
+ ffn_hidden_size=4096,
47
+ n_layers=24,
48
+ n_heads=16,
49
+ n_kv_heads=16,
50
+ norm_type="rmsnorm",
51
+ norm_eps=1e-5,
52
+ image_token_id=10,
53
+ )
54
+ else:
55
+ raise ValueError(f"Unknown model name: {model_name}")
56
+
57
+
58
+ def precompute_freqs_cis_2d(
59
+ dim: int,
60
+ height: int,
61
+ width: int,
62
+ theta: float,
63
+ ) -> torch.Tensor:
64
+ """
65
+ Precompute 2D complex tensor for rotary position embedding.
66
+
67
+ This function generates a 2D complex tensor used for rotary position embeddings,
68
+ which helps the model understand spatial relationships in the input image.
69
+
70
+ Args:
71
+ dim (int): Dimension of the model (typically the hidden size divided by number of heads).
72
+ height (int): Height of the image in patches.
73
+ width (int): Width of the image in patches.
74
+ theta (float): Base value for the angle calculation, controls the frequency range.
75
+
76
+ Returns:
77
+ torch.Tensor: 2D complex tensor of shape (height, width, dim // 2).
78
+ """
79
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
80
+
81
+ h = torch.arange(height, device=freqs.device)
82
+ w = torch.arange(width, device=freqs.device)
83
+
84
+ freqs_h = torch.outer(h, freqs[::2]).float()
85
+ freqs_w = torch.outer(w, freqs[1::2]).float()
86
+ freqs_2d = torch.cat(
87
+ [
88
+ freqs_h[:, None, :].repeat(1, width, 1),
89
+ freqs_w[None, :, :].repeat(height, 1, 1),
90
+ ],
91
+ dim=-1,
92
+ )
93
+ return torch.polar(torch.ones_like(freqs_2d), freqs_2d)
94
+
95
+
96
+ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
97
+ """
98
+ Reshape frequency tensor for broadcasting with input tensor.
99
+
100
+ This function ensures that the frequency tensor can be properly broadcast
101
+ with the input tensor during the rotary embedding process.
102
+
103
+ Args:
104
+ freqs_cis (torch.Tensor): Frequency tensor from precompute_freqs_cis_2d.
105
+ x (torch.Tensor): Input tensor to be embedded.
106
+
107
+ Returns:
108
+ torch.Tensor: Reshaped frequency tensor ready for broadcasting.
109
+ """
110
+ ndim = x.ndim
111
+ assert 0 <= 1 < ndim, f"ndim is {ndim} but index is {1}"
112
+ assert freqs_cis.shape == (
113
+ x.shape[1],
114
+ x.shape[-1],
115
+ ), f"freqs_cis shape is {freqs_cis.shape} but x shape is {x.shape}"
116
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
117
+ return freqs_cis.view(*shape)
118
+
119
+
120
+ def apply_rotary_emb(
121
+ xq: torch.Tensor,
122
+ xk: torch.Tensor,
123
+ *args,
124
+ freqs_cis: torch.Tensor,
125
+ **kwargs,
126
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ """
128
+ Apply rotary positional embeddings to input tensors.
129
+
130
+ This function applies the rotary positional embeddings to the query and key tensors,
131
+ which helps the model understand spatial relationships in the input.
132
+
133
+ Args:
134
+ xq (torch.Tensor): Query tensor.
135
+ xk (torch.Tensor): Key tensor.
136
+ freqs_cis (torch.Tensor): Precomputed frequencies from precompute_freqs_cis_2d.
137
+ *args: Variable length argument list (unused).
138
+ **kwargs: Arbitrary keyword arguments (unused).
139
+
140
+ Returns:
141
+ Tuple[torch.Tensor, torch.Tensor]: Rotated query and key tensors.
142
+ """
143
+ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
144
+ xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
145
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
146
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
147
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
148
+ return xq_out.type_as(xq), xk_out.type_as(xk)
149
+
150
+
151
+ class VisionTransformer(nn.Module):
152
+ """
153
+ Vision Transformer model for image processing.
154
+
155
+ This class implements a Vision Transformer that processes images using a patch-based approach
156
+ and applies transformer layers with rotary position embeddings.
157
+
158
+ Args:
159
+ dim (int): Dimension of the model (hidden size).
160
+ num_channels (int): Number of input image channels (e.g., 3 for RGB).
161
+ patch_size (int): Size of each image patch (e.g., 16x16 pixels).
162
+ n_layers (int): Number of transformer layers.
163
+ n_heads (int): Number of attention heads.
164
+ ffn_hidden_size (int): Hidden size of the feed-forward network in transformer blocks.
165
+ norm_type (str): Type of normalization to use (e.g., "rmsnorm").
166
+ norm_eps (float): Epsilon value for normalization layers.
167
+ image_size (int): Size of the input image (assumed square).
168
+ rope_theta (float): Base value for rotary position embedding calculation.
169
+ attention_dropout (float): Dropout rate for attention layers.
170
+ hidden_dropout (float): Dropout rate for hidden layers.
171
+ image_token_id (int): Token ID for the image token (if present).
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ dim: int = 1024,
177
+ num_channels: int = 3,
178
+ patch_size: int = 16,
179
+ n_layers: int = 24,
180
+ n_heads: int = 16,
181
+ n_kv_heads: int = None,
182
+ ffn_hidden_size: int = 4096,
183
+ norm_type: str = "rmsnorm",
184
+ norm_eps: float = 1e-5,
185
+ image_size: int = 1024,
186
+ rope_theta: float = 1000000.0,
187
+ image_token_id: int = None,
188
+ ):
189
+ super().__init__()
190
+ self.patch_conv = nn.Conv2d(
191
+ in_channels=num_channels,
192
+ out_channels=dim,
193
+ kernel_size=patch_size,
194
+ stride=patch_size,
195
+ bias=False,
196
+ )
197
+ self.ln_pre = create_norm(norm_type=norm_type, dim=dim, eps=norm_eps)
198
+ if n_kv_heads is None:
199
+ n_kv_heads = n_heads
200
+ layer_args = dict(
201
+ n_layers=n_layers,
202
+ n_heads=n_heads,
203
+ n_kv_heads=n_kv_heads,
204
+ dim=dim,
205
+ use_qk_normalization=False,
206
+ max_seq_len=None,
207
+ max_batch_size=None,
208
+ ffn_hidden_size=ffn_hidden_size,
209
+ norm_type=norm_type,
210
+ norm_eps=norm_eps,
211
+ causal_mask=False, # Full attention in ViT
212
+ head_dim=None,
213
+ insert_cross_attn=False,
214
+ attn_type="full",
215
+ )
216
+
217
+ self.transformer = VisionTransformerBlocks(n_layers=n_layers, args=layer_args)
218
+
219
+ head_dim = dim // n_heads
220
+ assert head_dim % 2 == 0, "ROPE requires even head_dim"
221
+
222
+ self.dim = dim
223
+ self.n_heads = n_heads
224
+ self.max_patches_per_side = image_size // patch_size
225
+ self.image_size = image_size
226
+ self.patch_size = patch_size
227
+ self.rope_theta = rope_theta
228
+ self._freqs_cis: Optional[torch.Tensor] = None
229
+ self.image_token_id = image_token_id
230
+
231
+ num_params = self.get_num_params()
232
+ log.debug(f"Number of model parameters: {round(num_params / 1e6, 3)}M")
233
+
234
+ @classmethod
235
+ def build(
236
+ cls,
237
+ config: Mapping[str, Any],
238
+ ) -> "VisionTransformer":
239
+ """
240
+ Create a Vision Transformer from a configuration dictionary.
241
+
242
+ This class method creates a Vision Transformer from a configuration dictionary,
243
+ which is typically loaded from a JSON file or other configuration source.
244
+
245
+ Args:
246
+ config (Mapping[str, Any]): Configuration dictionary for the Vision Transformer.
247
+
248
+ Returns:
249
+ VisionTransformer: Vision Transformer model instance.
250
+ """
251
+ necessary_keys = ["dim", "num_channels", "patch_size", "n_layers", "n_heads", "ffn_hidden_size", "rope_theta"]
252
+ missing_keys = [k for k in necessary_keys if k not in config]
253
+ assert len(missing_keys) == 0, f"Missing keys in config: {missing_keys}"
254
+ return cls(
255
+ **config,
256
+ )
257
+
258
+ def expand_in_channels(self, new_in_channels: int):
259
+ """
260
+ Expand the input channels of the patch convolution layer.
261
+ This is useful when the input is non-standard, e.g. a 4-channel image with the last channel as the alpha channel.
262
+ Note that you should only call this method after the weight is loaded.
263
+ """
264
+ assert (
265
+ new_in_channels > self.patch_conv.in_channels
266
+ ), "Cannot expand the input channels of the patch convolution layer to be less than the original number of channels."
267
+ log.debug(
268
+ f"Vision encoder in_channels is {self.patch_conv.in_channels}. But you have specified to be {new_in_channels}. We will change it to {new_in_channels} channels with {new_in_channels - self.patch_conv.in_channels} channels of 0s."
269
+ )
270
+ new_conv = nn.Conv2d(
271
+ in_channels=new_in_channels,
272
+ out_channels=self.patch_conv.out_channels,
273
+ kernel_size=self.patch_conv.kernel_size,
274
+ stride=self.patch_conv.stride,
275
+ bias=False,
276
+ )
277
+ new_conv.weight.data[:, : self.patch_conv.in_channels].copy_(self.patch_conv.weight.data)
278
+ new_conv.weight.data[
279
+ :, self.patch_conv.in_channels :
280
+ ].zero_() # zeroize, such that initially it has no effect to output
281
+ self.patch_conv = new_conv
282
+
283
+ @property
284
+ def device(self) -> torch.device:
285
+ """Get the device of the model."""
286
+ return next(self.parameters()).device
287
+
288
+ @property
289
+ def freqs_cis(self) -> torch.Tensor:
290
+ """
291
+ Get or compute the frequency tensor for rotary position embedding.
292
+
293
+ This property lazily initializes and caches the frequency tensor used for
294
+ rotary position embeddings, ensuring it's on the correct device.
295
+
296
+ Returns:
297
+ torch.Tensor: The frequency tensor for rotary position embeddings.
298
+ """
299
+ if self._freqs_cis is None:
300
+ self._freqs_cis = precompute_freqs_cis_2d(
301
+ dim=self.dim // self.n_heads,
302
+ height=self.max_patches_per_side,
303
+ width=self.max_patches_per_side,
304
+ theta=self.rope_theta,
305
+ )
306
+
307
+ if self._freqs_cis.device != self.device:
308
+ self._freqs_cis = self._freqs_cis.to(device=self.device)
309
+
310
+ return self._freqs_cis
311
+
312
+ def forward(
313
+ self,
314
+ x: torch.Tensor,
315
+ ) -> torch.Tensor:
316
+ """
317
+ Forward pass of the Vision Transformer.
318
+
319
+ This method processes the input image through the Vision Transformer,
320
+ including patch embedding, position embedding, and transformer layers.
321
+
322
+ Args:
323
+ x (torch.Tensor): Input tensor of shape (B, C, H, W), where B is batch size,
324
+ C is number of channels, and H, W are height and width.
325
+
326
+ Returns:
327
+ torch.Tensor: Output features of shape (B, N, D), where N is the number of patches
328
+ and D is the embedding dimension.
329
+ """
330
+
331
+ patch_embeds = self.patch_conv(x) # (B, D, Hp, Wp)
332
+ _, _, Hp, Wp = patch_embeds.shape # Patch embeds dim
333
+ patch_embeds = patch_embeds.flatten(2) # (B, D, Hp*Wp)
334
+ patch_embeds = patch_embeds.transpose(1, 2) # (B, Hp*Wp, D)
335
+ patch_embeds = self.ln_pre(patch_embeds) # (B, Hp*Wp, D)
336
+ positions = torch.stack(
337
+ torch.meshgrid(
338
+ torch.arange(Hp),
339
+ torch.arange(Wp),
340
+ indexing="ij",
341
+ ),
342
+ dim=-1,
343
+ ).reshape(-1, 2)
344
+
345
+ freqs_cis = self.freqs_cis[positions[:, 0], positions[:, 1]]
346
+ rope = partial(apply_rotary_emb, freqs_cis=freqs_cis)
347
+ out = self.transformer(patch_embeds, rope=rope)
348
+
349
+ return out
350
+
351
+ def get_num_params(
352
+ self,
353
+ ) -> int:
354
+ """
355
+ Return the number of parameters in the model.
356
+ """
357
+ n_params = sum(p.numel() for p in self.parameters())
358
+ return n_params
359
+
360
+
361
+ class VisionTransformerBlocks(nn.Module):
362
+ """
363
+ Vision Transformer Blocks.
364
+
365
+ This class implements a stack of Transformer blocks used in the Vision Transformer.
366
+
367
+ Args:
368
+ n_layers (int): Number of transformer layers.
369
+ args (Mapping[str, Any]): Arguments for each transformer block, including dimensions,
370
+ """
371
+
372
+ def __init__(
373
+ self,
374
+ n_layers: int,
375
+ args: Mapping[str, Any],
376
+ ):
377
+ super().__init__()
378
+ self.layers = torch.nn.ModuleList()
379
+
380
+ for layer_id in range(n_layers):
381
+ self.layers.append(
382
+ TransformerBlock(
383
+ layer_id=layer_id,
384
+ args=args,
385
+ )
386
+ )
387
+
388
+ def forward(
389
+ self,
390
+ x: torch.Tensor,
391
+ rope: Callable,
392
+ ) -> torch.Tensor:
393
+ """
394
+ Forward pass through the Vision Transformer Blocks.
395
+
396
+ This method applies a series of Transformer blocks to the input tensor,
397
+ using the provided rotary position embedding function.
398
+
399
+ Args:
400
+ x (torch.Tensor): Input tensor of shape (B, N, D), where B is batch size,
401
+ N is the number of patches, and D is the embedding dimension.
402
+ rope (Callable): Rotary position embedding function to be applied in each layer.
403
+
404
+ Returns:
405
+ torch.Tensor: Output tensor after passing through all transformer layers,
406
+ with the same shape as the input.
407
+ """
408
+ for layer in self.layers:
409
+ x = layer(x, input_pos=None, mask=None, rope=rope)
410
+ return x
ar_tokenizer_discrete_video.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import torch
19
+ from einops import rearrange
20
+
21
+ from .ar_tokenizer_quantizers import FSQuantizer
22
+
23
+ # Make sure jit model output consistenly during consecutive calls
24
+ # Check here: https://github.com/pytorch/pytorch/issues/74534
25
+ torch._C._jit_set_texpr_fuser_enabled(False)
26
+
27
+
28
+ def load_jit_model(jit_filepath: str = None, device: str = "cuda") -> torch.jit.ScriptModule:
29
+ """Loads a torch.jit.ScriptModule from a filepath.
30
+
31
+ Args:
32
+ jit_filepath: The filepath to the JIT-compiled model.
33
+ device: The device to load the model onto, default=cuda.
34
+ Returns:
35
+ The JIT compiled model loaded to device and on eval mode.
36
+ """
37
+ # Make sure jit model output consistenly during consecutive calls
38
+ # Check here: https://github.com/pytorch/pytorch/issues/74534
39
+ torch._C._jit_set_texpr_fuser_enabled(False)
40
+
41
+ model = torch.jit.load(jit_filepath)
42
+ return model.eval().to(device)
43
+
44
+
45
+ class BaseDiscreteVideoFSQTokenizer(torch.nn.Module):
46
+ """
47
+ A base class for Discrete Video FSQ Tokenizer that handles data type conversions, and normalization
48
+ using provided mean and standard deviation values for latent space representation.
49
+ Derived classes should load pre-trained encoder and decoder components into a encoder and decoder attributes.
50
+
51
+ Attributes:
52
+ encoder (Module | Callable): Encoder loaded from storage.
53
+ decoder (Module | Callable): Decoder loaded from storage.
54
+ dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
55
+
56
+ Args:
57
+ name (str): Name of the model, used for differentiating cache file paths.
58
+ latent_ch (int, optional): Number of latent channels (default is 6).
59
+ is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
60
+ pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level.
61
+ latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level.
62
+ max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow.
63
+ level (list[int]): The level defined in FSQ quantizer.
64
+ compression_ratio (list[int]): The compression factor for (T, H, W).
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ name: str,
70
+ latent_ch: int = 6,
71
+ is_bf16: bool = True,
72
+ pixel_chunk_duration: int = 25,
73
+ latent_chunk_duration: int = 4,
74
+ max_enc_batch_size: int = 8,
75
+ max_dec_batch_size: int = 4,
76
+ levels: list[int] = [8, 8, 8, 5, 5, 5],
77
+ compression_ratio: list[int] = [8, 16, 16],
78
+ ):
79
+ super().__init__()
80
+ self.channel = latent_ch
81
+ self.name = name
82
+ dtype = torch.bfloat16 if is_bf16 else torch.float32
83
+ self.dtype = dtype
84
+ self.pixel_chunk_duration = pixel_chunk_duration
85
+ self.latent_chunk_duration = latent_chunk_duration
86
+ self.max_enc_batch_size = max_enc_batch_size
87
+ self.max_dec_batch_size = max_dec_batch_size
88
+ self.levels = levels
89
+ self.compress_ratio = compression_ratio
90
+ self.fsq_quantizer = FSQuantizer(levels)
91
+
92
+ @property
93
+ def latent_ch(self) -> int:
94
+ """
95
+ Returns the number of latent channels in the tokenizer.
96
+ """
97
+ return self.channel
98
+
99
+ @torch.no_grad()
100
+ def encode(self, state: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor:
101
+ B, C, T, H, W = state.shape
102
+ if pixel_chunk_duration is None:
103
+ # Use the default pixel chunk duration and latent chunk duration
104
+ pixel_chunk_duration = self.pixel_chunk_duration
105
+ latent_chunk_duration = self.latent_chunk_duration
106
+ else:
107
+ # Update the latent chunk duration based on the given pixel chunk duration
108
+ latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0]
109
+
110
+ assert (
111
+ T % pixel_chunk_duration == 0
112
+ ), f"Temporal dimension {T} is not divisible by chunk_length {pixel_chunk_duration}"
113
+ state = rearrange(state, "b c (n t) h w -> (b n) c t h w", t=pixel_chunk_duration)
114
+
115
+ # use max_enc_batch_size to avoid OOM
116
+ if state.shape[0] > self.max_enc_batch_size:
117
+ quantized_out_list = []
118
+ indices_list = []
119
+ for i in range(0, state.shape[0], self.max_enc_batch_size):
120
+ indices, quantized_out, _ = self.encoder(state[i : i + self.max_enc_batch_size].to(self.dtype))
121
+ quantized_out_list.append(quantized_out)
122
+ indices_list.append(indices)
123
+ quantized_out = torch.cat(quantized_out_list, dim=0)
124
+ indices = torch.cat(indices_list, dim=0)
125
+ else:
126
+ indices, quantized_out, _ = self.encoder(state.to(self.dtype))
127
+ assert quantized_out.shape[2] == latent_chunk_duration
128
+ return rearrange(quantized_out, "(b n) c t h w -> b c (n t) h w", b=B), rearrange(
129
+ indices, "(b n) t h w -> b (n t) h w", b=B
130
+ )
131
+
132
+ @torch.no_grad()
133
+ def decode(self, indices: torch.Tensor, pixel_chunk_duration: Optional[int] = None) -> torch.Tensor:
134
+ B, T, _, _ = indices.shape
135
+ if pixel_chunk_duration is None:
136
+ pixel_chunk_duration = self.pixel_chunk_duration
137
+ latent_chunk_duration = self.latent_chunk_duration
138
+ else:
139
+ latent_chunk_duration = 1 + (pixel_chunk_duration - 1) // self.compress_ratio[0]
140
+ assert (
141
+ T % latent_chunk_duration == 0
142
+ ), f"Temporal dimension {T} is not divisible by chunk_length {latent_chunk_duration}"
143
+ indices = rearrange(indices, "b (n t) h w -> (b n) t h w", t=latent_chunk_duration)
144
+
145
+ # use max_dec_batch_size to avoid OOM
146
+ if indices.shape[0] > self.max_dec_batch_size:
147
+ state = []
148
+ for i in range(0, indices.shape[0], self.max_dec_batch_size):
149
+ state.append(self.decoder(indices[i : i + self.max_dec_batch_size]))
150
+ state = torch.cat(state, dim=0)
151
+ else:
152
+ state = self.decoder(indices)
153
+
154
+ assert state.shape[2] == pixel_chunk_duration
155
+ return rearrange(state, "(b n) c t h w -> b c (n t) h w", b=B)
156
+
157
+ def reset_dtype(self, *args, **kwargs):
158
+ """
159
+ Resets the data type of the encoder and decoder to the model's default data type.
160
+
161
+ Args:
162
+ *args, **kwargs: Unused, present to allow flexibility in method calls.
163
+ """
164
+ del args, kwargs
165
+ self.decoder.to(self.dtype)
166
+ self.encoder.to(self.dtype)
167
+
168
+
169
+ class DiscreteVideoFSQJITTokenizer(BaseDiscreteVideoFSQTokenizer):
170
+ """
171
+ A JIT compiled Discrete Video FSQ Tokenizer that loads pre-trained encoder
172
+ and decoder components from a remote store, handles data type conversions, and normalization
173
+ using provided mean and standard deviation values for latent space representation.
174
+
175
+ Attributes:
176
+ encoder (Module): The JIT compiled encoder loaded from storage.
177
+ decoder (Module): The JIT compiled decoder loaded from storage.
178
+ dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
179
+
180
+ Args:
181
+ enc_fp (str): File path to the encoder's JIT file on the remote store.
182
+ dec_fp (str): File path to the decoder's JIT file on the remote store.
183
+ name (str): Name of the model, used for differentiating cache file paths.
184
+ latent_ch (int, optional): Number of latent channels (default is 6).
185
+ is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
186
+ pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level.
187
+ latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level.
188
+ max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow.
189
+ level (list[int]): The level defined in FSQ quantizer.
190
+ compression_ratio (list[int]): The compression factor for (T, H, W).
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ enc_fp: str,
196
+ dec_fp: str,
197
+ name: str,
198
+ latent_ch: int = 6,
199
+ is_bf16: bool = True,
200
+ pixel_chunk_duration: int = 25,
201
+ latent_chunk_duration: int = 4,
202
+ max_enc_batch_size: int = 8,
203
+ max_dec_batch_size: int = 4,
204
+ levels: list[int] = [8, 8, 8, 5, 5, 5],
205
+ compression_ratio: list[int] = [8, 16, 16],
206
+ ):
207
+ super().__init__(
208
+ name,
209
+ latent_ch,
210
+ is_bf16,
211
+ pixel_chunk_duration,
212
+ latent_chunk_duration,
213
+ max_enc_batch_size,
214
+ max_dec_batch_size,
215
+ levels,
216
+ compression_ratio,
217
+ )
218
+
219
+ self.load_encoder(enc_fp)
220
+ self.load_decoder(dec_fp)
221
+
222
+ def load_encoder(self, enc_fp: str) -> None:
223
+ """
224
+ Load the encoder from the remote store.
225
+
226
+ Args:
227
+ - enc_fp (str): File path to the encoder's JIT file on the remote store.
228
+ """
229
+ self.encoder = load_jit_model(enc_fp, device="cuda")
230
+ self.encoder.eval()
231
+ for param in self.encoder.parameters():
232
+ param.requires_grad = False
233
+ self.encoder.to(self.dtype)
234
+
235
+ def load_decoder(self, dec_fp: str) -> None:
236
+ """
237
+ Load the decoder from the remote store.
238
+
239
+ Args:
240
+ - dec_fp (str): File path to the decoder's JIT file on the remote store.
241
+ """
242
+ self.decoder = load_jit_model(dec_fp, device="cuda")
243
+ self.decoder.eval()
244
+ for param in self.decoder.parameters():
245
+ param.requires_grad = False
246
+ self.decoder.to(self.dtype)
247
+
248
+
249
+ class DiscreteVideoFSQStateDictTokenizer(BaseDiscreteVideoFSQTokenizer):
250
+ """
251
+ A Discrete Video FSQ Tokenizer that loads weights from pre-trained JITed encoder
252
+ into as nn.Module so that encoder can be "torch.compile()" and JITed decoder, so it can be torch.compiled,
253
+ handles data type conversions, and normalization using provided mean and standard deviation values for latent
254
+ space representation.
255
+
256
+ Attributes:
257
+ tokenizer_module (Module): Tokenizer module with weights loaded from JIT checkpoints
258
+ encoder (Callable): tokenizer_module's encode method
259
+ decoder (Callable): tokenizer_module's decode method
260
+ dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.
261
+
262
+ Args:
263
+ enc_fp (str): File path to the encoder's JIT file on the remote store.
264
+ dec_fp (str): File path to the decoder's JIT file on the remote store.
265
+ tokenizer_module (Module): Tokenizer module that will have it's weights loaded
266
+ name (str): Name of the model, used for differentiating cache file paths.
267
+ latent_ch (int, optional): Number of latent channels (default is 6).
268
+ is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
269
+ pixel_chunk_duration (int): The duration (in number of frames) of each chunk of video data at the pixel level.
270
+ latent_chunk_duration (int): The duration (in number of frames) of each chunk at the latent representation level.
271
+ max_enc_batch_size (int): The maximum batch size to process in one go to avoid memory overflow.
272
+ level (list[int]): The level defined in FSQ quantizer.
273
+ compression_ratio (list[int]): The compression factor for (T, H, W).
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ enc_fp: str,
279
+ dec_fp: str,
280
+ tokenizer_module: torch.nn.Module,
281
+ name: str,
282
+ latent_ch: int = 6,
283
+ is_bf16: bool = True,
284
+ pixel_chunk_duration: int = 25,
285
+ latent_chunk_duration: int = 4,
286
+ max_enc_batch_size: int = 8,
287
+ max_dec_batch_size: int = 4,
288
+ levels: list[int] = [8, 8, 8, 5, 5, 5],
289
+ compression_ratio: list[int] = [8, 16, 16],
290
+ ):
291
+ super().__init__(
292
+ name,
293
+ latent_ch,
294
+ is_bf16,
295
+ pixel_chunk_duration,
296
+ latent_chunk_duration,
297
+ max_enc_batch_size,
298
+ max_dec_batch_size,
299
+ levels,
300
+ compression_ratio,
301
+ )
302
+
303
+ self.load_encoder_and_decoder(enc_fp, dec_fp, tokenizer_module)
304
+
305
+ def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, tokenizer_module: torch.nn.Module) -> None:
306
+ """
307
+ Load the encoder from the remote store.
308
+
309
+ Args:
310
+ - enc_fp (str): File path to the encoder's JIT file on the remote store.
311
+ - def_fp (str): File path to the decoder's JIT file on the remote store.
312
+ - tokenizer_module (Module): Tokenizer module that was used to create JIT checkpoints
313
+ """
314
+ self.decoder = load_jit_model(dec_fp)
315
+
316
+ self.decoder.eval()
317
+ for param in self.decoder.parameters():
318
+ param.requires_grad = False
319
+ self.decoder.to(self.dtype)
320
+
321
+ encoder_sd = load_jit_model(enc_fp).state_dict()
322
+
323
+ del tokenizer_module.post_quant_conv
324
+ del tokenizer_module.decoder
325
+
326
+ state_dict = {
327
+ k: v
328
+ for k, v in (encoder_sd).items()
329
+ # Variables captured by JIT
330
+ if k
331
+ not in (
332
+ "encoder.patcher3d.wavelets",
333
+ "encoder.patcher3d._arange",
334
+ "encoder.patcher3d.patch_size_buffer",
335
+ "quantizer._levels",
336
+ "quantizer._basis",
337
+ "quantizer.implicit_codebook",
338
+ )
339
+ }
340
+
341
+ tokenizer_module.load_state_dict(state_dict)
342
+
343
+ tokenizer_module.eval()
344
+ for param in tokenizer_module.parameters():
345
+ param.requires_grad = False
346
+ tokenizer_module.to(self.dtype)
347
+
348
+ self.tokenizer_module = tokenizer_module
349
+ self.encoder = self.tokenizer_module.encode
350
+
351
+ def reset_dtype(self, *args, **kwargs):
352
+ """
353
+ Resets the data type of the encoder and decoder to the model's default data type.
354
+
355
+ Args:
356
+ *args, **kwargs: Unused, present to allow flexibility in method calls.
357
+ """
358
+ del args, kwargs
359
+ self.decoder.to(self.dtype)
360
+ self.tokenizer_module.to(self.dtype)
ar_tokenizer_image_text_tokenizer.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ import transformers
21
+ from transformers import AutoImageProcessor
22
+ from transformers.image_utils import ImageInput, is_valid_image, load_image
23
+
24
+ from .ar_tokenizer_text_tokenizer import TextTokenizer
25
+ from .log import log
26
+
27
+ # Configuration for different vision-language models
28
+ IMAGE_CONFIGS = {
29
+ "pixtral": {
30
+ "patch_size": 16,
31
+ "image_token": "[IMG]",
32
+ "image_break_token": "[IMG_BREAK]",
33
+ "image_end_token": "[IMG_END]",
34
+ }
35
+ }
36
+
37
+ # Chat template for Pixtral-12B-Instruct
38
+ PIXTRAL_CHAT_TEMPLATE = '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["content"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {{- message["content"] + eos_token}}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}'
39
+
40
+
41
+ # Copied from transformers.models.pixtral.processing_pixtral.is_url
42
+ def is_url(val) -> bool:
43
+ """Check if the given value is a URL."""
44
+ return isinstance(val, str) and val.startswith("http")
45
+
46
+
47
+ # Copied from transformers.models.pixtral.processing_pixtral.is_image_or_image_url
48
+ def is_image_or_image_url(elem):
49
+ """Check if the given element is an image or an image URL."""
50
+ return is_url(elem) or is_valid_image(elem)
51
+
52
+
53
+ def load_image_list(
54
+ image_list: List[Union[str, "PIL.Image.Image"]], timeout: Optional[float] = None
55
+ ) -> List["PIL.Image.Image"]:
56
+ """
57
+ Load a list of images.
58
+
59
+ Args:
60
+ image_list (List[Union[str, PIL.Image.Image]]): The list of images to load.
61
+ timeout (Optional[float]): The timeout for loading the image.
62
+
63
+ Returns:
64
+ List[PIL.Image.Image]: The list of loaded images.
65
+ """
66
+ return [load_image(image, timeout=timeout) for image in image_list]
67
+
68
+
69
+ class ImageTextTokenizer(TextTokenizer):
70
+ """
71
+ Image-text tokenizer class that extends the text tokenizer to support vision tokens as well.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ model_family: str,
77
+ is_instruct_model: bool,
78
+ tokenizer_path: str,
79
+ image_processor_path: str,
80
+ ):
81
+ """
82
+ Initialize the ImageTextTokenizer.
83
+
84
+ Args:
85
+ model_family (str): The model family.
86
+ is_instruct_model (bool): Whether the model is an instruct model.
87
+ s3_credential_path (str): The path to the s3 credential file. Defaults to "credentials/pbss_dir.secret".
88
+
89
+ Raises:
90
+ AssertionError: If the model family is not supported or if the transformers version is incompatible.
91
+ """
92
+ super().__init__(
93
+ model_family=model_family,
94
+ is_instruct_model=is_instruct_model,
95
+ local_path=tokenizer_path,
96
+ )
97
+ assert model_family in ["pixtral"], f"Unsupported model family: {model_family}"
98
+ if model_family == "pixtral":
99
+ # Need transformers>=4.45.0
100
+ assert transformers.__version__ >= "4.45.0", "Pixtral requires transformers>=4.45.0"
101
+ assert is_instruct_model, "Pixtral requires is_instruct_model=True"
102
+ if not hasattr(self.tokenizer, "chat_template") or self.tokenizer.chat_template is None:
103
+ setattr(self.tokenizer, "chat_template", PIXTRAL_CHAT_TEMPLATE)
104
+ log.debug(f"Pixtral tokenizer chat template set to: {PIXTRAL_CHAT_TEMPLATE}")
105
+
106
+ # Set up image-specific configurations
107
+ image_config = IMAGE_CONFIGS[model_family]
108
+ self.patch_size = image_config["patch_size"]
109
+ self.image_token = image_config["image_token"]
110
+ self.image_break_token = image_config["image_break_token"]
111
+ self.image_end_token = image_config["image_end_token"]
112
+
113
+ # Initialize the image processor
114
+ self.image_processor = AutoImageProcessor.from_pretrained(image_processor_path)
115
+
116
+ def encode(
117
+ self,
118
+ text: Union[str, List[str], List[int]],
119
+ *, # Enforce keyword-only arguments
120
+ images: Optional[ImageInput] = None,
121
+ image_kwargs: Optional[Dict[str, Any]] = None,
122
+ **text_kwargs,
123
+ ) -> List[int]:
124
+ """
125
+ Process the images and return the tokenized images and text.
126
+
127
+ Args:
128
+ text (`str`, `List[str]`, `List[List[str]]`):
129
+ The sequence or batch of sequences to be encoded.
130
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
131
+ The image or batch of images to be prepared.
132
+ image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
133
+ **text_kwargs: Additional keyword arguments for text processing.
134
+
135
+ Returns:
136
+ A dictionary with the following fields:
137
+ - **input_ids** -- List of token ids to be fed to a model.
138
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model.
139
+ - **pixel_values** -- Pixel values to be fed to a model.
140
+
141
+ Raises:
142
+ ValueError: If the input images are in an invalid format.
143
+ """
144
+
145
+ output_dict, image_inputs = {}, {}
146
+ if images is not None:
147
+ # Preprocess images
148
+ if is_image_or_image_url(images):
149
+ images = [[images]]
150
+ elif isinstance(images, list) and is_image_or_image_url(images[0]):
151
+ images = [images]
152
+ elif (
153
+ not isinstance(images, list)
154
+ and not isinstance(images[0], list)
155
+ and not is_image_or_image_url(images[0][0])
156
+ ):
157
+ raise ValueError(
158
+ "Invalid input images. Please provide a single image or a list of images or a list of list of images."
159
+ )
160
+
161
+ # Load and process images
162
+ images = [load_image_list(sample) for sample in images]
163
+ image_kwargs = image_kwargs or {}
164
+ image_inputs = self.image_processor(images, patch_size=self.patch_size, return_tensors="np", **image_kwargs)
165
+
166
+ # Validate image inputs
167
+ assert "pixel_values" in image_inputs, "pixel_values not found in image_inputs"
168
+ assert "image_sizes" in image_inputs, "image_sizes not found in image_inputs"
169
+ assert len(image_inputs.keys()) == 2, "Only one key is allowed in image_inputs, got {}".format(
170
+ image_inputs.keys()
171
+ )
172
+
173
+ # Extract pixel values and image sizes
174
+ pixel_values = image_inputs["pixel_values"][0]
175
+ image_sizes = image_inputs["image_sizes"][0]
176
+ unique_sizes = np.unique(image_sizes, axis=0)
177
+
178
+ assert len(unique_sizes) == 1, "All images must have the same size, got {}".format(unique_sizes)
179
+
180
+ # Convert pixel values to PyTorch tensor
181
+ pixel_values = np.asarray(pixel_values)
182
+ pixel_values = torch.from_numpy(pixel_values)
183
+ output_dict["pixel_values"] = pixel_values
184
+ output_dict["image_sizes"] = image_sizes
185
+
186
+ # Expand image tokens in text
187
+ if image_inputs.get("pixel_values") is not None:
188
+ replace_strings = []
189
+ # Calculate the number of tokens needed for each image and create a placeholder
190
+ for image_size in image_sizes:
191
+ height, width = image_size
192
+ num_height_tokens = height // self.patch_size
193
+ num_width_tokens = width // self.patch_size
194
+ replace_tokens = [[self.image_token] * num_width_tokens + [self.image_break_token]] * num_height_tokens
195
+ # Flatten list
196
+ replace_tokens = [item for sublist in replace_tokens for item in sublist]
197
+ replace_tokens[-1] = self.image_end_token
198
+ replace_str = "".join(replace_tokens)
199
+ replace_strings.append(replace_str)
200
+ text = text.replace(self.image_token, "<placeholder>", 1)
201
+
202
+ # Replace placeholders with actual image token sequences
203
+ while "<placeholder>" in text:
204
+ replace_str = replace_strings.pop(0)
205
+ text = text.replace("<placeholder>", replace_str, 1)
206
+
207
+ # Encode the text
208
+ text_inputs = super(ImageTextTokenizer, self).encode(text, **text_kwargs)
209
+
210
+ output_dict["input_ids"] = text_inputs
211
+ return output_dict
212
+
213
+ def apply_chat_template(
214
+ self,
215
+ conversation: List[Dict[str, Any]] | List[List[Dict[str, Any]]],
216
+ *,
217
+ images: Optional[ImageInput] = None,
218
+ image_kwargs: Optional[Dict[str, Any]] = None,
219
+ add_generation_prompt: bool = False,
220
+ tokenize: bool = True,
221
+ padding: bool = False,
222
+ truncation: bool = False,
223
+ max_length: Optional[int] = None,
224
+ return_tensors: Optional[str] = None,
225
+ return_dict: bool = True,
226
+ return_assistant_tokens_mask: bool = False,
227
+ generation_prefix: str = "",
228
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
229
+ **kwargs,
230
+ ):
231
+ """
232
+ Apply the chat template to the conversation.
233
+
234
+ Args:
235
+ conversation (List[Dict[str, Any]] | List[List[Dict[str, Any]]]): The conversation to process.
236
+ images (Optional[ImageInput]): Images to include in the conversation.
237
+ image_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for image processing.
238
+ add_generation_prompt (bool): Whether to add a generation prompt.
239
+ tokenize (bool): Whether to tokenize the output.
240
+ padding (bool): Whether to pad the output.
241
+ truncation (bool): Whether to truncate the output.
242
+ max_length (Optional[int]): Maximum length of the output.
243
+ return_tensors (Optional[str]): The type of tensors to return.
244
+ return_dict (bool): Whether to return a dictionary.
245
+ return_assistant_tokens_mask (bool): Whether to return the assistant tokens mask.
246
+ generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "".
247
+ tokenizer_kwargs (Optional[Dict[str, Any]]): Additional keyword arguments for the tokenizer.
248
+ **kwargs: Additional keyword arguments.
249
+
250
+ Returns:
251
+ The processed conversation with applied chat template.
252
+
253
+ Raises:
254
+ AssertionError: If return_dict is False or if the conversation format is invalid.
255
+ """
256
+ assert return_dict, "return_dict must be True for ImageTextTokenizer"
257
+ assert isinstance(conversation, list), "conversation must be a list"
258
+ if isinstance(conversation[0], list):
259
+ assert len(conversation) == 1, "Only support single-conversation input, got {}".format(conversation)
260
+ conversation = conversation[0]
261
+
262
+ # Extract images from the conversation if not provided
263
+ if images is None:
264
+ images = []
265
+ for msg in conversation:
266
+ if msg.get("images", None) is not None:
267
+ images = images + (msg["images"])
268
+ images = load_image_list(images)
269
+ # In case the input does not have images, will ignore
270
+ # Useful in feeding VLM inputs with and without images
271
+ if isinstance(images, list) and len(images) == 0:
272
+ images = None
273
+
274
+ # Apply the chat template to the text
275
+ text = super().apply_chat_template(
276
+ conversation,
277
+ tokenize=False,
278
+ add_generation_prompt=add_generation_prompt,
279
+ padding=padding,
280
+ truncation=truncation,
281
+ max_length=max_length,
282
+ return_tensors=return_tensors,
283
+ return_dict=False,
284
+ return_assistant_tokens_mask=return_assistant_tokens_mask,
285
+ generation_prefix=generation_prefix,
286
+ tokenizer_kwargs=tokenizer_kwargs,
287
+ **kwargs,
288
+ )
289
+
290
+ if tokenizer_kwargs is None:
291
+ tokenizer_kwargs = {}
292
+
293
+ # Encode the text and images
294
+ output = self.encode(
295
+ text,
296
+ images=images,
297
+ image_kwargs=image_kwargs,
298
+ tokenize=tokenize,
299
+ padding=padding,
300
+ truncation=truncation,
301
+ max_length=max_length,
302
+ add_special_tokens=False,
303
+ return_tensors=return_tensors,
304
+ **tokenizer_kwargs,
305
+ )
306
+ return output
307
+
308
+ @property
309
+ def model_input_names(self):
310
+ """
311
+ Get the combined model input names from both the text tokenizer and image processor.
312
+
313
+ Returns:
314
+ List[str]: A list of unique input names.
315
+ """
316
+ tokenizer_input_names = self.tokenizer.model_input_names
317
+ image_processor_input_names = self.image_processor.model_input_names
318
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
ar_tokenizer_modules.py ADDED
@@ -0,0 +1,560 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The model definition for 3D layers
17
+
18
+ Adapted from: https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/
19
+ magvit2_pytorch/magvit2_pytorch.py#L889
20
+
21
+ [MIT License Copyright (c) 2023 Phil Wang]
22
+ https://github.com/lucidrains/magvit2-pytorch/blob/9f49074179c912736e617d61b32be367eb5f993a/LICENSE
23
+ """
24
+ import math
25
+ from typing import Tuple, Union
26
+
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn as nn
30
+ import torch.nn.functional as F
31
+
32
+ from .ar_tokenizer_patching import Patcher3D, UnPatcher3D
33
+ from .ar_tokenizer_utils import (
34
+ CausalNormalize,
35
+ batch2space,
36
+ batch2time,
37
+ cast_tuple,
38
+ is_odd,
39
+ nonlinearity,
40
+ replication_pad,
41
+ space2batch,
42
+ time2batch,
43
+ )
44
+ from .log import log
45
+
46
+
47
+ class CausalConv3d(nn.Module):
48
+ def __init__(
49
+ self,
50
+ chan_in: int = 1,
51
+ chan_out: int = 1,
52
+ kernel_size: Union[int, Tuple[int, int, int]] = 3,
53
+ pad_mode: str = "constant",
54
+ **kwargs,
55
+ ):
56
+ super().__init__()
57
+ kernel_size = cast_tuple(kernel_size, 3)
58
+
59
+ time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
60
+
61
+ assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
62
+
63
+ dilation = kwargs.pop("dilation", 1)
64
+ stride = kwargs.pop("stride", 1)
65
+ time_stride = kwargs.pop("time_stride", 1)
66
+ time_dilation = kwargs.pop("time_dilation", 1)
67
+ padding = kwargs.pop("padding", 1)
68
+
69
+ self.pad_mode = pad_mode
70
+ time_pad = time_dilation * (time_kernel_size - 1) + (1 - time_stride)
71
+ self.time_pad = time_pad
72
+
73
+ self.spatial_pad = (padding, padding, padding, padding)
74
+
75
+ stride = (time_stride, stride, stride)
76
+ dilation = (time_dilation, dilation, dilation)
77
+ self.conv3d = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
78
+
79
+ def _replication_pad(self, x: torch.Tensor) -> torch.Tensor:
80
+ x_prev = x[:, :, :1, ...].repeat(1, 1, self.time_pad, 1, 1)
81
+ x = torch.cat([x_prev, x], dim=2)
82
+ padding = self.spatial_pad + (0, 0)
83
+ return F.pad(x, padding, mode=self.pad_mode, value=0.0)
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ x = self._replication_pad(x)
87
+ return self.conv3d(x)
88
+
89
+
90
+ class CausalHybridUpsample3d(nn.Module):
91
+ def __init__(self, in_channels: int, spatial_up: bool = True, temporal_up: bool = True, **ignore_kwargs) -> None:
92
+ super().__init__()
93
+ self.conv1 = (
94
+ CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=1, padding=0)
95
+ if temporal_up
96
+ else nn.Identity()
97
+ )
98
+ self.conv2 = (
99
+ CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=1, time_stride=1, padding=1)
100
+ if spatial_up
101
+ else nn.Identity()
102
+ )
103
+ self.conv3 = (
104
+ CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0)
105
+ if spatial_up or temporal_up
106
+ else nn.Identity()
107
+ )
108
+ self.spatial_up = spatial_up
109
+ self.temporal_up = temporal_up
110
+
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ if not self.spatial_up and not self.temporal_up:
113
+ return x
114
+
115
+ # hybrid upsample temporally.
116
+ if self.temporal_up:
117
+ time_factor = 1.0 + 1.0 * (x.shape[2] > 1)
118
+ if isinstance(time_factor, torch.Tensor):
119
+ time_factor = time_factor.item()
120
+ x = x.repeat_interleave(int(time_factor), dim=2)
121
+ x = x[..., int(time_factor - 1) :, :, :]
122
+ x = self.conv1(x) + x
123
+
124
+ # hybrid upsample spatially.
125
+ if self.spatial_up:
126
+ x = x.repeat_interleave(2, dim=3).repeat_interleave(2, dim=4)
127
+ x = self.conv2(x) + x
128
+
129
+ # final 1x1x1 conv.
130
+ x = self.conv3(x)
131
+ return x
132
+
133
+
134
+ class CausalHybridDownsample3d(nn.Module):
135
+ def __init__(
136
+ self, in_channels: int, spatial_down: bool = True, temporal_down: bool = True, **ignore_kwargs
137
+ ) -> None:
138
+ super().__init__()
139
+ self.conv1 = (
140
+ CausalConv3d(in_channels, in_channels, kernel_size=(1, 3, 3), stride=2, time_stride=1, padding=0)
141
+ if spatial_down
142
+ else nn.Identity()
143
+ )
144
+ self.conv2 = (
145
+ CausalConv3d(in_channels, in_channels, kernel_size=(3, 1, 1), stride=1, time_stride=2, padding=0)
146
+ if temporal_down
147
+ else nn.Identity()
148
+ )
149
+ self.conv3 = (
150
+ CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, time_stride=1, padding=0)
151
+ if spatial_down or temporal_down
152
+ else nn.Identity()
153
+ )
154
+ self.spatial_down = spatial_down
155
+ self.temporal_down = temporal_down
156
+
157
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
158
+ if not self.spatial_down and not self.temporal_down:
159
+ return x
160
+
161
+ # hybrid downsample spatially.
162
+ if self.spatial_down:
163
+ pad = (0, 1, 0, 1, 0, 0)
164
+ x = F.pad(x, pad, mode="constant", value=0)
165
+ x1 = self.conv1(x)
166
+ x2 = F.avg_pool3d(x, kernel_size=(1, 2, 2), stride=(1, 2, 2))
167
+ x = x1 + x2
168
+
169
+ # hybrid downsample temporally.
170
+ if self.temporal_down:
171
+ x = replication_pad(x)
172
+ x1 = self.conv2(x)
173
+ x2 = F.avg_pool3d(x, kernel_size=(2, 1, 1), stride=(2, 1, 1))
174
+ x = x1 + x2
175
+
176
+ # final 1x1x1 conv.
177
+ x = self.conv3(x)
178
+ return x
179
+
180
+
181
+ class CausalResnetBlockFactorized3d(nn.Module):
182
+ def __init__(self, *, in_channels: int, out_channels: int = None, dropout: float, num_groups: int) -> None:
183
+ super().__init__()
184
+ self.in_channels = in_channels
185
+ out_channels = in_channels if out_channels is None else out_channels
186
+
187
+ self.norm1 = CausalNormalize(in_channels, num_groups=1)
188
+ self.conv1 = nn.Sequential(
189
+ CausalConv3d(in_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
190
+ CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
191
+ )
192
+ self.norm2 = CausalNormalize(out_channels, num_groups=num_groups)
193
+ self.dropout = torch.nn.Dropout(dropout)
194
+ self.conv2 = nn.Sequential(
195
+ CausalConv3d(out_channels, out_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
196
+ CausalConv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
197
+ )
198
+ self.nin_shortcut = (
199
+ CausalConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
200
+ if in_channels != out_channels
201
+ else nn.Identity()
202
+ )
203
+
204
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
205
+ h = x
206
+ h = self.norm1(h)
207
+ h = nonlinearity(h)
208
+ h = self.conv1(h)
209
+
210
+ h = self.norm2(h)
211
+ h = nonlinearity(h)
212
+ h = self.dropout(h)
213
+ h = self.conv2(h)
214
+ x = self.nin_shortcut(x)
215
+
216
+ return x + h
217
+
218
+
219
+ class CausalAttnBlock(nn.Module):
220
+ def __init__(self, in_channels: int, num_groups: int) -> None:
221
+ super().__init__()
222
+
223
+ self.norm = CausalNormalize(in_channels, num_groups=num_groups)
224
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
225
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
226
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
227
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
228
+
229
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
230
+ h_ = x
231
+ h_ = self.norm(h_)
232
+ q = self.q(h_)
233
+ k = self.k(h_)
234
+ v = self.v(h_)
235
+
236
+ # compute attention
237
+ q, batch_size = time2batch(q)
238
+ k, batch_size = time2batch(k)
239
+ v, batch_size = time2batch(v)
240
+
241
+ b, c, h, w = q.shape
242
+ q = q.reshape(b, c, h * w)
243
+ q = q.permute(0, 2, 1)
244
+ k = k.reshape(b, c, h * w)
245
+ w_ = torch.bmm(q, k)
246
+ w_ = w_ * (int(c) ** (-0.5))
247
+ w_ = F.softmax(w_, dim=2)
248
+
249
+ # attend to values
250
+ v = v.reshape(b, c, h * w)
251
+ w_ = w_.permute(0, 2, 1)
252
+ h_ = torch.bmm(v, w_)
253
+ h_ = h_.reshape(b, c, h, w)
254
+
255
+ h_ = batch2time(h_, batch_size)
256
+ h_ = self.proj_out(h_)
257
+ return x + h_
258
+
259
+
260
+ class CausalTemporalAttnBlock(nn.Module):
261
+ def __init__(self, in_channels: int, num_groups: int) -> None:
262
+ super().__init__()
263
+
264
+ self.norm = CausalNormalize(in_channels, num_groups=num_groups)
265
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
266
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
267
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
268
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
269
+
270
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
271
+ h_ = x
272
+ h_ = self.norm(h_)
273
+ q = self.q(h_)
274
+ k = self.k(h_)
275
+ v = self.v(h_)
276
+
277
+ # compute attention
278
+ q, batch_size, height = space2batch(q)
279
+ k, _, _ = space2batch(k)
280
+ v, _, _ = space2batch(v)
281
+
282
+ bhw, c, t = q.shape
283
+ q = q.permute(0, 2, 1) # (bhw, t, c)
284
+ k = k.permute(0, 2, 1) # (bhw, t, c)
285
+ v = v.permute(0, 2, 1) # (bhw, t, c)
286
+
287
+ w_ = torch.bmm(q, k.permute(0, 2, 1)) # (bhw, t, t)
288
+ w_ = w_ * (int(c) ** (-0.5))
289
+
290
+ # Apply causal mask
291
+ mask = torch.tril(torch.ones_like(w_))
292
+ w_ = w_.masked_fill(mask == 0, float("-inf"))
293
+ w_ = F.softmax(w_, dim=2)
294
+
295
+ # attend to values
296
+ h_ = torch.bmm(w_, v) # (bhw, t, c)
297
+ h_ = h_.permute(0, 2, 1).reshape(bhw, c, t) # (bhw, c, t)
298
+
299
+ h_ = batch2space(h_, batch_size, height)
300
+ h_ = self.proj_out(h_)
301
+ return x + h_
302
+
303
+
304
+ class EncoderFactorized(nn.Module):
305
+ def __init__(
306
+ self,
307
+ in_channels: int,
308
+ channels: int,
309
+ channels_mult: list[int],
310
+ num_res_blocks: int,
311
+ attn_resolutions: list[int],
312
+ dropout: float,
313
+ resolution: int,
314
+ z_channels: int,
315
+ spatial_compression: int,
316
+ temporal_compression: int,
317
+ **ignore_kwargs,
318
+ ) -> None:
319
+ super().__init__()
320
+ self.num_resolutions = len(channels_mult)
321
+ self.num_res_blocks = num_res_blocks
322
+
323
+ # Patcher.
324
+ patch_size = ignore_kwargs.get("patch_size", 1)
325
+ self.patcher3d = Patcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
326
+ in_channels = in_channels * patch_size * patch_size * patch_size
327
+
328
+ # calculate the number of downsample operations
329
+ self.num_spatial_downs = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
330
+ assert (
331
+ self.num_spatial_downs <= self.num_resolutions
332
+ ), f"Spatially downsample {self.num_resolutions} times at most"
333
+
334
+ self.num_temporal_downs = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
335
+ assert (
336
+ self.num_temporal_downs <= self.num_resolutions
337
+ ), f"Temporally downsample {self.num_resolutions} times at most"
338
+
339
+ # downsampling
340
+ self.conv_in = nn.Sequential(
341
+ CausalConv3d(in_channels, channels, kernel_size=(1, 3, 3), stride=1, padding=1),
342
+ CausalConv3d(channels, channels, kernel_size=(3, 1, 1), stride=1, padding=0),
343
+ )
344
+
345
+ curr_res = resolution // patch_size
346
+ in_ch_mult = (1,) + tuple(channels_mult)
347
+ self.in_ch_mult = in_ch_mult
348
+ self.down = nn.ModuleList()
349
+ for i_level in range(self.num_resolutions):
350
+ block = nn.ModuleList()
351
+ attn = nn.ModuleList()
352
+ block_in = channels * in_ch_mult[i_level]
353
+ block_out = channels * channels_mult[i_level]
354
+ for _ in range(self.num_res_blocks):
355
+ block.append(
356
+ CausalResnetBlockFactorized3d(
357
+ in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1
358
+ )
359
+ )
360
+ block_in = block_out
361
+ if curr_res in attn_resolutions:
362
+ attn.append(
363
+ nn.Sequential(
364
+ CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
365
+ )
366
+ )
367
+ down = nn.Module()
368
+ down.block = block
369
+ down.attn = attn
370
+ if i_level != self.num_resolutions - 1:
371
+ spatial_down = i_level < self.num_spatial_downs
372
+ temporal_down = i_level < self.num_temporal_downs
373
+ down.downsample = CausalHybridDownsample3d(
374
+ block_in, spatial_down=spatial_down, temporal_down=temporal_down
375
+ )
376
+ curr_res = curr_res // 2
377
+ self.down.append(down)
378
+
379
+ # middle
380
+ self.mid = nn.Module()
381
+ self.mid.block_1 = CausalResnetBlockFactorized3d(
382
+ in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
383
+ )
384
+ self.mid.attn_1 = nn.Sequential(
385
+ CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
386
+ )
387
+ self.mid.block_2 = CausalResnetBlockFactorized3d(
388
+ in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
389
+ )
390
+
391
+ # end
392
+ self.norm_out = CausalNormalize(block_in, num_groups=1)
393
+ self.conv_out = nn.Sequential(
394
+ CausalConv3d(block_in, z_channels, kernel_size=(1, 3, 3), stride=1, padding=1),
395
+ CausalConv3d(z_channels, z_channels, kernel_size=(3, 1, 1), stride=1, padding=0),
396
+ )
397
+
398
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
399
+ x = self.patcher3d(x)
400
+
401
+ # downsampling
402
+ h = self.conv_in(x)
403
+ for i_level in range(self.num_resolutions):
404
+ for i_block in range(self.num_res_blocks):
405
+ h = self.down[i_level].block[i_block](h)
406
+ if len(self.down[i_level].attn) > 0:
407
+ h = self.down[i_level].attn[i_block](h)
408
+ if i_level != self.num_resolutions - 1:
409
+ h = self.down[i_level].downsample(h)
410
+
411
+ # middle
412
+ h = self.mid.block_1(h)
413
+ h = self.mid.attn_1(h)
414
+ h = self.mid.block_2(h)
415
+
416
+ # end
417
+ h = self.norm_out(h)
418
+ h = nonlinearity(h)
419
+ h = self.conv_out(h)
420
+ return h
421
+
422
+
423
+ class DecoderFactorized(nn.Module):
424
+ def __init__(
425
+ self,
426
+ out_channels: int,
427
+ channels: int,
428
+ channels_mult: list[int],
429
+ num_res_blocks: int,
430
+ attn_resolutions: list[int],
431
+ dropout: float,
432
+ resolution: int,
433
+ z_channels: int,
434
+ spatial_compression: int,
435
+ temporal_compression: int,
436
+ **ignore_kwargs,
437
+ ):
438
+ super().__init__()
439
+ self.num_resolutions = len(channels_mult)
440
+ self.num_res_blocks = num_res_blocks
441
+
442
+ # UnPatcher.
443
+ patch_size = ignore_kwargs.get("patch_size", 1)
444
+ self.unpatcher3d = UnPatcher3D(patch_size, ignore_kwargs.get("patch_method", "rearrange"))
445
+ out_ch = out_channels * patch_size * patch_size * patch_size
446
+
447
+ # calculate the number of upsample operations
448
+ self.num_spatial_ups = int(math.log2(spatial_compression)) - int(math.log2(patch_size))
449
+ assert self.num_spatial_ups <= self.num_resolutions, f"Spatially upsample {self.num_resolutions} times at most"
450
+ self.num_temporal_ups = int(math.log2(temporal_compression)) - int(math.log2(patch_size))
451
+ assert (
452
+ self.num_temporal_ups <= self.num_resolutions
453
+ ), f"Temporally upsample {self.num_resolutions} times at most"
454
+
455
+ block_in = channels * channels_mult[self.num_resolutions - 1]
456
+ curr_res = (resolution // patch_size) // 2 ** (self.num_resolutions - 1)
457
+ self.z_shape = (1, z_channels, curr_res, curr_res)
458
+ log.debug("Working with z of shape {} = {} dimensions.".format(self.z_shape, np.prod(self.z_shape)))
459
+
460
+ # z to block_in
461
+ self.conv_in = nn.Sequential(
462
+ CausalConv3d(z_channels, block_in, kernel_size=(1, 3, 3), stride=1, padding=1),
463
+ CausalConv3d(block_in, block_in, kernel_size=(3, 1, 1), stride=1, padding=0),
464
+ )
465
+
466
+ # middle
467
+ self.mid = nn.Module()
468
+ self.mid.block_1 = CausalResnetBlockFactorized3d(
469
+ in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
470
+ )
471
+ self.mid.attn_1 = nn.Sequential(
472
+ CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
473
+ )
474
+ self.mid.block_2 = CausalResnetBlockFactorized3d(
475
+ in_channels=block_in, out_channels=block_in, dropout=dropout, num_groups=1
476
+ )
477
+
478
+ legacy_mode = ignore_kwargs.get("legacy_mode", False)
479
+ # upsampling
480
+ self.up = nn.ModuleList()
481
+ for i_level in reversed(range(self.num_resolutions)):
482
+ block = nn.ModuleList()
483
+ attn = nn.ModuleList()
484
+ block_out = channels * channels_mult[i_level]
485
+ for _ in range(self.num_res_blocks + 1):
486
+ block.append(
487
+ CausalResnetBlockFactorized3d(
488
+ in_channels=block_in, out_channels=block_out, dropout=dropout, num_groups=1
489
+ )
490
+ )
491
+ block_in = block_out
492
+ if curr_res in attn_resolutions:
493
+ attn.append(
494
+ nn.Sequential(
495
+ CausalAttnBlock(block_in, num_groups=1), CausalTemporalAttnBlock(block_in, num_groups=1)
496
+ )
497
+ )
498
+ up = nn.Module()
499
+ up.block = block
500
+ up.attn = attn
501
+ if i_level != 0:
502
+ # The layer index for temporal/spatial downsampling performed in the encoder should correspond
503
+ # to the layer index, inreverse order, where upsampling is performed in the decoder.
504
+ # If you've a pre-trained model, you can simply finetune.
505
+ # For example:
506
+ # Input tensor = (1, 3, 17, 32, 32)
507
+ # Patch size = 4 for 3D wavelet transform
508
+ # Compression rate = (8x16x16)
509
+ #
510
+ # We expect successive downsampling in the encoder and upsampling in the decoder to be mirrored.
511
+ # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)`
512
+ # DECODER: `(...,3,2,2) -> (...,3,4,4) -> (...,5,8,8)`
513
+ #
514
+ # if legacy_mode is True, the temporal upsampling is not perfectly mirrored.
515
+ # ENCODER: `(...,5,8,8) -> (...,3,4,4) -> (...,3,2,2)`
516
+ # DECODER: `(...,3,2,2) -> (...,5,4,4) -> (...,5,8,8)`
517
+ #
518
+ # Most of the CV and DV tokenizers were trained before 09/01/2024 with upsampling that's not mirrored.
519
+ # Going forward, new CV/DV tokenizers will adopt `legacy_mode=False`, i.e. use mirrored upsampling.
520
+ i_level_reverse = self.num_resolutions - i_level - 1
521
+ if legacy_mode:
522
+ temporal_up = i_level_reverse < self.num_temporal_ups
523
+ else:
524
+ temporal_up = 0 < i_level_reverse < self.num_temporal_ups + 1
525
+ spatial_up = temporal_up or (
526
+ i_level_reverse < self.num_spatial_ups and self.num_spatial_ups > self.num_temporal_ups
527
+ )
528
+ up.upsample = CausalHybridUpsample3d(block_in, spatial_up=spatial_up, temporal_up=temporal_up)
529
+ curr_res = curr_res * 2
530
+ self.up.insert(0, up) # prepend to get consistent order
531
+
532
+ # end
533
+ self.norm_out = CausalNormalize(block_in, num_groups=1)
534
+ self.conv_out = nn.Sequential(
535
+ CausalConv3d(block_in, out_ch, kernel_size=(1, 3, 3), stride=1, padding=1),
536
+ CausalConv3d(out_ch, out_ch, kernel_size=(3, 1, 1), stride=1, padding=0),
537
+ )
538
+
539
+ def forward(self, z):
540
+ h = self.conv_in(z)
541
+
542
+ # middle block.
543
+ h = self.mid.block_1(h)
544
+ h = self.mid.attn_1(h)
545
+ h = self.mid.block_2(h)
546
+
547
+ # decoder blocks.
548
+ for i_level in reversed(range(self.num_resolutions)):
549
+ for i_block in range(self.num_res_blocks + 1):
550
+ h = self.up[i_level].block[i_block](h)
551
+ if len(self.up[i_level].attn) > 0:
552
+ h = self.up[i_level].attn[i_block](h)
553
+ if i_level != 0:
554
+ h = self.up[i_level].upsample(h)
555
+
556
+ h = self.norm_out(h)
557
+ h = nonlinearity(h)
558
+ h = self.conv_out(h)
559
+ h = self.unpatcher3d(h)
560
+ return h
ar_tokenizer_networks.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import namedtuple
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+ from .ar_tokenizer_modules import CausalConv3d, DecoderFactorized, EncoderFactorized
22
+ from .ar_tokenizer_quantizers import FSQuantizer
23
+ from .log import log
24
+
25
+ NetworkEval = namedtuple("NetworkEval", ["reconstructions", "quant_loss", "quant_info"])
26
+
27
+
28
+ class CausalDiscreteVideoTokenizer(nn.Module):
29
+ def __init__(self, z_channels: int, z_factor: int, embedding_dim: int, **kwargs) -> None:
30
+ super().__init__()
31
+ self.name = kwargs.get("name", "CausalDiscreteVideoTokenizer")
32
+ self.embedding_dim = embedding_dim
33
+ self.encoder = EncoderFactorized(z_channels=z_factor * z_channels, **kwargs)
34
+ self.decoder = DecoderFactorized(z_channels=z_channels, **kwargs)
35
+
36
+ self.quant_conv = CausalConv3d(z_factor * z_channels, embedding_dim, kernel_size=1, padding=0)
37
+ self.post_quant_conv = CausalConv3d(embedding_dim, z_channels, kernel_size=1, padding=0)
38
+
39
+ self.quantizer = FSQuantizer(**kwargs)
40
+
41
+ num_parameters = sum(param.numel() for param in self.parameters())
42
+ log.debug(f"model={self.name}, num_parameters={num_parameters:,}")
43
+ log.debug(f"z_channels={z_channels}, embedding_dim={self.embedding_dim}.")
44
+
45
+ def to(self, *args, **kwargs):
46
+ setattr(self.quantizer, "dtype", kwargs.get("dtype", torch.bfloat16))
47
+ return super(CausalDiscreteVideoTokenizer, self).to(*args, **kwargs)
48
+
49
+ def encode(self, x):
50
+ h = self.encoder(x)
51
+ h = self.quant_conv(h)
52
+ return self.quantizer(h)
53
+
54
+ def decode(self, quant):
55
+ quant = self.post_quant_conv(quant)
56
+ return self.decoder(quant)
57
+
58
+ def forward(self, input):
59
+ quant_info, quant_codes, quant_loss = self.encode(input)
60
+ reconstructions = self.decode(quant_codes)
61
+ if self.training:
62
+ return dict(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info)
63
+ return NetworkEval(reconstructions=reconstructions, quant_loss=quant_loss, quant_info=quant_info)
ar_tokenizer_patching.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """The patcher and unpatcher implementation for 2D and 3D data."""
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from einops import rearrange
21
+
22
+ _WAVELETS = {
23
+ "haar": torch.tensor([0.7071067811865476, 0.7071067811865476]),
24
+ "rearrange": torch.tensor([1.0, 1.0]),
25
+ }
26
+ _PERSISTENT = False
27
+
28
+
29
+ class Patcher(torch.nn.Module):
30
+ """A module to convert image tensors into patches using torch operations.
31
+
32
+ The main difference from `class Patching` is that this module implements
33
+ all operations using torch, rather than python or numpy, for efficiency purpose.
34
+
35
+ It's bit-wise identical to the Patching module outputs, with the added
36
+ benefit of being torch.jit scriptable.
37
+ """
38
+
39
+ def __init__(self, patch_size=1, patch_method="haar"):
40
+ super().__init__()
41
+ self.patch_size = patch_size
42
+ self.patch_method = patch_method
43
+ self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT)
44
+ self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
45
+ self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT)
46
+ for param in self.parameters():
47
+ param.requires_grad = False
48
+
49
+ def forward(self, x):
50
+ if self.patch_method == "haar":
51
+ return self._haar(x)
52
+ elif self.patch_method == "rearrange":
53
+ return self._arrange(x)
54
+ else:
55
+ raise ValueError("Unknown patch method: " + self.patch_method)
56
+
57
+ def _dwt(self, x, mode="reflect", rescale=False):
58
+ dtype = x.dtype
59
+ h = self.wavelets
60
+
61
+ n = h.shape[0]
62
+ g = x.shape[1]
63
+ hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
64
+ hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
65
+ hh = hh.to(dtype=dtype)
66
+ hl = hl.to(dtype=dtype)
67
+
68
+ x = F.pad(x, pad=(n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
69
+ xl = F.conv2d(x, hl.unsqueeze(2), groups=g, stride=(1, 2))
70
+ xh = F.conv2d(x, hh.unsqueeze(2), groups=g, stride=(1, 2))
71
+ xll = F.conv2d(xl, hl.unsqueeze(3), groups=g, stride=(2, 1))
72
+ xlh = F.conv2d(xl, hh.unsqueeze(3), groups=g, stride=(2, 1))
73
+ xhl = F.conv2d(xh, hl.unsqueeze(3), groups=g, stride=(2, 1))
74
+ xhh = F.conv2d(xh, hh.unsqueeze(3), groups=g, stride=(2, 1))
75
+
76
+ out = torch.cat([xll, xlh, xhl, xhh], dim=1)
77
+ if rescale:
78
+ out = out / 2
79
+ return out
80
+
81
+ def _haar(self, x):
82
+ for _ in self.range:
83
+ x = self._dwt(x, rescale=True)
84
+ return x
85
+
86
+ def _arrange(self, x):
87
+ x = rearrange(x, "b c (h p1) (w p2) -> b (c p1 p2) h w", p1=self.patch_size, p2=self.patch_size).contiguous()
88
+ return x
89
+
90
+
91
+ class Patcher3D(Patcher):
92
+ """A 3D discrete wavelet transform for video data, expects 5D tensor, i.e. a batch of videos."""
93
+
94
+ def __init__(self, patch_size=1, patch_method="haar"):
95
+ super().__init__(patch_method=patch_method, patch_size=patch_size)
96
+ self.register_buffer(
97
+ "patch_size_buffer", patch_size * torch.ones([1], dtype=torch.int32), persistent=_PERSISTENT
98
+ )
99
+
100
+ def _dwt(self, x, mode="reflect", rescale=False):
101
+ dtype = x.dtype
102
+ h = self.wavelets
103
+
104
+ n = h.shape[0]
105
+ g = x.shape[1]
106
+ hl = h.flip(0).reshape(1, 1, -1).repeat(g, 1, 1)
107
+ hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
108
+ hh = hh.to(dtype=dtype)
109
+ hl = hl.to(dtype=dtype)
110
+
111
+ # Handles temporal axis.
112
+ x = F.pad(x, pad=(max(0, n - 2), n - 1, n - 2, n - 1, n - 2, n - 1), mode=mode).to(dtype)
113
+ xl = F.conv3d(x, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
114
+ xh = F.conv3d(x, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
115
+
116
+ # Handles spatial axes.
117
+ xll = F.conv3d(xl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
118
+ xlh = F.conv3d(xl, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
119
+ xhl = F.conv3d(xh, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
120
+ xhh = F.conv3d(xh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
121
+
122
+ xlll = F.conv3d(xll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
123
+ xllh = F.conv3d(xll, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
124
+ xlhl = F.conv3d(xlh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
125
+ xlhh = F.conv3d(xlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
126
+ xhll = F.conv3d(xhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
127
+ xhlh = F.conv3d(xhl, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
128
+ xhhl = F.conv3d(xhh, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
129
+ xhhh = F.conv3d(xhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
130
+
131
+ out = torch.cat([xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh], dim=1)
132
+ if rescale:
133
+ out = out / (2 * torch.sqrt(torch.tensor(2.0)))
134
+ return out
135
+
136
+ def _haar(self, x):
137
+ xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
138
+ x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
139
+ for _ in self.range:
140
+ x = self._dwt(x, rescale=True)
141
+ return x
142
+
143
+ def _arrange(self, x):
144
+ xi, xv = torch.split(x, [1, x.shape[2] - 1], dim=2)
145
+ x = torch.cat([xi.repeat_interleave(self.patch_size, dim=2), xv], dim=2)
146
+ x = rearrange(
147
+ x,
148
+ "b c (t p1) (h p2) (w p3) -> b (c p1 p2 p3) t h w",
149
+ p1=self.patch_size,
150
+ p2=self.patch_size,
151
+ p3=self.patch_size,
152
+ ).contiguous()
153
+ return x
154
+
155
+
156
+ class UnPatcher(torch.nn.Module):
157
+ """A module to convert patches into image tensorsusing torch operations.
158
+
159
+ The main difference from `class Unpatching` is that this module implements
160
+ all operations using torch, rather than python or numpy, for efficiency purpose.
161
+
162
+ It's bit-wise identical to the Unpatching module outputs, with the added
163
+ benefit of being torch.jit scriptable.
164
+ """
165
+
166
+ def __init__(self, patch_size=1, patch_method="haar"):
167
+ super().__init__()
168
+ self.patch_size = patch_size
169
+ self.patch_method = patch_method
170
+ self.register_buffer("wavelets", _WAVELETS[patch_method], persistent=_PERSISTENT)
171
+ self.range = range(int(torch.log2(torch.tensor(self.patch_size)).item()))
172
+ self.register_buffer("_arange", torch.arange(_WAVELETS[patch_method].shape[0]), persistent=_PERSISTENT)
173
+ for param in self.parameters():
174
+ param.requires_grad = False
175
+
176
+ def forward(self, x):
177
+ if self.patch_method == "haar":
178
+ return self._ihaar(x)
179
+ elif self.patch_method == "rearrange":
180
+ return self._iarrange(x)
181
+ else:
182
+ raise ValueError("Unknown patch method: " + self.patch_method)
183
+
184
+ def _idwt(self, x, rescale=False):
185
+ dtype = x.dtype
186
+ h = self.wavelets
187
+ n = h.shape[0]
188
+
189
+ g = x.shape[1] // 4
190
+ hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
191
+ hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
192
+ hh = hh.to(dtype=dtype)
193
+ hl = hl.to(dtype=dtype)
194
+
195
+ xll, xlh, xhl, xhh = torch.chunk(x.to(dtype), 4, dim=1)
196
+
197
+ # Inverse transform.
198
+ yl = torch.nn.functional.conv_transpose2d(xll, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
199
+ yl += torch.nn.functional.conv_transpose2d(xlh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
200
+ yh = torch.nn.functional.conv_transpose2d(xhl, hl.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
201
+ yh += torch.nn.functional.conv_transpose2d(xhh, hh.unsqueeze(3), groups=g, stride=(2, 1), padding=(n - 2, 0))
202
+ y = torch.nn.functional.conv_transpose2d(yl, hl.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2))
203
+ y += torch.nn.functional.conv_transpose2d(yh, hh.unsqueeze(2), groups=g, stride=(1, 2), padding=(0, n - 2))
204
+
205
+ if rescale:
206
+ y = y * 2
207
+ return y
208
+
209
+ def _ihaar(self, x):
210
+ for _ in self.range:
211
+ x = self._idwt(x, rescale=True)
212
+ return x
213
+
214
+ def _iarrange(self, x):
215
+ x = rearrange(x, "b (c p1 p2) h w -> b c (h p1) (w p2)", p1=self.patch_size, p2=self.patch_size)
216
+ return x
217
+
218
+
219
+ class UnPatcher3D(UnPatcher):
220
+ """A 3D inverse discrete wavelet transform for video wavelet decompositions."""
221
+
222
+ def __init__(self, patch_size=1, patch_method="haar"):
223
+ super().__init__(patch_method=patch_method, patch_size=patch_size)
224
+
225
+ def _idwt(self, x, rescale=False):
226
+ dtype = x.dtype
227
+ h = self.wavelets
228
+
229
+ g = x.shape[1] // 8 # split into 8 spatio-temporal filtered tesnors.
230
+ hl = h.flip([0]).reshape(1, 1, -1).repeat([g, 1, 1])
231
+ hh = (h * ((-1) ** self._arange)).reshape(1, 1, -1).repeat(g, 1, 1)
232
+ hl = hl.to(dtype=dtype)
233
+ hh = hh.to(dtype=dtype)
234
+
235
+ xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
236
+
237
+ # Height height transposed convolutions.
238
+ xll = F.conv_transpose3d(xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
239
+ xll += F.conv_transpose3d(xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
240
+
241
+ xlh = F.conv_transpose3d(xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
242
+ xlh += F.conv_transpose3d(xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
243
+
244
+ xhl = F.conv_transpose3d(xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
245
+ xhl += F.conv_transpose3d(xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
246
+
247
+ xhh = F.conv_transpose3d(xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
248
+ xhh += F.conv_transpose3d(xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2))
249
+
250
+ # Handles width transposed convolutions.
251
+ xl = F.conv_transpose3d(xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
252
+ xl += F.conv_transpose3d(xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
253
+ xh = F.conv_transpose3d(xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
254
+ xh += F.conv_transpose3d(xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1))
255
+
256
+ # Handles time axis transposed convolutions.
257
+ x = F.conv_transpose3d(xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
258
+ x += F.conv_transpose3d(xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1))
259
+
260
+ if rescale:
261
+ x = x * (2 * torch.sqrt(torch.tensor(2.0)))
262
+ return x
263
+
264
+ def _ihaar(self, x):
265
+ for _ in self.range:
266
+ x = self._idwt(x, rescale=True)
267
+ x = x[:, :, self.patch_size - 1 :, ...]
268
+ return x
269
+
270
+ def _iarrange(self, x):
271
+ x = rearrange(
272
+ x,
273
+ "b (c p1 p2 p3) t h w -> b c (t p1) (h p2) (w p3)",
274
+ p1=self.patch_size,
275
+ p2=self.patch_size,
276
+ p3=self.patch_size,
277
+ )
278
+ x = x[:, :, self.patch_size - 1 :, ...]
279
+ return x
ar_tokenizer_quantizers.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Quantizers for discrete image and video tokenization."""
17
+
18
+ from typing import Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ from einops import rearrange
23
+
24
+ from .ar_tokenizer_utils import default, pack_one, round_ste, unpack_one
25
+
26
+
27
+ class FSQuantizer(nn.Module):
28
+ """Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505
29
+
30
+ Adapted from: https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/
31
+ vector_quantize_pytorch/finite_scalar_quantization.py
32
+ [Copyright (c) 2020 Phil Wang]
33
+ https://github.com/lucidrains/vector-quantize-pytorch/blob/9502a1f447876d53fd37685b226bf28f250dc4a3/LICENSE
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ levels: list[int],
39
+ dim: Optional[int] = None,
40
+ num_codebooks=1,
41
+ keep_num_codebooks_dim: Optional[bool] = None,
42
+ scale: Optional[float] = None,
43
+ **ignore_kwargs,
44
+ ):
45
+ super().__init__()
46
+ self.dtype = ignore_kwargs.get("dtype", torch.float32)
47
+ _levels = torch.tensor(levels, dtype=torch.int32)
48
+ self.register_buffer("_levels", _levels, persistent=False)
49
+
50
+ _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
51
+ self.register_buffer("_basis", _basis, persistent=False)
52
+
53
+ self.scale = scale
54
+
55
+ codebook_dim = len(levels)
56
+ self.codebook_dim = codebook_dim
57
+
58
+ effective_codebook_dim = codebook_dim * num_codebooks
59
+ self.num_codebooks = num_codebooks
60
+ self.effective_codebook_dim = effective_codebook_dim
61
+
62
+ keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1)
63
+ assert not (num_codebooks > 1 and not keep_num_codebooks_dim)
64
+ self.keep_num_codebooks_dim = keep_num_codebooks_dim
65
+
66
+ self.dim = default(dim, len(_levels) * num_codebooks)
67
+
68
+ has_projections = self.dim != effective_codebook_dim
69
+ self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
70
+ self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
71
+ self.has_projections = has_projections
72
+
73
+ self.codebook_size = self._levels.prod().item()
74
+
75
+ implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
76
+ self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
77
+
78
+ def bound(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
79
+ """Bound `z`, an array of shape (..., d)."""
80
+ half_l = (self._levels - 1) * (1 + eps) / 2
81
+ offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
82
+ shift = (offset / half_l).atanh()
83
+ return (z + shift).tanh() * half_l - offset
84
+
85
+ def quantize(self, z: torch.Tensor) -> torch.Tensor:
86
+ """Quantizes z, returns quantized zhat, same shape as z."""
87
+ quantized = round_ste(self.bound(z))
88
+ half_width = self._levels // 2 # Renormalize to [-1, 1].
89
+ return quantized / half_width
90
+
91
+ def _scale_and_shift(self, zhat_normalized: torch.Tensor) -> torch.Tensor:
92
+ half_width = self._levels // 2
93
+ return (zhat_normalized * half_width) + half_width
94
+
95
+ def _scale_and_shift_inverse(self, zhat: torch.Tensor) -> torch.Tensor:
96
+ half_width = self._levels // 2
97
+ return (zhat - half_width) / half_width
98
+
99
+ def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
100
+ """Converts a `code` to an index in the codebook."""
101
+ assert zhat.shape[-1] == self.codebook_dim
102
+ zhat = self._scale_and_shift(zhat).float()
103
+ return (zhat * self._basis).sum(dim=-1).to(torch.int32)
104
+
105
+ def indices_to_codes(self, indices: torch.Tensor, project_out=True) -> torch.Tensor:
106
+ """Inverse of `codes_to_indices`."""
107
+ is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
108
+ indices = rearrange(indices, "... -> ... 1")
109
+ codes_non_centered = (indices // self._basis) % self._levels
110
+ codes = self._scale_and_shift_inverse(codes_non_centered)
111
+
112
+ if self.keep_num_codebooks_dim:
113
+ codes = rearrange(codes, "... c d -> ... (c d)")
114
+
115
+ if project_out:
116
+ codes = self.project_out(codes)
117
+
118
+ if is_img_or_video:
119
+ codes = rearrange(codes, "b ... d -> b d ...")
120
+
121
+ return codes.to(self.dtype)
122
+
123
+ def forward(self, z: torch.Tensor) -> torch.Tensor:
124
+ """
125
+ einstein notation
126
+ b - batch
127
+ n - sequence (or flattened spatial dimensions)
128
+ d - feature dimension, which is also log2(codebook size)
129
+ c - number of codebook dim
130
+ """
131
+ is_img_or_video = z.ndim >= 4
132
+
133
+ # standardize image or video into (batch, seq, dimension)
134
+
135
+ if is_img_or_video:
136
+ z = rearrange(z, "b d ... -> b ... d")
137
+ z, ps = pack_one(z, "b * d")
138
+
139
+ assert z.shape[-1] == self.dim, f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}"
140
+
141
+ z = self.project_in(z)
142
+
143
+ z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks)
144
+
145
+ codes = self.quantize(z)
146
+ indices = self.codes_to_indices(codes)
147
+
148
+ codes = rearrange(codes, "b n c d -> b n (c d)")
149
+
150
+ out = self.project_out(codes)
151
+
152
+ # reconstitute image or video dimensions
153
+
154
+ if is_img_or_video:
155
+ out = unpack_one(out, ps, "b * d")
156
+ out = rearrange(out, "b ... d -> b d ...")
157
+ indices = unpack_one(indices, ps, "b * c")
158
+ dummy_loss = torch.zeros_like(out.mean(dim=[1, 2, 3], keepdim=True))
159
+ else:
160
+ dummy_loss = torch.zeros_like(out.mean(dim=[1, 2], keepdim=True)).unsqueeze(1)
161
+
162
+ if not self.keep_num_codebooks_dim:
163
+ indices = rearrange(indices, "... 1 -> ...")
164
+
165
+ return (indices, out.to(self.dtype), dummy_loss)
ar_tokenizer_text_tokenizer.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, List, Optional, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+ from transformers import AutoTokenizer
21
+
22
+ from .log import log
23
+
24
+
25
+ def get_tokenizer_path(model_family: str, is_instruct_model: bool = False):
26
+ """
27
+ Get the tokenizer path from the model family and instruct model flag.
28
+ Args:
29
+ model_family (str): The model family.
30
+ is_instruct_model (bool): Whether the model is an instruct model.
31
+ Returns:
32
+ str: The tokenizer path in s3.
33
+ """
34
+ model_family = model_family.lower()
35
+ if model_family == "mistral":
36
+ return "mistralai/Mistral-Nemo-Instruct-2407"
37
+ else:
38
+ assert model_family in ["llama3", "llama3.1"]
39
+ if model_family == "llama3":
40
+ model_path = "meta-llama/Meta-Llama-3-8B"
41
+ elif model_family == "llama3.1":
42
+ model_path = "meta-llama/Llama-3.1-8B"
43
+ else:
44
+ raise ValueError(f"Unsupported model family: {model_family}")
45
+ suffix = "-Instruct" if is_instruct_model else ""
46
+ model_path = f"{model_path}{suffix}"
47
+ return model_path
48
+
49
+
50
+ class TextTokenizer:
51
+ """
52
+ Text tokenizer class built on HuggingFace's Fast Tokenizer (Rust based).
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ model_family: str,
58
+ is_instruct_model: bool,
59
+ local_path: Optional[str] = None,
60
+ ):
61
+ """
62
+ Initialize the TextTokenizer.
63
+ Args:
64
+ model_family (str): The model family.
65
+ is_instruct_model (bool): Whether the model is an instruct model.
66
+ local_path (Optional[str]): The local path to the tokenizer. If not provided, the tokenizer will be downloaded from the remote path.
67
+ """
68
+ if local_path is None:
69
+ tokenizer_path = get_tokenizer_path(model_family, is_instruct_model)
70
+ else:
71
+ tokenizer_path = local_path
72
+
73
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
74
+ self.stop_tokens = {
75
+ self.tokenizer.eos_token_id,
76
+ }
77
+ self.model_family = model_family
78
+ self.is_instruct_model = is_instruct_model
79
+ self.eos_id = self.tokenizer.eos_token_id
80
+ if self.tokenizer.pad_token is None:
81
+ if model_family.startswith("llama"):
82
+ self.pad_id = 128004 # "<|finetune_right_pad_id|>"
83
+ elif model_family == "mistral":
84
+ self.pad_id = 10 # "<pad>"
85
+ elif model_family == "pixtral":
86
+ self.pad_id = 11 # "<pad>"
87
+ else:
88
+ raise ValueError(f"pad_id not defined for model_family {model_family}")
89
+ else:
90
+ self.pad_id = self.tokenizer.pad_token_id
91
+
92
+ def tokenize(self, text: str, *, add_special_tokens: bool = False, **kwargs) -> List[str]:
93
+ """
94
+ Converts a string into a sequence of tokens, replacing unknown tokens with the `unk_token`.
95
+
96
+ Args:
97
+ text (`str`):
98
+ The sequence to be encoded.
99
+ add_special_tokens (`bool`, *optional*, defaults to `False`):
100
+ Whether or not to add the special tokens associated with the corresponding model.
101
+ Returns:
102
+ `List[str]`: The list of tokens.
103
+ """
104
+ return self.tokenizer.tokenize(text, add_special_tokens=add_special_tokens, **kwargs)
105
+
106
+ def encode(
107
+ self,
108
+ text: Union[str, List[str], List[int]],
109
+ *, # Enforce keyword-only arguments
110
+ add_special_tokens: bool = True,
111
+ padding: Union[bool, str] = False,
112
+ truncation: Union[bool, str] = None,
113
+ max_length: Optional[int] = None,
114
+ stride: int = 0,
115
+ return_tensors: Optional[str] = None,
116
+ **kwargs,
117
+ ) -> List[int]:
118
+ """
119
+ Converts a string to a sequence of ids (integer), using the tokenizer and vocabulary.
120
+
121
+ Args:
122
+ text (`str`, `List[str]` or `List[int]`):
123
+ The first sequence to be encoded. This can be a string, a list of strings (tokenized string using the
124
+ `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
125
+ method).
126
+ add_special_tokens (`bool`, *optional*, defaults to `True`):
127
+ Whether or not to add special tokens when encoding the sequences. This will use the underlying
128
+ `PretrainedTokenizerBase.build_inputs_with_special_tokens` function, which defines which tokens are
129
+ automatically added to the input ids. This is usefull if you want to add `bos` or `eos` tokens
130
+ automatically.
131
+ padding (`bool`, `str`, *optional*, defaults to `False`):
132
+ Activates and controls padding. Accepts the following values:
133
+
134
+ - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
135
+ sequence if provided).
136
+ - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
137
+ acceptable input length for the model if that argument is not provided.
138
+ - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
139
+ lengths).
140
+ truncation (`bool`, `str`, *optional*, defaults to `False`):
141
+ Activates and controls truncation. Accepts the following values:
142
+
143
+ - `True` or `'longest_first'`: Truncate to a maximum length specified with the argument `max_length` or
144
+ to the maximum acceptable input length for the model if that argument is not provided. This will
145
+ truncate token by token, removing a token from the longest sequence in the pair if a pair of
146
+ sequences (or a batch of pairs) is provided.
147
+ - `'only_first'`: Truncate to a maximum length specified with the argument `max_length` or to the
148
+ maximum acceptable input length for the model if that argument is not provided. This will only
149
+ truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
150
+ - `'only_second'`: Truncate to a maximum length specified with the argument `max_length` or to the
151
+ maximum acceptable input length for the model if that argument is not provided. This will only
152
+ truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided.
153
+ - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
154
+ greater than the model maximum admissible input size).
155
+ max_length (`int`, *optional*):
156
+ Controls the maximum length to use by one of the truncation/padding parameters.
157
+
158
+ If left unset or set to `None`, this will use the predefined model maximum length if a maximum length
159
+ is required by one of the truncation/padding parameters. If the model has no specific maximum input
160
+ length (like XLNet) truncation/padding to a maximum length will be deactivated.
161
+ stride (`int`, *optional*, defaults to 0):
162
+ If set to a number along with `max_length`, the overflowing tokens returned when
163
+ `return_overflowing_tokens=True` will contain some tokens from the end of the truncated sequence
164
+ returned to provide some overlap between truncated and overflowing sequences. The value of this
165
+ argument defines the number of overlapping tokens.
166
+ is_split_into_words (`bool`, *optional*, defaults to `False`):
167
+ Whether or not the input is already pre-tokenized (e.g., split into words). If set to `True`, the
168
+ tokenizer assumes the input is already split into words (for instance, by splitting it on whitespace)
169
+ which it will tokenize. This is useful for NER or token classification.
170
+ pad_to_multiple_of (`int`, *optional*):
171
+ If set will pad the sequence to a multiple of the provided value. Requires `padding` to be activated.
172
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
173
+ `>= 7.5` (Volta).
174
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
175
+ If set, will return tensors instead of list of python integers. Acceptable values are:
176
+
177
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
178
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
179
+ - `'np'`: Return Numpy `np.ndarray` objects.
180
+ """
181
+ return self.tokenizer.encode(
182
+ text,
183
+ add_special_tokens=add_special_tokens,
184
+ padding=padding,
185
+ truncation=truncation,
186
+ max_length=max_length,
187
+ stride=stride,
188
+ return_tensors=return_tensors,
189
+ )
190
+
191
+ def decode(
192
+ self,
193
+ token_ids: Union[int, List[int], "np.ndarray", "torch.Tensor"],
194
+ *, # Enforce keyword-only arguments
195
+ skip_special_tokens: bool = False,
196
+ clean_up_tokenization_spaces: bool = None,
197
+ **kwargs,
198
+ ) -> str:
199
+ """
200
+ Converts a sequence of ids in a string, using the tokenizer and vocabulary with options to remove special
201
+ tokens and clean up tokenization spaces.
202
+
203
+ Args:
204
+ token_ids (`Union[int, List[int], np.ndarray, torch.Tensor, tf.Tensor]`):
205
+ List of tokenized input ids. Can be obtained using the `__call__` method.
206
+ skip_special_tokens (`bool`, *optional*, defaults to `False`):
207
+ Whether or not to remove special tokens in the decoding.
208
+ clean_up_tokenization_spaces (`bool`, *optional*):
209
+ Whether or not to clean up the tokenization spaces. If `None`, will default to
210
+ `self.clean_up_tokenization_spaces`.
211
+ kwargs (additional keyword arguments, *optional*):
212
+ Will be passed to the underlying model specific decode method.
213
+
214
+ Returns:
215
+ `str`: The decoded sentence.
216
+ """
217
+ return self.tokenizer.decode(
218
+ token_ids,
219
+ skip_special_tokens=skip_special_tokens,
220
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces,
221
+ **kwargs,
222
+ )
223
+
224
+ def apply_chat_template(
225
+ self,
226
+ conversation: Union[List[Dict[str, str]], List[List[Dict[str, str]]]],
227
+ *,
228
+ add_generation_prompt: bool = False,
229
+ tokenize: bool = True,
230
+ padding: bool = False,
231
+ truncation: bool = False,
232
+ max_length: Optional[int] = None,
233
+ return_tensors: Optional[str] = None,
234
+ return_dict: bool = False,
235
+ return_assistant_tokens_mask: bool = False,
236
+ generation_prefix: str = "",
237
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
238
+ **kwargs,
239
+ ):
240
+ """
241
+ Converts a list of dictionaries with `"role"` and `"content"` keys to a list of token
242
+ ids. This method is intended for use with chat models, and will read the tokenizer's chat_template attribute to determine the format and control tokens to use when converting.
243
+
244
+ More details can be found at https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template
245
+
246
+ Args:
247
+ conversation (Union[List[Dict[str, str]], List[List[Dict[str, str]]]]): A list of dicts
248
+ with "role" and "content" keys, representing the chat history so far.
249
+ add_generation_prompt (bool, *optional*):
250
+ If this is set, a prompt with the token(s) that indicate
251
+ the start of an assistant message will be appended to the formatted output. This is useful when you want to generate a response from the model.
252
+ Note that this argument will be passed to the chat template, and so it must be supported in the
253
+ template for this argument to have any effect.
254
+ continue_final_message (bool, *optional*):
255
+ If this is set, the chat will be formatted so that the final
256
+ message in the chat is open-ended, without any EOS tokens. The model will continue this message
257
+ rather than starting a new one. This allows you to "prefill" part of
258
+ the model's response for it. Cannot be used at the same time as `add_generation_prompt`.
259
+ tokenize (`bool`, defaults to `True`):
260
+ Whether to tokenize the output. If `False`, the output will be a string.
261
+ padding (`bool`, defaults to `False`):
262
+ Whether to pad sequences to the maximum length. Has no effect if tokenize is `False`.
263
+ truncation (`bool`, defaults to `False`):
264
+ Whether to truncate sequences at the maximum length. Has no effect if tokenize is `False`.
265
+ max_length (`int`, *optional*):
266
+ Maximum length (in tokens) to use for padding or truncation. Has no effect if tokenize is `False`. If
267
+ not specified, the tokenizer's `max_length` attribute will be used as a default.
268
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
269
+ If set, will return tensors of a particular framework. Has no effect if tokenize is `False`. Acceptable
270
+ values are:
271
+ - `'tf'`: Return TensorFlow `tf.Tensor` objects.
272
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
273
+ - `'np'`: Return NumPy `np.ndarray` objects.
274
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
275
+ return_dict (`bool`, defaults to `False`):
276
+ Whether to return a dictionary with named outputs. Has no effect if tokenize is `False`.
277
+ generation_prefix (str): Prefix to add before asking model to generate. Helpful to guide the generation. Defaults to "".
278
+ tokenizer_kwargs (`Dict[str: Any]`, *optional*): Additional kwargs to pass to the tokenizer.
279
+ return_assistant_tokens_mask (`bool`, defaults to `False`):
280
+ Whether to return a mask of the assistant generated tokens. For tokens generated by the assistant,
281
+ the mask will contain 1. For user and system tokens, the mask will contain 0.
282
+ This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
283
+ **kwargs: Additional kwargs to pass to the template renderer. Will be accessible by the chat template.
284
+
285
+ Returns:
286
+ `Union[List[int], Dict]`: A list of token ids representing the tokenized chat so far, including control tokens. This
287
+ output is ready to pass to the model, either directly or via methods like `generate()`. If `return_dict` is
288
+ set, will return a dict of tokenizer outputs instead.
289
+ """
290
+ if not self.is_instruct_model:
291
+ raise ValueError(
292
+ "apply_chat_template is only supported for instruct models. You should pass argument is_instruct_model=True to the TextTokenizer constructor."
293
+ )
294
+ # Since generation_prefix is added to the text in the end, ensure that the setting is correct
295
+ if generation_prefix:
296
+ assert not tokenize, "tokenize must be False when generation_prefix is provided."
297
+ assert add_generation_prompt, "add_generation_prompt must be set when generation_prefix is provided."
298
+ formatted_text: Union[str, List[int]] = self.tokenizer.apply_chat_template(
299
+ conversation,
300
+ add_generation_prompt=add_generation_prompt,
301
+ tokenize=tokenize,
302
+ padding=padding,
303
+ truncation=truncation,
304
+ max_length=max_length,
305
+ return_tensors=return_tensors,
306
+ return_dict=return_dict,
307
+ return_assistant_tokens_mask=return_assistant_tokens_mask,
308
+ tokenizer_kwargs=tokenizer_kwargs,
309
+ **kwargs,
310
+ )
311
+ if generation_prefix:
312
+ formatted_text: str = formatted_text + generation_prefix
313
+ log.debug(
314
+ f"Adding generation prefix: {generation_prefix} to the formatted text\n"
315
+ f"Formatted text: {formatted_text}"
316
+ )
317
+ return formatted_text
ar_tokenizer_tokenizer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import defaultdict
17
+ from typing import Optional
18
+
19
+ import torch
20
+ from einops import rearrange
21
+
22
+ from .ar_config_base_tokenizer import TokenizerConfig
23
+ from .lazy_config_init import instantiate as lazy_instantiate
24
+
25
+
26
+ def update_vocab_size(
27
+ existing_vocab_size,
28
+ to_be_added_vocab_size,
29
+ training_type,
30
+ add_special_tokens,
31
+ video_special_tokens={},
32
+ ):
33
+ # New vocab size
34
+ if add_special_tokens:
35
+ existing_vocab_size += to_be_added_vocab_size + len(video_special_tokens)
36
+ # For text_to_video, we add one <bov> special token at the beginning of the video
37
+ elif training_type == "text_to_video":
38
+ existing_vocab_size += to_be_added_vocab_size + 1
39
+ else:
40
+ existing_vocab_size += to_be_added_vocab_size
41
+ return existing_vocab_size
42
+
43
+
44
+ class DiscreteMultimodalTokenizer:
45
+ def __init__(self, tokenizer_config: TokenizerConfig):
46
+ self.tokenizer_config = tokenizer_config
47
+ self.vocab_size = 0
48
+ self.total_seq_len = tokenizer_config.seq_len
49
+ self.pad_to_multiple_of = tokenizer_config.pad_to_multiple_of
50
+ self.training_type = tokenizer_config.training_type
51
+ assert self.training_type in [
52
+ "text_only",
53
+ "text_to_video",
54
+ "video_to_video",
55
+ "image_text_interleaved",
56
+ ], f"{self.training_type} not supported"
57
+
58
+ self._build_text_tokenizer()
59
+ self._build_video_tokenizer()
60
+
61
+ def _build_text_tokenizer(self):
62
+ r"""Function to initialize the text tokenizer model."""
63
+ if self.tokenizer_config.text_tokenizer is not None:
64
+ self.text_tokenizer = lazy_instantiate(self.tokenizer_config.text_tokenizer.config)
65
+ self.vocab_size += self.tokenizer_config.text_tokenizer.vocab_size
66
+ else:
67
+ self.text_tokenizer = None
68
+
69
+ def _build_video_tokenizer(self):
70
+ r"""Function to initialize the video tokenizer model."""
71
+ if self.tokenizer_config.video_tokenizer is not None:
72
+ self.video_tokenizer = lazy_instantiate(self.tokenizer_config.video_tokenizer.config)
73
+ self.video_tokenizer = self.video_tokenizer.to("cuda")
74
+ self.video_vocab_size = self.tokenizer_config.video_tokenizer.vocab_size
75
+ special_token_offset = (
76
+ self.tokenizer_config.video_tokenizer.tokenizer_offset
77
+ + self.tokenizer_config.video_tokenizer.vocab_size
78
+ )
79
+ self.video_special_tokens = {
80
+ "<|begin_of_video|>": special_token_offset,
81
+ "<|end_of_video|>": special_token_offset + 1,
82
+ "<|pad_token_video|>": special_token_offset + 2,
83
+ }
84
+
85
+ self.vocab_size = update_vocab_size(
86
+ existing_vocab_size=self.vocab_size,
87
+ to_be_added_vocab_size=self.tokenizer_config.video_tokenizer.vocab_size,
88
+ training_type=self.training_type,
89
+ add_special_tokens=self.tokenizer_config.add_special_tokens,
90
+ video_special_tokens=self.video_special_tokens,
91
+ )
92
+ else:
93
+ self.video_tokenizer = None
94
+
95
+ @property
96
+ def pad_id(self):
97
+ r"""Returns the pad_id."""
98
+
99
+ if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
100
+ pad_id = self.text_tokenizer.pad_id
101
+ elif self.training_type in ["text_to_video", "video_to_video"]:
102
+ pad_id = self.video_special_tokens["<|pad_token_video|>"]
103
+ else:
104
+ raise ValueError(f"training_type {self.training_type} not defined")
105
+ return pad_id
106
+
107
+ @property
108
+ def ignore_index(self):
109
+ r"""Returns which token should be ignored during loss computation."""
110
+ if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
111
+ if self.text_tokenizer.pad_id == self.text_tokenizer.eos_id:
112
+ # If the PAD token is the same as the EOS token, we do not ignore it during loss
113
+ # computation, since we want the model to be able to predict EOS tokens in inference.
114
+ # The PyTorch default ignore_index for the cross-entropy loss is -100.
115
+ ignore_index = -100
116
+ else:
117
+ ignore_index = self.text_tokenizer.pad_id
118
+ elif self.training_type in ["text_to_video", "video_to_video"]:
119
+ ignore_index = self.pad_id
120
+ else:
121
+ raise ValueError(f"training_type {self.training_type} not defined")
122
+ return ignore_index
123
+
124
+ @property
125
+ def stop_tokens(self):
126
+ r"""Returns the stop tokens."""
127
+ if self.training_type == "text_only" or self.training_type == "image_text_interleaved":
128
+ stop_tokens = self.text_tokenizer.stop_tokens
129
+ elif self.training_type in ["text_to_video", "video_to_video"]:
130
+ stop_tokens = set([self.video_special_tokens["<|end_of_video|>"]])
131
+ else:
132
+ raise ValueError(f"training_type {self.training_type} not defined")
133
+ return stop_tokens
134
+
135
+ def _tokenize_text(self, raw_text: list[str], max_text_seq_len: int = -1):
136
+ r"""Function to tokenize text.
137
+ Args:
138
+ raw_text (list[str]): List of input strings
139
+ max_text_seq_len (int): Maximum sequence length returned by text tokenizer
140
+ Returns:
141
+ text_tokens (list[list[int]]): List of text tokens
142
+ """
143
+
144
+ batch_size = len(raw_text)
145
+ text_tokens = [self.text_tokenizer.encode(raw_text[i], bos=True, eos=True) for i in range(batch_size)]
146
+
147
+ # Clipping the text tokens so that the sequence length does not exceed max_text_seq_len
148
+ if max_text_seq_len > -1:
149
+ for i in range(len(text_tokens)):
150
+ if len(text_tokens[i]) > max_text_seq_len:
151
+ # Simply clip and add end of seq token
152
+ text_tokens[i] = text_tokens[i][0 : max_text_seq_len - 1] + [self.text_tokenizer.eos_id]
153
+ return text_tokens
154
+
155
+ def _tokenize_class(self, cls_labels: list[str]):
156
+ r"""Function to tokenize the class label.
157
+ Args:
158
+ cls_labels (list[str]): List of class indices
159
+ Returns:
160
+ class_tokens (list[list[int]]): List of class tokens
161
+ """
162
+
163
+ # tokenizer_offset tells what offset should be added to the tokens.
164
+ # This is needed for vocab expansion.
165
+ class_tokens = [[int(x) + self.tokenizer_config.class_tokenizer.tokenizer_offset] for x in cls_labels]
166
+
167
+ return class_tokens
168
+
169
+ def _tokenize_video(self, videos: torch.Tensor, pixel_chunk_duration: Optional[int] = None):
170
+ r"""Function to tokenize video.
171
+ Args:
172
+ videos (torch.Tensor): Input video data tensor
173
+ pixel_chunk_duration (Optional[float]): Pixel chunk duration. If provided, we pass it to the video tokenizer.
174
+ Returns:
175
+ video_tokens (list[list[int]]): List of video tokens
176
+ """
177
+
178
+ video_tokens = []
179
+ batch_size = videos.shape[0]
180
+
181
+ quantized_out, _ = self.video_tokenizer.encode(videos, pixel_chunk_duration=pixel_chunk_duration)
182
+ indices = self.video_tokenizer.fsq_quantizer.codes_to_indices(quantized_out.permute(0, 2, 3, 4, 1))
183
+
184
+ # Flatten the indices
185
+ indices = rearrange(indices, "B T H W -> B (T H W)")
186
+
187
+ # tokenizer_offset tells what offset should be added to the tokens.
188
+ # This is needed for vocab expansion.
189
+ indices += self.tokenizer_config.video_tokenizer.tokenizer_offset
190
+
191
+ # Add begin and end of video tokens
192
+ bov_token = self.video_special_tokens["<|begin_of_video|>"]
193
+ eov_token = self.video_special_tokens["<|end_of_video|>"]
194
+
195
+ # Append bov and eov tokens
196
+ if self.tokenizer_config.add_special_tokens:
197
+ for i in range(batch_size):
198
+ video_tokens.append([bov_token] + indices[i].tolist() + [eov_token])
199
+ else:
200
+ if self.training_type == "text_to_video":
201
+ for i in range(batch_size):
202
+ video_tokens.append([bov_token] + indices[i].tolist())
203
+ else:
204
+ for i in range(batch_size):
205
+ video_tokens.append(indices[i].tolist())
206
+ assert (
207
+ len(video_tokens[-1]) == self.tokenizer_config.video_tokenizer.max_seq_len
208
+ ), f"Expected {self.tokenizer_config.video_tokenizer.max_seq_len} tokens, got {len(video_tokens[-1])}; video shape: {videos.shape}"
209
+
210
+ return video_tokens
211
+
212
+ def tokenize(self, data_batch: dict):
213
+ r"""Function to tokenize data_dict.
214
+ Args:
215
+ data_batch (dict): Input data dict
216
+ Returns:
217
+ tokens (torch.LongTensor): Token tensor dict
218
+ """
219
+
220
+ if (
221
+ self.training_type in ["text_only", "image_text_interleaved"]
222
+ and not self.tokenizer_config.text_tokenizer.tokenize_here
223
+ ):
224
+ # In case of pre-computed tokens, just return the data_batch
225
+ return data_batch["tokens"], None
226
+
227
+ # Online tokenization
228
+ tokens = []
229
+ token_boundaries = defaultdict(list)
230
+
231
+ # Obtain maximum sequence length
232
+ max_text_seq_len = -1
233
+ max_visual_seq_len = -1
234
+
235
+ if self.training_type in ["text_to_video", "video_to_video"]:
236
+ max_visual_seq_len = self.tokenizer_config.video_tokenizer.max_seq_len
237
+
238
+ # If max visual sequence length is specified, make sure that text is clipped so that
239
+ # the full video/image is always seen.
240
+ if max_visual_seq_len > -1:
241
+ if self.tokenizer_config.add_special_tokens:
242
+ max_visual_seq_len = max_visual_seq_len + 2 # Two special tokens is for [bov, eov] or [boi, eoi] token
243
+ elif self.training_type == "text_to_video":
244
+ max_visual_seq_len = max_visual_seq_len + 1
245
+ else:
246
+ max_visual_seq_len = max_visual_seq_len
247
+ assert (
248
+ max_visual_seq_len <= self.total_seq_len
249
+ ), f"max_visual_seq_len ({max_visual_seq_len}) is greater that total sequence length ({self.total_seq_len})"
250
+ max_text_seq_len = self.total_seq_len - max_visual_seq_len
251
+
252
+ # Tokenize the text
253
+ if (
254
+ "text" in self.training_type
255
+ and self.text_tokenizer is not None
256
+ and self.tokenizer_config.text_tokenizer.tokenize_here
257
+ ):
258
+ key = self.tokenizer_config.text_tokenizer.data_key
259
+ batch_size = len(data_batch[key])
260
+ assert key in data_batch, f"Key {key} should be present in data for text tokenizer"
261
+ tokens = self._tokenize_text(data_batch["caption"], max_text_seq_len)
262
+
263
+ for i in range(batch_size):
264
+ token_boundaries["text"].append((0, len(tokens[i])))
265
+ else:
266
+ tokens = []
267
+ batch_size = None
268
+
269
+ # Tokenize the class label
270
+ if "class" in self.training_type and self.tokenizer_config.class_tokenizer is not None:
271
+ key = self.tokenizer_config.class_tokenizer.data_key
272
+ assert key in data_batch, f"Key {key} should be present in data for class tokenizer"
273
+ batch_size = len(data_batch[key]) if batch_size is None else batch_size
274
+ tokens_class = self._tokenize_class(data_batch[key])
275
+ if len(tokens) == 0:
276
+ tokens = tokens_class
277
+ for i in range(batch_size):
278
+ token_boundaries["class"].append((0, len(tokens[i])))
279
+ else:
280
+ for i in range(batch_size):
281
+ token_boundaries["class"].append((len(tokens[i]), len(tokens[i]) + len(tokens_class[i])))
282
+ tokens[i] = tokens[i] + tokens_class[i]
283
+
284
+ # Tokenize the video
285
+ if self.video_tokenizer is not None and self.tokenizer_config.video_tokenizer.tokenize_here:
286
+ key = self.tokenizer_config.video_tokenizer.data_key
287
+ assert key in data_batch, f"Key {key} should be present in data for video tokenizer"
288
+ batch_size = len(data_batch[key]) if batch_size is None else batch_size
289
+
290
+ pixel_chunk_duration = (
291
+ None # If not specified, we assume it's a video dataset and use the default chunk duration
292
+ )
293
+ dataset_name = data_batch.get("dataset_name", None)
294
+ if dataset_name is not None and dataset_name.startswith("image"):
295
+ # If it's an image dataset, we use a pixel chunk duration of 1
296
+ pixel_chunk_duration = 1
297
+ tokens_video = self._tokenize_video(data_batch[key], pixel_chunk_duration=pixel_chunk_duration)
298
+ if len(tokens) == 0:
299
+ tokens = tokens_video
300
+ for i in range(batch_size):
301
+ token_boundaries["video"].append((0, len(tokens[i])))
302
+ # [B,] each entry is ((0, len(tokens[i])))
303
+ else:
304
+ for i in range(batch_size):
305
+ token_boundaries["video"].append((len(tokens[i]), len(tokens[i]) + len(tokens_video[i])))
306
+ tokens[i] = tokens[i] + tokens_video[i]
307
+
308
+ # Combine the tokens and do padding
309
+ max_seq_len_in_batch = max([len(token) for token in tokens])
310
+ if self.pad_to_multiple_of is not None:
311
+ # Pad the sequence length to the nearest multiple of pad_to_multiple_of
312
+ max_seq_len_in_batch = ((max_seq_len_in_batch - 1) // self.pad_to_multiple_of + 1) * self.pad_to_multiple_of
313
+ pad_to_len = min(max_seq_len_in_batch, self.total_seq_len)
314
+ for i in range(len(tokens)):
315
+ if len(tokens[i]) < pad_to_len:
316
+ tokens[i] = tokens[i] + [self.pad_id] * (pad_to_len - len(tokens[i]))
317
+ else:
318
+ tokens[i] = tokens[i][0:pad_to_len]
319
+
320
+ # Convert it to long tensor
321
+ tokens = torch.LongTensor(tokens)
322
+ return tokens, token_boundaries
ar_tokenizer_utils.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any
17
+
18
+ import torch
19
+ from einops import pack, rearrange, unpack
20
+
21
+
22
+ def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
23
+ batch_size = x.shape[0]
24
+ return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
25
+
26
+
27
+ def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
28
+ return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
29
+
30
+
31
+ def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
32
+ batch_size, height = x.shape[0], x.shape[-2]
33
+ return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
34
+
35
+
36
+ def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
37
+ return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
38
+
39
+
40
+ def cast_tuple(t: Any, length: int = 1) -> Any:
41
+ return t if isinstance(t, tuple) else ((t,) * length)
42
+
43
+
44
+ def replication_pad(x):
45
+ return torch.cat([x[:, :, :1, ...], x], dim=2)
46
+
47
+
48
+ def divisible_by(num: int, den: int) -> bool:
49
+ return (num % den) == 0
50
+
51
+
52
+ def is_odd(n: int) -> bool:
53
+ return not divisible_by(n, 2)
54
+
55
+
56
+ def nonlinearity(x):
57
+ return x * torch.sigmoid(x)
58
+
59
+
60
+ class CausalNormalize(torch.nn.Module):
61
+ def __init__(self, in_channels, num_groups=1):
62
+ super().__init__()
63
+ self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
64
+ self.num_groups = num_groups
65
+
66
+ def forward(self, x):
67
+ # if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
68
+ # All new models should use num_groups=1, otherwise causality is not guaranteed.
69
+ if self.num_groups == 1:
70
+ x, batch_size = time2batch(x)
71
+ return batch2time(self.norm(x), batch_size)
72
+ return self.norm(x)
73
+
74
+
75
+ def exists(v):
76
+ return v is not None
77
+
78
+
79
+ def default(*args):
80
+ for arg in args:
81
+ if exists(arg):
82
+ return arg
83
+ return None
84
+
85
+
86
+ def pack_one(t, pattern):
87
+ return pack([t], pattern)
88
+
89
+
90
+ def unpack_one(t, ps, pattern):
91
+ return unpack(t, ps, pattern)[0]
92
+
93
+
94
+ def round_ste(z: torch.Tensor) -> torch.Tensor:
95
+ """Round with straight through gradients."""
96
+ zhat = z.round()
97
+ return z + (zhat - z).detach()
98
+
99
+
100
+ def log(t, eps=1e-5):
101
+ return t.clamp(min=eps).log()
ar_utils_checkpoint.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict, Optional
17
+
18
+ import torch
19
+
20
+ # Substrings to ignore when processing state dicts
21
+ substrings_to_ignore = [
22
+ "_extra_state", # Extra states (BytesIO type) added by TransformerEngine for FP8 handling
23
+ ]
24
+
25
+
26
+ def get_partial_state_dict(
27
+ state_dict: Dict[str, torch.Tensor],
28
+ prefix: str,
29
+ ) -> Dict[str, torch.Tensor]:
30
+ """
31
+ Get a partial state dict with keys starting with the given prefix
32
+ """
33
+ return {k: v for k, v in state_dict.items() if k.startswith(prefix)}
34
+
35
+
36
+ def process_state_dict(
37
+ state_dict: Dict[str, torch.Tensor],
38
+ device: str = None,
39
+ dtype: torch.dtype = None,
40
+ prefix_to_remove: Optional[str] = None,
41
+ ) -> Dict[str, torch.Tensor]:
42
+ """
43
+ - Remove items with substring "_extra_state" in keys (TransformerEngine adds these for FP8)
44
+ - Move tensors to specified device and dtype if provided
45
+
46
+ Args:
47
+ state_dict (Dict[str, torch.Tensor]): The state dict to process
48
+ device (str, optional): The device to move tensors to. Defaults to None.
49
+ dtype (torch.dtype, optional): The dtype to move tensors to. Defaults to None.
50
+ prefix_to_remove (str, optional): The prefix to remove from the keys of the state dict. Defaults to None.
51
+
52
+ Returns:
53
+ Dict[str, torch.Tensor]: The processed state dict
54
+ """
55
+ new_state_dict = {}
56
+ tensor_kwargs = {}
57
+ if device is not None:
58
+ tensor_kwargs["device"] = device
59
+ if dtype is not None:
60
+ tensor_kwargs["dtype"] = dtype
61
+
62
+ for key, value in state_dict.items():
63
+ # Check if any of the substrings to ignore are in the key
64
+ skip = False
65
+ for substr in substrings_to_ignore:
66
+ if substr in key:
67
+ skip = True
68
+ break
69
+ if skip:
70
+ continue
71
+ if len(tensor_kwargs) > 0:
72
+ value = value.to(**tensor_kwargs)
73
+ if prefix_to_remove is not None and key.startswith(prefix_to_remove):
74
+ key = key[len(prefix_to_remove) :]
75
+ new_state_dict[key] = value
76
+ return new_state_dict
ar_utils_inference.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import json
18
+ import math
19
+ import os
20
+ from pathlib import Path
21
+ from typing import List
22
+
23
+ import numpy as np
24
+ import torch
25
+ import torchvision
26
+ from PIL import Image
27
+
28
+ from .ar_config_inference_inference_config import SamplingConfig
29
+ from .log import log
30
+
31
+ _IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", "webp"]
32
+ _VIDEO_EXTENSIONS = [".mp4"]
33
+ _SUPPORTED_CONTEXT_LEN = [1, 9] # Input frames
34
+ NUM_TOTAL_FRAMES = 33
35
+
36
+
37
+ def add_common_arguments(parser):
38
+ """Add common command line arguments.
39
+
40
+ Args:
41
+ parser (ArgumentParser): Argument parser to add arguments to
42
+ """
43
+ parser.add_argument(
44
+ "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
45
+ )
46
+ parser.add_argument(
47
+ "--video_save_name",
48
+ type=str,
49
+ default="output",
50
+ help="Output filename for generating a single video",
51
+ )
52
+ parser.add_argument("--video_save_folder", type=str, default="outputs/", help="Output folder for saving videos")
53
+ parser.add_argument(
54
+ "--input_image_or_video_path",
55
+ type=str,
56
+ help="Input path for input image or video",
57
+ )
58
+ parser.add_argument(
59
+ "--batch_input_path",
60
+ type=str,
61
+ help="Input folder containing all input images or videos",
62
+ )
63
+ parser.add_argument(
64
+ "--num_input_frames",
65
+ type=int,
66
+ default=9,
67
+ help="Number of input frames for world generation",
68
+ choices=_SUPPORTED_CONTEXT_LEN,
69
+ )
70
+ parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
71
+ parser.add_argument("--top_p", type=float, default=0.8, help="Top-p value for sampling")
72
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
73
+ parser.add_argument("--disable_diffusion_decoder", action="store_true", help="Disable diffusion decoder")
74
+ parser.add_argument(
75
+ "--offload_guardrail_models",
76
+ action="store_true",
77
+ help="Offload guardrail models after inference",
78
+ )
79
+ parser.add_argument(
80
+ "--offload_diffusion_decoder",
81
+ action="store_true",
82
+ help="Offload diffusion decoder after inference",
83
+ )
84
+ parser.add_argument(
85
+ "--offload_ar_model",
86
+ action="store_true",
87
+ help="Offload AR model after inference",
88
+ )
89
+ parser.add_argument(
90
+ "--offload_tokenizer",
91
+ action="store_true",
92
+ help="Offload discrete tokenizer model after inference",
93
+ )
94
+
95
+
96
+ def validate_args(args: argparse.Namespace, inference_type: str):
97
+ """Validate command line arguments for base and video2world generation."""
98
+ assert inference_type in [
99
+ "base",
100
+ "video2world",
101
+ ], "Invalid inference_type, must be 'base' or 'video2world'"
102
+ if args.input_type in ["image", "text_and_image"] and args.num_input_frames != 1:
103
+ args.num_input_frames = 1
104
+ log.info(f"Set num_input_frames to 1 for {args.input_type} input")
105
+
106
+ if args.num_input_frames == 1:
107
+ if "4B" in args.ar_model_dir:
108
+ log.warning(
109
+ "The failure rate for 4B model with image input is ~15%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details."
110
+ )
111
+ elif "5B" in args.ar_model_dir:
112
+ log.warning(
113
+ "The failure rate for 5B model with image input is ~7%. 12B / 13B model have a smaller failure rate. Please be cautious and refer to README.md for more details."
114
+ )
115
+
116
+ # Validate prompt/image/video args for single or batch generation
117
+ assert (
118
+ args.input_image_or_video_path or args.batch_input_path
119
+ ), "--input_image_or_video_path or --batch_input_path must be provided."
120
+ if inference_type == "video2world" and (not args.batch_input_path):
121
+ assert args.prompt, "--prompt is required for single video generation."
122
+ args.data_resolution = [640, 1024]
123
+
124
+ # Validate number of GPUs
125
+ num_gpus = int(os.getenv("WORLD_SIZE", 1))
126
+ assert num_gpus <= 1, "We support only single GPU inference for now"
127
+
128
+ # Create output folder
129
+ Path(args.video_save_folder).mkdir(parents=True, exist_ok=True)
130
+
131
+ sampling_config = SamplingConfig(
132
+ echo=True,
133
+ temperature=args.temperature,
134
+ top_p=args.top_p,
135
+ compile_sampling=True,
136
+ )
137
+ return sampling_config
138
+
139
+
140
+ def resize_input(video: torch.Tensor, resolution: list[int]):
141
+ r"""
142
+ Function to perform aspect ratio preserving resizing and center cropping.
143
+ This is needed to make the video into target resolution.
144
+ Args:
145
+ video (torch.Tensor): Input video tensor
146
+ resolution (list[int]): Data resolution
147
+ Returns:
148
+ Cropped video
149
+ """
150
+
151
+ orig_h, orig_w = video.shape[2], video.shape[3]
152
+ target_h, target_w = resolution
153
+
154
+ scaling_ratio = max((target_w / orig_w), (target_h / orig_h))
155
+ resizing_shape = (int(math.ceil(scaling_ratio * orig_h)), int(math.ceil(scaling_ratio * orig_w)))
156
+ video_resized = torchvision.transforms.functional.resize(video, resizing_shape)
157
+ video_cropped = torchvision.transforms.functional.center_crop(video_resized, resolution)
158
+ return video_cropped
159
+
160
+
161
+ def load_image_from_list(flist, data_resolution: List[int]) -> dict:
162
+ """
163
+ Function to load images from a list of image paths.
164
+ Args:
165
+ flist (List[str]): List of image paths
166
+ data_resolution (List[int]): Data resolution
167
+ Returns:
168
+ Dict containing input images
169
+ """
170
+ all_videos = dict()
171
+ for img_path in flist:
172
+ ext = os.path.splitext(img_path)[1]
173
+ if ext in _IMAGE_EXTENSIONS:
174
+ # Read the image
175
+ img = Image.open(img_path)
176
+
177
+ # Convert to tensor
178
+ img = torchvision.transforms.functional.to_tensor(img)
179
+ static_vid = img.unsqueeze(0).repeat(NUM_TOTAL_FRAMES, 1, 1, 1)
180
+ static_vid = static_vid * 2 - 1
181
+
182
+ log.debug(
183
+ f"Resizing input image of shape ({static_vid.shape[2]}, {static_vid.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})"
184
+ )
185
+ static_vid = resize_input(static_vid, data_resolution)
186
+ fname = os.path.basename(img_path)
187
+ all_videos[fname] = static_vid.transpose(0, 1).unsqueeze(0)
188
+
189
+ return all_videos
190
+
191
+
192
+ def read_input_images(batch_input_path: str, data_resolution: List[int]) -> dict:
193
+ """
194
+ Function to read input images from a JSONL file.
195
+
196
+ Args:
197
+ batch_input_path (str): Path to JSONL file containing visual input paths
198
+ data_resolution (list[int]): Data resolution
199
+
200
+ Returns:
201
+ Dict containing input images
202
+ """
203
+ # Read visual inputs from JSONL
204
+ flist = []
205
+ with open(batch_input_path, "r") as f:
206
+ for line in f:
207
+ data = json.loads(line.strip())
208
+ flist.append(data["visual_input"])
209
+
210
+ return load_image_from_list(flist, data_resolution=data_resolution)
211
+
212
+
213
+ def read_input_image(input_path: str, data_resolution: List[int]) -> dict:
214
+ """
215
+ Function to read input image.
216
+ Args:
217
+ input_path (str): Path to input image
218
+ data_resolution (List[int]): Data resolution
219
+ Returns:
220
+ Dict containing input image
221
+ """
222
+ flist = [input_path]
223
+ return load_image_from_list(flist, data_resolution=data_resolution)
224
+
225
+
226
+ def read_input_videos(batch_input_path: str, data_resolution: List[int], num_input_frames: int) -> dict:
227
+ r"""
228
+ Function to read input videos.
229
+ Args:
230
+ batch_input_path (str): Path to JSONL file containing visual input paths
231
+ data_resolution (list[int]): Data resolution
232
+ Returns:
233
+ Dict containing input videos
234
+ """
235
+ # Read visual inputs from JSONL
236
+ flist = []
237
+ with open(batch_input_path, "r") as f:
238
+ for line in f:
239
+ data = json.loads(line.strip())
240
+ flist.append(data["visual_input"])
241
+ return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames)
242
+
243
+
244
+ def read_input_video(input_path: str, data_resolution: List[int], num_input_frames: int) -> dict:
245
+ """
246
+ Function to read input video.
247
+ Args:
248
+ input_path (str): Path to input video
249
+ data_resolution (List[int]): Data resolution
250
+ num_input_frames (int): Number of frames in context
251
+ Returns:
252
+ Dict containing input video
253
+ """
254
+ flist = [input_path]
255
+ return load_videos_from_list(flist, data_resolution=data_resolution, num_input_frames=num_input_frames)
256
+
257
+
258
+ def load_videos_from_list(flist: List[str], data_resolution: List[int], num_input_frames: int) -> dict:
259
+ """
260
+ Function to load videos from a list of video paths.
261
+ Args:
262
+ flist (List[str]): List of video paths
263
+ data_resolution (List[int]): Data resolution
264
+ num_input_frames (int): Number of frames in context
265
+ Returns:
266
+ Dict containing input videos
267
+ """
268
+ all_videos = dict()
269
+
270
+ for video_path in flist:
271
+ ext = os.path.splitext(video_path)[-1]
272
+ if ext in _VIDEO_EXTENSIONS:
273
+ video, _, _ = torchvision.io.read_video(video_path, pts_unit="sec")
274
+ video = video.float() / 255.0
275
+ video = video * 2 - 1
276
+
277
+ # Resize the videos to the required dimension
278
+ nframes_in_video = video.shape[0]
279
+ if nframes_in_video < num_input_frames:
280
+ fname = os.path.basename(video_path)
281
+ log.warning(
282
+ f"Video {fname} has {nframes_in_video} frames, less than the requried {num_input_frames} frames. Skipping."
283
+ )
284
+ continue
285
+
286
+ video = video[-num_input_frames:, :, :, :]
287
+
288
+ # Pad the video to NUM_TOTAL_FRAMES (because the tokenizer expects inputs of NUM_TOTAL_FRAMES)
289
+ video = torch.cat(
290
+ (video, video[-1, :, :, :].unsqueeze(0).repeat(NUM_TOTAL_FRAMES - num_input_frames, 1, 1, 1)),
291
+ dim=0,
292
+ )
293
+
294
+ video = video.permute(0, 3, 1, 2)
295
+
296
+ log.debug(
297
+ f"Resizing input video of shape ({video.shape[2]}, {video.shape[3]}) -> ({data_resolution[0]}, {data_resolution[1]})"
298
+ )
299
+ video = resize_input(video, data_resolution)
300
+
301
+ fname = os.path.basename(video_path)
302
+ all_videos[fname] = video.transpose(0, 1).unsqueeze(0)
303
+
304
+ return all_videos
305
+
306
+
307
+ def load_vision_input(
308
+ input_type: str,
309
+ batch_input_path: str,
310
+ input_image_or_video_path: str,
311
+ data_resolution: List[int],
312
+ num_input_frames: int,
313
+ ):
314
+ """
315
+ Function to load vision input.
316
+ Note: We pad the frames of the input image/video to NUM_TOTAL_FRAMES here, and feed the padded video tensors to the video tokenizer to obtain tokens. The tokens will be truncated based on num_input_frames when feeding to the autoregressive model.
317
+ Args:
318
+ input_type (str): Type of input
319
+ batch_input_path (str): Folder containing input images or videos
320
+ input_image_or_video_path (str): Path to input image or video
321
+ data_resolution (List[int]): Data resolution
322
+ num_input_frames (int): Number of frames in context
323
+ Returns:
324
+ Dict containing input videos
325
+ """
326
+ if batch_input_path:
327
+ log.info(f"Reading batch inputs from path: {batch_input_path}")
328
+ if input_type == "image" or input_type == "text_and_image":
329
+ input_videos = read_input_images(batch_input_path, data_resolution=data_resolution)
330
+ elif input_type == "video" or input_type == "text_and_video":
331
+ input_videos = read_input_videos(
332
+ batch_input_path,
333
+ data_resolution=data_resolution,
334
+ num_input_frames=num_input_frames,
335
+ )
336
+ else:
337
+ raise ValueError(f"Invalid input type {input_type}")
338
+ else:
339
+ if input_type == "image" or input_type == "text_and_image":
340
+ input_videos = read_input_image(input_image_or_video_path, data_resolution=data_resolution)
341
+ elif input_type == "video" or input_type == "text_and_video":
342
+ input_videos = read_input_video(
343
+ input_image_or_video_path,
344
+ data_resolution=data_resolution,
345
+ num_input_frames=num_input_frames,
346
+ )
347
+ else:
348
+ raise ValueError(f"Invalid input type {input_type}")
349
+ return input_videos
350
+
351
+
352
+ def prepare_video_batch_for_saving(video_batch: List[torch.Tensor]) -> List[np.ndarray]:
353
+ """
354
+ Function to convert output tensors to numpy format for saving.
355
+ Args:
356
+ video_batch (List[torch.Tensor]): List of output tensors
357
+ Returns:
358
+ List of numpy arrays
359
+ """
360
+ return [(video * 255).to(torch.uint8).permute(1, 2, 3, 0).cpu().numpy() for video in video_batch]
ar_utils_misc.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from omegaconf import DictConfig, OmegaConf
17
+
18
+
19
+ class CustomSimpleNamespace:
20
+ """
21
+ A simple namespace class that supports both attribute-style and dictionary-style access.
22
+ """
23
+
24
+ def __init__(self, d):
25
+ self._d = d
26
+
27
+ def __getattr__(self, attr):
28
+ # Attribute-style access: config.key
29
+ try:
30
+ return self._d[attr]
31
+ except KeyError:
32
+ raise AttributeError(f"'CustomSimpleNamespace' object has no attribute '{attr}'")
33
+
34
+ def __getitem__(self, key):
35
+ # Dictionary-style access: config['key']
36
+ return self._d[key]
37
+
38
+
39
+ def maybe_convert_to_namespace(config):
40
+ """
41
+ This function cast a OmegaConf's DictConfig or a standard dict to CustomSimpleNamespace, which supports both
42
+ attribute-style and dictionary-style access.
43
+ Note: We need to convert OmegaConf's DictConfig since it is not compatible with torch.compile.
44
+ """
45
+ # If input is OmegaConf's DictConfig, convert to a standard dict
46
+ if isinstance(config, DictConfig):
47
+ config = OmegaConf.to_container(config, resolve=True)
48
+
49
+ if isinstance(config, dict):
50
+ return CustomSimpleNamespace(config)
51
+ else:
52
+ return config
ar_utils_sampling.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from .ar_network_transformer import Transformer
21
+
22
+
23
+ def sample_top_p(logits, temperature, top_p, return_probs: bool = False):
24
+ """
25
+ Perform top-p (nucleus) sampling on a probability distribution.
26
+
27
+ Args:
28
+ logits (torch.Tensor): Logits of the probability distribution.
29
+ temperature (float): Temperature for sampling.
30
+ top_p (float): Probability threshold for top-p sampling.
31
+
32
+ Returns:
33
+ torch.Tensor: Sampled token indices.
34
+
35
+ Note:
36
+ Top-p sampling selects the smallest set of tokens whose cumulative probability mass
37
+ exceeds the threshold p. The distribution is renormalized based on the selected tokens.
38
+ """
39
+ probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
40
+ # Sort the probabilities in descending order and get their indices.
41
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
42
+ # Compute the cumulative sum of the sorted probabilities.
43
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
44
+ # Create a mask where the cumulative probability exceeds the threshold p.
45
+ mask = probs_sum - probs_sort > top_p
46
+ # Set the probabilities that exceed the threshold to 0.
47
+ probs_sort[mask] = 0.0
48
+ # Renormalize the remaining probabilities so they sum to 1.
49
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
50
+ # Sample from the renormalized probability distribution.
51
+ # next_token = torch.multinomial(probs_sort, num_samples=1)
52
+ next_token = multinomial_sample_one_no_sync(probs_sort, dtype=torch.int64)
53
+ # Gather the indices of the sampled tokens.
54
+ next_token = torch.gather(probs_idx, -1, next_token)
55
+ if return_probs:
56
+ # Initialize a tensor for unsorted probabilities
57
+ probs_unsorted = torch.zeros_like(probs_sort)
58
+ # Scatter the sorted probabilities back to their original order
59
+ probs_unsorted.scatter_(-1, probs_idx, probs_sort)
60
+ else:
61
+ probs_unsorted = None
62
+ return next_token, probs_unsorted
63
+
64
+
65
+ def multinomial_sample_one_no_sync(probs_sort, dtype=torch.int):
66
+ """
67
+ Multinomial sampling without a cuda synchronization.
68
+ Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
69
+ """
70
+ q = torch.empty_like(probs_sort).exponential_(1)
71
+ return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=dtype)
72
+
73
+
74
+ def logits_to_probs(
75
+ logits,
76
+ temperature: float = 1.0,
77
+ top_k: Optional[int] = None,
78
+ ):
79
+ logits = logits / max(temperature, 1e-5)
80
+
81
+ if top_k is not None:
82
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
83
+ pivot = v.select(-1, -1).unsqueeze(-1)
84
+ logits = torch.where(logits < pivot, -float("Inf"), logits)
85
+ probs = torch.nn.functional.softmax(logits, dim=-1)
86
+ return probs
87
+
88
+
89
+ def sample_top_k(logits, temperature: float = 1.0, top_k: Optional[int] = None):
90
+ """
91
+ Sample from the logits using top-k sampling.
92
+ Source: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
93
+ """
94
+ # logits: [batch_size, seq_len, vocab_size]
95
+ if temperature == 0.0:
96
+ idx_next = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
97
+ probs = None
98
+ else:
99
+ probs = logits_to_probs(logits[:, -1, :], temperature, top_k)
100
+ idx_next = multinomial_sample_one_no_sync(probs)
101
+ return idx_next, probs
102
+
103
+
104
+ def prefill(
105
+ model: Transformer,
106
+ input_pos: torch.Tensor,
107
+ tokens: torch.Tensor = None,
108
+ token_embeddings: torch.Tensor = None,
109
+ temperature: float = 1.0,
110
+ top_k: Optional[int] = None,
111
+ top_p: Optional[float] = None,
112
+ **kwargs,
113
+ ) -> torch.Tensor:
114
+ logits = model(tokens=tokens, token_embeddings=token_embeddings, input_pos=input_pos, **kwargs)
115
+ # Only top-p or top-k can be provided
116
+ assert (
117
+ top_p is None or top_k is None
118
+ ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
119
+ if top_p is not None:
120
+ return sample_top_p(logits, temperature=temperature, top_p=top_p)[0]
121
+ else:
122
+ return sample_top_k(logits, temperature=temperature, top_k=top_k)[0]
123
+
124
+
125
+ def decode_one_token(
126
+ model: Transformer,
127
+ tokens: torch.Tensor,
128
+ input_pos: torch.Tensor,
129
+ temperature: float = 1.0,
130
+ top_k: Optional[int] = None,
131
+ top_p: Optional[float] = None,
132
+ **kwargs,
133
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
134
+ """
135
+ Decode a single token from the autoregressive model.
136
+ """
137
+ logits = model(tokens=tokens, input_pos=input_pos, **kwargs)
138
+ if top_p is not None:
139
+ return sample_top_p(logits, temperature=temperature, top_p=top_p)
140
+ else:
141
+ return sample_top_k(logits, temperature=temperature, top_k=top_k)
142
+
143
+
144
+ def decode_n_tokens(
145
+ model: Transformer,
146
+ cur_token: torch.Tensor,
147
+ input_pos: torch.Tensor,
148
+ num_new_tokens: int,
149
+ stop_tokens: torch.Tensor = None,
150
+ temperature: float = 1.0,
151
+ top_p: Optional[float] = None,
152
+ top_k: Optional[int] = None,
153
+ return_probs: bool = False,
154
+ decode_one_token_function=decode_one_token,
155
+ **kwargs,
156
+ ):
157
+ """
158
+ Decode n tokens from the autoregressive model.
159
+ Adapted from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
160
+ """
161
+ new_tokens, new_probs = [], []
162
+ batch_size = cur_token.shape[0]
163
+ assert (
164
+ top_p is None or top_k is None
165
+ ), "Only one of top-p or top-k can be provided, got top-p={top_p} and top-k={top_k}"
166
+ if stop_tokens is not None:
167
+ # Indicator for whether the EOS token (stop token) has been reached for each sample in the batch
168
+ eos_reached = torch.tensor([False] * batch_size, device="cuda")
169
+ for t in range(num_new_tokens):
170
+ with torch.backends.cuda.sdp_kernel(
171
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
172
+ ): # Actually better for Inductor to codegen attention here
173
+ next_token, next_prob = decode_one_token_function(
174
+ model,
175
+ tokens=cur_token,
176
+ input_pos=input_pos,
177
+ temperature=temperature,
178
+ top_k=top_k,
179
+ top_p=top_p,
180
+ **kwargs,
181
+ )
182
+ input_pos += 1
183
+ if stop_tokens is not None and len(stop_tokens) > 0:
184
+ eos_reached = eos_reached | (torch.isin(next_token, stop_tokens))
185
+ if eos_reached.all():
186
+ break
187
+ new_tokens.append(next_token.clone())
188
+ if return_probs:
189
+ new_probs.append(next_prob.clone())
190
+ cur_token = next_token.clone()
191
+
192
+ if return_probs:
193
+ return new_tokens, new_probs
194
+ else:
195
+ return new_tokens
assets/cosmos-logo.png ADDED
assets/diffusion_decoder_image_output.mp4 ADDED
Binary file (371 kB). View file
 
assets/diffusion_decoder_video_output.mp4 ADDED
Binary file (200 kB). View file
 
assets/image_output.mp4 ADDED
Binary file (234 kB). View file
 
assets/video_output.mp4 ADDED
Binary file (109 kB). View file
 
base.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ import imageio
20
+ import torch
21
+
22
+ from .world_generation_pipeline import ARBaseGenerationPipeline
23
+ from .ar_utils_inference import add_common_arguments, load_vision_input, validate_args
24
+ from .log import log
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser(description="Video to world generation demo script")
29
+ # Add common arguments
30
+ add_common_arguments(parser)
31
+ parser.add_argument(
32
+ "--ar_model_dir",
33
+ type=str,
34
+ default="Cosmos-1.0-Autoregressive-4B",
35
+ )
36
+ parser.add_argument("--input_type", type=str, default="video", help="Type of input", choices=["image", "video"])
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def main(args):
42
+ """Run video-to-world generation demo.
43
+
44
+ This function handles the main video-to-world generation pipeline, including:
45
+ - Setting up the random seed for reproducibility
46
+ - Initializing the generation pipeline with the provided configuration
47
+ - Processing single or multiple images/videos from input
48
+ - Generating videos from images/videos
49
+ - Saving the generated videos to disk
50
+
51
+ Args:
52
+ cfg (argparse.Namespace): Configuration namespace containing:
53
+ - Model configuration (checkpoint paths, model settings)
54
+ - Generation parameters (temperature, top_p)
55
+ - Input/output settings (images/videos, save paths)
56
+ - Performance options (model offloading settings)
57
+
58
+ The function will save:
59
+ - Generated MP4 video files
60
+
61
+ If guardrails block the generation, a critical log message is displayed
62
+ and the function continues to the next prompt if available.
63
+ """
64
+ inference_type = "base" # When the inference_type is "base", AR model does not take text as input, the world generation is purely based on the input video
65
+ sampling_config = validate_args(args, inference_type)
66
+
67
+ # Initialize base generation model pipeline
68
+ pipeline = ARBaseGenerationPipeline(
69
+ inference_type=inference_type,
70
+ checkpoint_dir=args.checkpoint_dir,
71
+ checkpoint_name=args.ar_model_dir,
72
+ disable_diffusion_decoder=args.disable_diffusion_decoder,
73
+ offload_guardrail_models=args.offload_guardrail_models,
74
+ offload_diffusion_decoder=args.offload_diffusion_decoder,
75
+ offload_network=args.offload_ar_model,
76
+ offload_tokenizer=args.offload_tokenizer,
77
+ )
78
+
79
+ # Load input image(s) or video(s)
80
+ input_videos = load_vision_input(
81
+ input_type=args.input_type,
82
+ batch_input_path=args.batch_input_path,
83
+ input_image_or_video_path=args.input_image_or_video_path,
84
+ data_resolution=args.data_resolution,
85
+ num_input_frames=args.num_input_frames,
86
+ )
87
+
88
+ for idx, input_filename in enumerate(input_videos):
89
+ inp_vid = input_videos[input_filename]
90
+ # Generate video
91
+ log.info(f"Run with image or video path: {input_filename}")
92
+ out_vid = pipeline.generate(
93
+ inp_vid=inp_vid,
94
+ num_input_frames=args.num_input_frames,
95
+ seed=args.seed,
96
+ sampling_config=sampling_config,
97
+ )
98
+ if out_vid is None:
99
+ log.critical("Guardrail blocked base generation.")
100
+ continue
101
+
102
+ # Save video
103
+ if args.input_image_or_video_path:
104
+ out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4")
105
+ else:
106
+ out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4")
107
+
108
+ imageio.mimsave(out_vid_path, out_vid, fps=25)
109
+
110
+ log.info(f"Saved video to {out_vid_path}")
111
+
112
+
113
+ if __name__ == "__main__":
114
+ torch._C._jit_set_texpr_fuser_enabled(False)
115
+ args = parse_args()
116
+ main(args)
base_world_generation_pipeline.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import os
18
+ from abc import ABC
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from .t5_text_encoder import CosmosT5TextEncoder
25
+ from .guardrail_common_presets import presets as guardrail_presets
26
+
27
+
28
+ class BaseWorldGenerationPipeline(ABC):
29
+ def __init__(
30
+ self,
31
+ inference_type: str | None = None,
32
+ checkpoint_dir: str | None = None,
33
+ checkpoint_name: str | None = None,
34
+ has_text_input: bool = False,
35
+ offload_network: bool = False,
36
+ offload_tokenizer: bool = False,
37
+ offload_text_encoder_model: bool = False,
38
+ offload_guardrail_models: bool = False,
39
+ ):
40
+ """Initialize base world generation pipeline.
41
+
42
+ This abstract base class provides core functionality for world generation models including:
43
+ - Model loading and initialization
44
+ - Text encoding and embedding
45
+ - Safety checks and content filtering
46
+ - Memory management through model offloading
47
+
48
+ Args:
49
+ inference_type: The type of inference pipeline ("text2world" or "video2world")
50
+ checkpoint_dir: Root directory containing model checkpoints
51
+ checkpoint_name: Name of the specific checkpoint file to load
52
+ has_text_input: Whether the pipeline takes text input for world generation
53
+ offload_network: If True, moves main model to CPU after inference
54
+ offload_tokenizer: If True, moves tokenizer to CPU after use
55
+ offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding
56
+ offload_guardrail_models: If True, moves safety models to CPU after checks
57
+ """
58
+ self.inference_type = inference_type
59
+ self.checkpoint_dir = checkpoint_dir
60
+ self.checkpoint_name = checkpoint_name
61
+ self.guardrail_dir = "Cosmos-1.0-Guardrail"
62
+ self.has_text_input = has_text_input
63
+
64
+ # Add offloading flags
65
+ self.offload_network = offload_network
66
+ self.offload_tokenizer = offload_tokenizer
67
+ self.offload_text_encoder_model = offload_text_encoder_model
68
+ self.offload_guardrail_models = offload_guardrail_models
69
+
70
+ # Initialize model instances
71
+ self.text_guardrail = None
72
+ self.video_guardrail = None
73
+ self.text_encoder = None
74
+ self.model = None
75
+
76
+ self._load_model()
77
+
78
+ if not self.offload_text_encoder_model:
79
+ self._load_text_encoder_model()
80
+ if not self.offload_guardrail_models:
81
+ if self.has_text_input:
82
+ self._load_text_guardrail()
83
+ self._load_video_guardrail()
84
+ if not self.offload_network:
85
+ self._load_network()
86
+ if not self.offload_tokenizer:
87
+ self._load_tokenizer()
88
+
89
+ def _load_tokenizer(self):
90
+ pass
91
+
92
+ def _load_network(self):
93
+ pass
94
+
95
+ def _load_model(self, checkpoint_name: str) -> Any:
96
+ """Load the world generation model from a checkpoint.
97
+
98
+ This abstract method must be implemented by subclasses to load their specific
99
+ model architecture and weights.
100
+
101
+ Args:
102
+ checkpoint_name: Path to the model checkpoint file
103
+
104
+ Returns:
105
+ The loaded model instance
106
+
107
+ Raises:
108
+ NotImplementedError: Must be implemented by subclasses
109
+ """
110
+ pass
111
+
112
+ def _load_text_encoder_model(self):
113
+ """Load the T5 text encoder model.
114
+
115
+ Initializes and loads the T5 encoder model used for converting text prompts
116
+ into embeddings that condition the world generation model.
117
+
118
+ Returns:
119
+ Loaded T5 text encoder model instance
120
+ """
121
+ self.text_encoder = CosmosT5TextEncoder(cache_dir=self.checkpoint_dir)
122
+
123
+ def _load_text_guardrail(self):
124
+ """Load text safety classifier models.
125
+
126
+ Initializes models used for checking input prompts against safety policies.
127
+ Models are loaded from the specified guardrail directory.
128
+ """
129
+ self.text_guardrail = guardrail_presets.create_text_guardrail_runner(
130
+ checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
131
+ )
132
+
133
+ def _load_video_guardrail(self):
134
+ """Load video safety classifier models.
135
+
136
+ Initializes models used for validating generated video content against
137
+ safety policies. Models are loaded from the specified guardrail directory.
138
+ """
139
+ self.video_guardrail = guardrail_presets.create_video_guardrail_runner(
140
+ checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
141
+ )
142
+
143
+ def _offload_network(self):
144
+ if self.model.model:
145
+ del self.model.model
146
+ self.model.model = None
147
+ gc.collect()
148
+ torch.cuda.empty_cache()
149
+
150
+ def _offload_tokenizer(self):
151
+ if self.model.tokenizer:
152
+ del self.model.tokenizer
153
+ self.model.tokenizer = None
154
+ gc.collect()
155
+ torch.cuda.empty_cache()
156
+
157
+ def _offload_guardrail_models(self):
158
+ """Offload safety classifier models to reduce memory usage.
159
+
160
+ Moves safety models to CPU and clears GPU memory if they are no longer needed.
161
+ This helps manage memory when processing multiple inputs sequentially.
162
+ """
163
+ if self.text_guardrail:
164
+ del self.text_guardrail
165
+ self.text_guardrail = None
166
+ if self.video_guardrail:
167
+ del self.video_guardrail
168
+ self.video_guardrail = None
169
+ gc.collect()
170
+ torch.cuda.empty_cache()
171
+
172
+ def _offload_text_encoder_model(self):
173
+ """Offload T5 text encoder to reduce memory usage.
174
+
175
+ Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete.
176
+ This helps manage memory when processing multiple inputs sequentially.
177
+ """
178
+ if self.text_encoder:
179
+ del self.text_encoder
180
+ self.text_encoder = None
181
+ gc.collect()
182
+ torch.cuda.empty_cache()
183
+
184
+ def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor:
185
+ """Generate world latents using the model.
186
+
187
+ This abstract method must be implemented by subclasses to define their specific
188
+ generation process.
189
+
190
+ Args:
191
+ *args: Variable positional arguments for model inference
192
+ **kwargs: Variable keyword arguments for model inference
193
+
194
+ Returns:
195
+ torch.Tensor: Generated world representation tensor
196
+ """
197
+ pass
198
+
199
+ def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor:
200
+ """Generate world representation with memory management.
201
+
202
+ Handles loading the model before inference and offloading afterward if enabled.
203
+ This helps minimize GPU memory usage during inference.
204
+
205
+ Args:
206
+ *args: Arguments passed to _run_model
207
+ **kwargs: Keyword arguments passed to _run_model
208
+
209
+ Returns:
210
+ np.ndarray: Generated world representation as numpy array
211
+ """
212
+ pass
213
+
214
+ def _run_guardrail_on_prompt(self, prompt: str) -> bool:
215
+ """Check if prompt meets safety requirements.
216
+
217
+ Validates the input prompt against safety policies using loaded guardrail models.
218
+
219
+ Args:
220
+ prompt: Raw text prompt to validate
221
+
222
+ Returns:
223
+ bool: True if prompt passes all safety checks, False otherwise
224
+ """
225
+ return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail)
226
+
227
+ def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool:
228
+ """Check prompt safety with memory management.
229
+
230
+ Validates prompt safety while handling model loading/offloading to manage memory.
231
+
232
+ Args:
233
+ prompt: Raw text prompt to validate
234
+
235
+ Returns:
236
+ bool: True if prompt passes all safety checks, False otherwise
237
+ """
238
+ if self.offload_guardrail_models:
239
+ self._load_text_guardrail()
240
+
241
+ is_safe = self._run_guardrail_on_prompt(prompt)
242
+
243
+ if self.offload_guardrail_models:
244
+ self._offload_guardrail_models()
245
+
246
+ return is_safe
247
+
248
+ def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None:
249
+ """Check if video meets safety requirements.
250
+
251
+ Validates generated video content against safety policies using guardrail models.
252
+
253
+ Args:
254
+ video: Video frames to validate
255
+
256
+ Returns:
257
+ np.ndarray: Processed video if safe, None if unsafe
258
+ """
259
+ return guardrail_presets.run_video_guardrail(video, self.video_guardrail)
260
+
261
+ def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None:
262
+ """Check if generated video meets safety requirements.
263
+
264
+ Args:
265
+ video: Video frames to validate
266
+
267
+ Returns:
268
+ np.ndarray: Processed video frames if safe, None otherwise
269
+
270
+ Note:
271
+ Guardrail models are offloaded after checks if enabled.
272
+ """
273
+ if self.offload_guardrail_models:
274
+ self._load_video_guardrail()
275
+
276
+ video = self._run_guardrail_on_video(video)
277
+
278
+ if self.offload_guardrail_models:
279
+ self._offload_guardrail_models()
280
+ return video
281
+
282
+ def _run_text_embedding_on_prompt(
283
+ self, prompts: list[str], **kwargs: Any
284
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
285
+ """Convert text prompts to embeddings.
286
+
287
+ Processes text prompts into embedding tensors that condition the generation model.
288
+
289
+ Args:
290
+ prompts: List of text prompts to encode
291
+ **kwargs: Additional arguments for text encoding
292
+
293
+ Returns:
294
+ tuple containing:
295
+ - List of text embedding tensors for each prompt
296
+ - List of attention masks for each embedding
297
+ """
298
+
299
+ embeddings = []
300
+ masks = []
301
+ for prompt in prompts:
302
+ embedding, mask = self.text_encoder.encode_prompts(
303
+ [prompt],
304
+ **kwargs,
305
+ )
306
+ embeddings.append(embedding)
307
+ masks.append(mask)
308
+
309
+ return embeddings, masks
310
+
311
+ def _run_text_embedding_on_prompt_with_offload(
312
+ self, prompts: list[str], **kwargs: Any
313
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
314
+ """Convert text prompt into embeddings using T5 encoder.
315
+
316
+ Args:
317
+ prompt: Processed and validated text prompt
318
+
319
+ Returns:
320
+ Text embedding tensor to condition diffusion model
321
+
322
+ Note:
323
+ T5 model is offloaded after encoding if enabled.
324
+ """
325
+ if self.offload_text_encoder_model:
326
+ self._load_text_encoder_model()
327
+
328
+ embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs)
329
+
330
+ if self.offload_text_encoder_model:
331
+ self._offload_text_encoder_model()
332
+ return embeddings, masks
333
+
334
+ def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray:
335
+ """Decode model outputs into final world representation.
336
+
337
+ This abstract method must be implemented by subclasses to convert raw model
338
+ outputs into their specific world representation format.
339
+
340
+ Args:
341
+ samples: Raw output tensor from the generation model
342
+
343
+ Returns:
344
+ np.ndarray: Decoded world representation
345
+ """
346
+ pass
347
+
348
+ def generate(self, *args: Any, **kwargs: Any):
349
+ """Generate world representation.
350
+
351
+ This abstract method must be implemented by subclasses to convert raw model
352
+ outputs into their specific world representation format.
353
+
354
+ Args:
355
+ *args: Variable positional arguments for model inference
356
+ **kwargs: Variable keyword arguments for model inference
357
+ """
358
+ pass
config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ARVideo2World"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "video2world_hf.ARVideo2WorldConfig",
7
+ "AutoModel": "video2world_hf.ARVideo2World"
8
+ },
9
+ "model_type": "AutoModel"
10
+ }
config.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, TypeVar
19
+
20
+ import attrs
21
+
22
+ from .lazy_config_init import LazyDict
23
+ from .misc import Color
24
+
25
+ T = TypeVar("T")
26
+
27
+
28
+ def _is_attrs_instance(obj: object) -> bool:
29
+ """
30
+ Helper function to check if an object is an instance of an attrs-defined class.
31
+
32
+ Args:
33
+ obj: The object to check.
34
+
35
+ Returns:
36
+ bool: True if the object is an instance of an attrs-defined class, False otherwise.
37
+ """
38
+ return hasattr(obj, "__attrs_attrs__")
39
+
40
+
41
+ def make_freezable(cls: T) -> T:
42
+ """
43
+ A decorator that adds the capability to freeze instances of an attrs-defined class.
44
+
45
+ NOTE: This requires the wrapped attrs to be defined with attrs.define(slots=False) because we need
46
+ to hack on a "_is_frozen" attribute.
47
+
48
+ This decorator enhances an attrs-defined class with the ability to be "frozen" at runtime.
49
+ Once an instance is frozen, its attributes cannot be changed. It also recursively freezes
50
+ any attrs-defined objects that are attributes of the class.
51
+
52
+ Usage:
53
+ @make_freezable
54
+ @attrs.define(slots=False)
55
+ class MyClass:
56
+ attribute1: int
57
+ attribute2: str
58
+
59
+ obj = MyClass(1, 'a')
60
+ obj.freeze() # Freeze the instance
61
+ obj.attribute1 = 2 # Raises AttributeError
62
+
63
+ Args:
64
+ cls: The class to be decorated.
65
+
66
+ Returns:
67
+ The decorated class with added freezing capability.
68
+ """
69
+
70
+ if not hasattr(cls, "__dict__"):
71
+ raise TypeError(
72
+ "make_freezable cannot be used with classes that do not define __dict__. Make sure that the wrapped "
73
+ "class was defined with `@attrs.define(slots=False)`"
74
+ )
75
+
76
+ original_setattr = cls.__setattr__
77
+
78
+ def setattr_override(self, key, value) -> None: # noqa: ANN001
79
+ """
80
+ Override __setattr__ to allow modifications during initialization
81
+ and prevent modifications once the instance is frozen.
82
+ """
83
+ if hasattr(self, "_is_frozen") and self._is_frozen and key != "_is_frozen":
84
+ raise AttributeError("Cannot modify frozen instance")
85
+ original_setattr(self, key, value) # type: ignore
86
+
87
+ cls.__setattr__ = setattr_override # type: ignore
88
+
89
+ def freeze(self: object) -> None:
90
+ """
91
+ Freeze the instance and all its attrs-defined attributes.
92
+ """
93
+ for _, value in attrs.asdict(self, recurse=False).items():
94
+ if _is_attrs_instance(value) and hasattr(value, "freeze"):
95
+ value.freeze()
96
+ self._is_frozen = True # type: ignore
97
+
98
+ cls.freeze = freeze # type: ignore
99
+
100
+ return cls
101
+
102
+
103
+ def _pretty_print_attrs_instance(obj: object, indent: int = 0, use_color: bool = False) -> str:
104
+ """
105
+ Recursively pretty prints attrs objects with color.
106
+ """
107
+
108
+ assert attrs.has(obj.__class__)
109
+
110
+ lines: list[str] = []
111
+ for attribute in attrs.fields(obj.__class__):
112
+ value = getattr(obj, attribute.name)
113
+ if attrs.has(value.__class__):
114
+ if use_color:
115
+ lines.append(" " * indent + Color.cyan("* ") + Color.green(attribute.name) + ":")
116
+ else:
117
+ lines.append(" " * indent + "* " + attribute.name + ":")
118
+ lines.append(_pretty_print_attrs_instance(value, indent + 1, use_color))
119
+ else:
120
+ if use_color:
121
+ lines.append(
122
+ " " * indent + Color.cyan("* ") + Color.green(attribute.name) + ": " + Color.yellow(value)
123
+ )
124
+ else:
125
+ lines.append(" " * indent + "* " + attribute.name + ": " + str(value))
126
+ return "\n".join(lines)
127
+
128
+
129
+ @make_freezable
130
+ @attrs.define(slots=False)
131
+ class JobConfig:
132
+ # Project name.
133
+ project: str = ""
134
+ # Experiment name.
135
+ group: str = ""
136
+ # Run/job name.
137
+ name: str = ""
138
+
139
+ @property
140
+ def path(self) -> str:
141
+ return f"{self.project}/{self.group}/{self.name}"
142
+
143
+
144
+ @make_freezable
145
+ @attrs.define(slots=False)
146
+ class Config:
147
+ """Config for a job.
148
+
149
+ See /README.md/Configuration System for more info.
150
+ """
151
+
152
+ # Model configs.
153
+ model: LazyDict
154
+
155
+ # Training job configs.
156
+ job: JobConfig = attrs.field(factory=JobConfig)
157
+
158
+ def to_dict(self) -> dict[str, Any]:
159
+ return attrs.asdict(self)
160
+
161
+ def validate(self) -> None:
162
+ """Validate that the config has all required fields."""
163
+ assert self.job.project != "", "Project name is required."
164
+ assert self.job.group != "", "Group name is required."
165
+ assert self.job.name != "", "Job name is required."
config_helper.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import importlib
17
+ import os
18
+ import pkgutil
19
+ import sys
20
+ from dataclasses import fields as dataclass_fields
21
+ from dataclasses import is_dataclass
22
+ from typing import Any, Dict, Optional
23
+
24
+ import attr
25
+ import attrs
26
+ from hydra import compose, initialize
27
+ from hydra.core.config_store import ConfigStore
28
+ from omegaconf import DictConfig, OmegaConf
29
+
30
+ from .log import log
31
+ from .config import Config
32
+
33
+
34
+ def is_attrs_or_dataclass(obj) -> bool:
35
+ """
36
+ Check if the object is an instance of an attrs class or a dataclass.
37
+
38
+ Args:
39
+ obj: The object to check.
40
+
41
+ Returns:
42
+ bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
43
+ """
44
+ return is_dataclass(obj) or attr.has(type(obj))
45
+
46
+
47
+ def get_fields(obj):
48
+ """
49
+ Get the fields of an attrs class or a dataclass.
50
+
51
+ Args:
52
+ obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
53
+
54
+ Returns:
55
+ list: A list of field names.
56
+
57
+ Raises:
58
+ ValueError: If the object is neither an attrs class nor a dataclass.
59
+ """
60
+ if is_dataclass(obj):
61
+ return [field.name for field in dataclass_fields(obj)]
62
+ elif attr.has(type(obj)):
63
+ return [field.name for field in attr.fields(type(obj))]
64
+ else:
65
+ raise ValueError("The object is neither an attrs class nor a dataclass.")
66
+
67
+
68
+ def override(config: Config, overrides: Optional[list[str]] = None) -> Config:
69
+ """
70
+ :param config: the instance of class `Config` (usually from `make_config`)
71
+ :param overrides: list of overrides for config
72
+ :return: the composed instance of class `Config`
73
+ """
74
+ # Store the class of the config for reconstruction after overriding.
75
+ # config_class = type(config)
76
+
77
+ # Convert Config object to a DictConfig object
78
+ config_dict = attrs.asdict(config)
79
+ config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
80
+ # Enforce "--" separator between the script arguments and overriding configs.
81
+ if overrides:
82
+ if overrides[0] != "--":
83
+ raise ValueError('Hydra config overrides must be separated with a "--" token.')
84
+ overrides = overrides[1:]
85
+ # Use Hydra to handle overrides
86
+ cs = ConfigStore.instance()
87
+ cs.store(name="config", node=config_omegaconf)
88
+ with initialize(version_base=None):
89
+ config_omegaconf = compose(config_name="config", overrides=overrides)
90
+ OmegaConf.resolve(config_omegaconf)
91
+
92
+ def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
93
+ """
94
+ Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
95
+
96
+ Args:
97
+ ref_instance: The reference instance to determine the type and fields when needed
98
+ kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
99
+
100
+ Returns:
101
+ Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
102
+
103
+ Raises:
104
+ AssertionError: If the fields do not match or if extra keys are found.
105
+ Exception: If there is an error constructing the new instance.
106
+ """
107
+ is_type = is_attrs_or_dataclass(ref_instance)
108
+ if not is_type:
109
+ return kwargs
110
+ else:
111
+ ref_fields = set(get_fields(ref_instance))
112
+ assert isinstance(kwargs, dict) or isinstance(
113
+ kwargs, DictConfig
114
+ ), "kwargs must be a dictionary or a DictConfig"
115
+ keys = set(kwargs.keys())
116
+
117
+ # ref_fields must equal to or include all keys
118
+ extra_keys = keys - ref_fields
119
+ assert ref_fields == keys or keys.issubset(
120
+ ref_fields
121
+ ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
122
+
123
+ resolved_kwargs: Dict[str, Any] = {}
124
+ for f in keys:
125
+ resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
126
+ try:
127
+ new_instance = type(ref_instance)(**resolved_kwargs)
128
+ except Exception as e:
129
+ log.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
130
+ log.error(e)
131
+ raise e
132
+ return new_instance
133
+
134
+ config = config_from_dict(config, config_omegaconf)
135
+
136
+ return config
137
+
138
+
139
+ def get_config_module(config_file: str) -> str:
140
+ if not config_file.endswith(".py"):
141
+ log.error("Config file cannot be specified as module.")
142
+ log.error("Please provide the path to the Python config file (relative to the Cosmos root).")
143
+ assert os.path.isfile(config_file), f"Cosmos config file ({config_file}) not found."
144
+ # Convert to importable module format.
145
+ config_module = config_file.replace("/", ".").replace(".py", "")
146
+ return config_module
147
+
148
+
149
+ def import_all_modules_from_package(package_path: str, reload: bool = False, skip_underscore: bool = True) -> None:
150
+ """
151
+ Import all modules from the specified package path recursively.
152
+
153
+ This function is typically used in conjunction with Hydra to ensure that all modules
154
+ within a specified package are imported, which is necessary for registering configurations.
155
+
156
+ Example usage:
157
+ ```python
158
+ import_all_modules_from_package("cosmos1.models.diffusion.config.inference", reload=True, skip_underscore=False)
159
+ ```
160
+
161
+ Args:
162
+ package_path (str): The dotted path to the package from which to import all modules.
163
+ reload (bool): Flag to determine whether to reload modules if they're already imported.
164
+ skip_underscore (bool): If True, skips importing modules that start with an underscore.
165
+ """
166
+ return
167
+ log.debug(f"{'Reloading' if reload else 'Importing'} all modules from package {package_path}")
168
+ package = importlib.import_module(package_path)
169
+ package_directory = package.__path__
170
+
171
+ def import_modules_recursively(directory: str, prefix: str) -> None:
172
+ """
173
+ Recursively imports or reloads all modules in the given directory.
174
+
175
+ Args:
176
+ directory (str): The file system path to the current package directory.
177
+ prefix (str): The module prefix (e.g., 'cosmos1.models.diffusion.config').
178
+ """
179
+ for _, module_name, is_pkg in pkgutil.iter_modules([directory]):
180
+ if skip_underscore and module_name.startswith("_"):
181
+ log.debug(f"Skipping module {module_name} as it starts with an underscore")
182
+ continue
183
+
184
+ full_module_name = f"{prefix}.{module_name}"
185
+ log.debug(f"{'Reloading' if reload else 'Importing'} module {full_module_name}")
186
+
187
+ if full_module_name in sys.modules and reload:
188
+ importlib.reload(sys.modules[full_module_name])
189
+ else:
190
+ importlib.import_module(full_module_name)
191
+
192
+ if is_pkg:
193
+ sub_package_directory = os.path.join(directory, module_name)
194
+ import_modules_recursively(sub_package_directory, full_module_name)
195
+
196
+ for directory in package_directory:
197
+ import_modules_recursively(directory, package_path)
convert_pixtral_ckpt.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Convert pretrained Pixtral vision model weights to checkpoint and verify the checkpoint loading.
17
+
18
+ Usage:
19
+
20
+ PYTHONPATH=$(pwd) python cosmos1/scripts/convert_pixtral_ckpt.py
21
+ """
22
+
23
+ import argparse
24
+ import json
25
+ import os
26
+ import shutil
27
+ from glob import glob
28
+
29
+ import torch
30
+ from huggingface_hub import snapshot_download
31
+ from safetensors.torch import load_file
32
+
33
+
34
+ def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str):
35
+ """
36
+ Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint.
37
+
38
+ Args:
39
+ checkpoint_dir (str): Path to the checkpoint directory
40
+ checkpoint_name (str): Name of the checkpoint
41
+ vit_type (str): Type of ViT used in the Pixtral model
42
+
43
+ This function performs the following steps:
44
+ 0. Download the checkpoint from Hugging Face
45
+ 1. Loads the original Pixtral checkpoint
46
+ 2. Splits the checkpoint into vision encoder, projector, and LLM weights
47
+ 3. Reorganizes the weights to match the expected format
48
+ 4. Extracts and verifies the vision encoder configuration
49
+ 5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer
50
+ 6. Optionally saves the converted checkpoint and configuration
51
+ """
52
+
53
+ save_dir = os.path.join(checkpoint_dir, checkpoint_name)
54
+ os.makedirs(save_dir, exist_ok=True)
55
+ # Save the converted checkpoint
56
+ save_path = os.path.join(save_dir, "model.pt")
57
+ if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
58
+ print(f"Checkpoint {save_path} already exists and is not empty")
59
+ return
60
+
61
+ pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409")
62
+ os.makedirs(pixtral_ckpt_dir, exist_ok=True)
63
+ repo_id = "mistralai/Pixtral-12B-2409"
64
+ print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...")
65
+ snapshot_download(
66
+ repo_id=repo_id,
67
+ allow_patterns=["params.json", "consolidated.safetensors"],
68
+ local_dir=pixtral_ckpt_dir,
69
+ local_dir_use_symlinks=False,
70
+ )
71
+ orig_dtype = torch.get_default_dtype()
72
+ dtype = torch.bfloat16
73
+ torch.set_default_dtype(dtype)
74
+
75
+ # Load checkpoint file
76
+ ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors"))
77
+ assert len(ckpt_files) == 1, "ckpt_dir should contain only one file"
78
+ ckpt_path = ckpt_files[0]
79
+ ckpt = load_file(ckpt_path)
80
+
81
+ # Split checkpoint into weights of vision encoder, projector, and LLM
82
+ vit_key_prefix = "vision_encoder."
83
+ vit_ckpt = {}
84
+ for key, value in ckpt.items():
85
+ if key.startswith(vit_key_prefix):
86
+ vit_ckpt[key.lstrip(vit_key_prefix)] = value
87
+
88
+ projector_key_prefix = "vision_language_adapter."
89
+ projector_ckpt = {}
90
+ substring_replacement_map = {
91
+ "w_in.": "projector.0.",
92
+ "w_out.": "projector.2.",
93
+ }
94
+ for key, value in ckpt.items():
95
+ if key.startswith(projector_key_prefix):
96
+ key = key.lstrip(projector_key_prefix)
97
+ for old, new in substring_replacement_map.items():
98
+ key = key.replace(old, new)
99
+ projector_ckpt[key] = value
100
+
101
+ llm_ckpt = {}
102
+ for key, value in ckpt.items():
103
+ if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix):
104
+ continue
105
+ llm_ckpt[key] = value
106
+
107
+ vlm_ckpt = {}
108
+ for key, value in llm_ckpt.items():
109
+ vlm_ckpt["model." + key] = value
110
+ for key, value in projector_ckpt.items():
111
+ vlm_ckpt["mm_projector." + key] = value
112
+ for key, value in vit_ckpt.items():
113
+ vlm_ckpt["vision_encoder." + key] = value
114
+
115
+ # Load config
116
+ config_path = os.path.join(pixtral_ckpt_dir, "params.json")
117
+ with open(config_path, "r") as f:
118
+ pixtral_config = json.load(f)
119
+
120
+ # Extract the vision encoder configuration
121
+ vision_encoder_config = {
122
+ "dim": pixtral_config["vision_encoder"]["hidden_size"],
123
+ "num_channels": pixtral_config["vision_encoder"]["num_channels"],
124
+ "image_size": pixtral_config["vision_encoder"]["image_size"],
125
+ "patch_size": pixtral_config["vision_encoder"]["patch_size"],
126
+ "rope_theta": pixtral_config["vision_encoder"]["rope_theta"],
127
+ "ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"],
128
+ "n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"],
129
+ "n_heads": pixtral_config["vision_encoder"]["num_attention_heads"],
130
+ "n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"],
131
+ "norm_type": "rmsnorm",
132
+ "norm_eps": pixtral_config["norm_eps"],
133
+ "image_token_id": pixtral_config["vision_encoder"]["image_token_id"],
134
+ }
135
+ # Configuration for the 400M ViT of Pixtral 12B VLM
136
+ vit_config = dict(
137
+ dim=1024,
138
+ num_channels=3,
139
+ image_size=1024,
140
+ patch_size=16,
141
+ rope_theta=10000,
142
+ ffn_hidden_size=4096,
143
+ n_layers=24,
144
+ n_heads=16,
145
+ n_kv_heads=16,
146
+ norm_type="rmsnorm",
147
+ norm_eps=1e-5,
148
+ image_token_id=10,
149
+ )
150
+ # Compare the two configurations
151
+ for key, value in vit_config.items():
152
+ assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}"
153
+
154
+ llm_config_keys = [
155
+ "dim",
156
+ "n_layers",
157
+ "head_dim",
158
+ "hidden_dim",
159
+ "n_heads",
160
+ "n_kv_heads",
161
+ "rope_theta",
162
+ "norm_eps",
163
+ "vocab_size",
164
+ ]
165
+ assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch"
166
+ replace_map = {
167
+ "hidden_dim": "ffn_hidden_size",
168
+ }
169
+ llm_config = {}
170
+ for k, v in pixtral_config.items():
171
+ if k in llm_config_keys:
172
+ llm_config[replace_map.get(k, k)] = v
173
+ elif k == "vision_encoder":
174
+ llm_config["vision_encoder"] = vit_type
175
+ else:
176
+ raise ValueError(f"Unknown key: {k}")
177
+
178
+ ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt}
179
+ torch.save(ckpt_to_save, save_path)
180
+ print(f"Model saved to {save_path}")
181
+
182
+ # Save config
183
+ config_path = os.path.join(save_dir, "config.json")
184
+ with open(config_path, "w") as f:
185
+ json.dump(llm_config, f)
186
+
187
+ torch.set_default_dtype(orig_dtype) # Reset the default dtype
188
+
189
+ # Remove the original Pixtral checkpoint
190
+ shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True)
191
+ print(f"Removed {pixtral_ckpt_dir}")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ parser = argparse.ArgumentParser(
196
+ description="Convert pretrained Pixtral vision model weights to checkpoint and verify accuracy"
197
+ )
198
+ parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Path to the checkpoint directory")
199
+ parser.add_argument(
200
+ "--checkpoint_name",
201
+ type=str,
202
+ default="Pixtral-12B",
203
+ help="Name of the checkpoint",
204
+ )
205
+ parser.add_argument("--vit_type", default="pixtral-12b-vit", help="Type of ViT used in the Pixtral model")
206
+ args = parser.parse_args()
207
+ convert_pixtral_checkpoint(
208
+ checkpoint_dir=args.checkpoint_dir, checkpoint_name=args.checkpoint_name, vit_type=args.vit_type
209
+ )
cosmos1/models/POST_TRAINING.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cosmos Post-training
2
+
3
+ In the [Cosmos paper](https://research.nvidia.com/publication/2025-01_cosmos-world-foundation-model-platform-physical-ai), we discuss several post-training examples of Cosmos pre-trained World Foundation Models (WFMs) for various Physical AI tasks, including
4
+
5
+ - General Post-Training: Fine-tune the WFM to generate a target distribution of videos based on the custom dataset. The target distribution could include a specific camera spec or a specific domain such as a factory.
6
+ - Instruction Control: Post-trains models for robotic manipulation to predict videos based on textual instructions, enabling robots to visually simulate tasks like folding clothes or picking up objects.
7
+ - Action Control: Post-trains models for robotic manipulation to predict the next visual frame based on action vectors, simulating robotic tasks like object handling or movement planning.
8
+ - Camera Control: Adds camera pose conditioning to generate 3D-consistent video simulations from single images, enabling joystick-like navigation in virtual environments.
9
+ - Multi-View Generation: Post-trains models for autonomous vehicles to generate synchronized multi-view videos from text prompts, simulating driving scenarios with multiple camera perspectives.
10
+ - Multi-View Generation with Vehicle Trajectory Control: Extends multi-view generation by incorporating trajectory inputs, enabling precise simulation of driving environments for autonomous vehicles, adhering to specified paths.
11
+
12
+ Except for the instruction control where the WFM is post-trained on a dataset of instruction-video pairs, all other cases require minor modifications of the network architectures. Post-training tasks will be supported by NeMo Framework. In this initial release, we provide post-training scripts for the general post-training of both diffusion and autorgressive WFMs. Scripts of the other post-training tasks will be provided in a future release.
13
+
14
+ ## Post-training Support Matrix
15
+
16
+ | Post-training Task | Diffusion WFM | Autoregressive WFM |
17
+ |---------------------|---------------|--------------------|
18
+ | General post-training | [Supported](../models/diffusion/nemo/post_training/README.md) | [Supported](../models/autoregressive/nemo/post_training/README.md) |
19
+ | Instruction control | Coming soon | Coming soon |
20
+ | Action control | Coming soon | Coming soon |
21
+ | Camera control | Coming soon | Coming soon |
22
+ | Multi-view generation | Coming soon | Coming soon |
23
+ | Multi-view generation with vehicle trajectory control | Coming soon | Coming soon |
cosmos1/models/autoregressive/README.md ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cosmos Autoregressive-based World Foundation Models
2
+
3
+ ## Table of Contents
4
+ - [Getting Started](#getting-started)
5
+ - [Set Up Docker Environment](#set-up-docker-environment)
6
+ - [Download Checkpoints](#download-checkpoints)
7
+ - [Usage](#usage)
8
+ - [Model Types](#model-types)
9
+ - [Single and Batch Generation](#single-and-batch-generation)
10
+ - [Sample Commands](#sample-commands)
11
+ - [Base Models (4B/12B)](#base-basepy-4b-and-12b)
12
+ - [Video2World Models (5B/13B)](#video2world-video2worldpy-5b-and-13b)
13
+ - [Arguments](#arguments)
14
+ - [Common Parameters](#common-parameters)
15
+ - [Base Specific Parameters](#base-specific-parameters)
16
+ - [Video2World Specific Parameters](#video2world-specific-parameters)
17
+ - [Safety Features](#safety-features)
18
+
19
+ This page details the steps for using the Cosmos autoregressive-based world foundation models.
20
+
21
+ ## Getting Started
22
+
23
+ ### Set Up Docker Environment
24
+
25
+ Follow our [Installation Guide](../../../INSTALL.md) to set up the Docker environment. All commands on this page should be run inside Docker.
26
+
27
+ ### Download Checkpoints
28
+
29
+ 1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token. Set the access token to 'Read' permission (default is 'Fine-grained').
30
+
31
+ 2. Log in to Hugging Face with the access token:
32
+
33
+ ```bash
34
+ huggingface-cli login
35
+ ```
36
+
37
+ 3. Download the Cosmos model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-6751e884dc10e013a0a0d8e6):
38
+
39
+ ```bash
40
+ PYTHONPATH=$(pwd) python cosmos1/scripts/download_autoregressive.py --model_sizes 4B 5B 12B 13B
41
+ ```
42
+
43
+ 4. The downloaded files should be in the following structure:
44
+
45
+ ```
46
+ checkpoints/
47
+ ├── Cosmos-1.0-Autoregressive-4B
48
+ │ ├── model.pt
49
+ │ └── config.json
50
+ ├── Cosmos-1.0-Autoregressive-5B-Video2World
51
+ │ ├── model.pt
52
+ │ └── config.json
53
+ ├── Cosmos-1.0-Autoregressive-12B
54
+ │ ├── model.pt
55
+ │ └── config.json
56
+ ├── Cosmos-1.0-Autoregressive-13B-Video2World
57
+ │ ├── model.pt
58
+ │ └── config.json
59
+ ├── Cosmos-1.0-Tokenizer-CV8x8x8
60
+ │ ├── decoder.jit
61
+ │ ├── encoder.jit
62
+ │ └── mean_std.pt
63
+ ├── Cosmos-1.0-Tokenizer-DV8x16x16
64
+ │ ├── decoder.jit
65
+ │ └── encoder.jit
66
+ ├── Cosmos-1.0-Diffusion-7B-Decoder-DV8x16x16ToCV8x8x8
67
+ │ ├── aux_vars.pt
68
+ │ └── model.pt
69
+ └── Cosmos-1.0-Guardrail
70
+ ├── aegis/
71
+ ├── blocklist/
72
+ ├── face_blur_filter/
73
+ └── video_content_safety_filter/
74
+ ```
75
+
76
+ ## Usage
77
+
78
+
79
+ ### Model Types
80
+
81
+ There are two model types available for autoregressive world generation:
82
+
83
+ 1. **Base**: Supports world generation from image/video input
84
+
85
+ * Models: `Cosmos-1.0-Autoregressive-4B` and `Cosmos-1.0-Autoregressive-12B`
86
+ * Inference script: [base.py](/cosmos1/models/autoregressive/inference/base.py)
87
+
88
+ 2. **Video2World**: Supports world generation from image/video input and text input
89
+
90
+ * Models: `Cosmos-1.0-Autoregressive-5B-Video2World` and `Cosmos-1.0-Autoregressive-13B-Video2World`
91
+ * Inference script: [video2world.py](/cosmos1/models/autoregressive/inference/video2world.py)
92
+
93
+ Our models now support video extension up to 33 frames. Starting from either a single image or a 9-frame video input, they can generate the remaining frames to reach the 33-frame length (generating 32 or 24 frames, respectively).
94
+
95
+ We have evaluated all eight possible configurations (4 models × 2 vision input types: image or video) using 100 test videos on physical AI topics. Below are the failure rates for each configuration:
96
+
97
+ | Model | Image input | Video input (9 frames) |
98
+ |:------------------------------------------|:--------------:|:-------------------------:|
99
+ | Cosmos-1.0-Autoregressive-4B | 15% | 1% |
100
+ | Cosmos-1.0-Autoregressive-5B-Video2World | 7% | 2% |
101
+ | Cosmos-1.0-Autoregressive-12B | 2% | 1% |
102
+ | Cosmos-1.0-Autoregressive-13B-Video2World | 3% | 0% |
103
+
104
+ We define failure cases as videos with severe distortions, such as:
105
+
106
+ * Sudden appearance of large unexpected objects
107
+ * Video degrading to a single solid color
108
+
109
+ Note that the following are not considered failures in our analysis:
110
+
111
+ * Static video frames
112
+ * Minor object distortions or artifacts
113
+
114
+ ### Single and Batch Generation
115
+
116
+ We support both single and batch video generation.
117
+
118
+ For generating a single video, `base` mode requires the input argument `--input_image_or_video_path` (image/video input), while `video2world` mode requires both `--input_image_or_video_path` (image/video input) and `--prompt` (text input).
119
+
120
+ Note that our model only works with 1024x640 resolution videos. If the input image/video is not in this resolution, it will be resized and cropped.
121
+
122
+ For generating a batch of videos, both `base` and `video2world` require `--batch_input_path` (path to a JSONL file). For `base`, the JSONL file should contain one visual input per line in the following format, where each line must contain a "visual_input" field:
123
+
124
+ ```json
125
+ {"visual_input": "path/to/video1.mp4"}
126
+ {"visual_input": "path/to/video2.mp4"}
127
+ ```
128
+
129
+ For `video2world`, each line in the JSONL file must contain both "prompt" and "visual_input" fields:
130
+
131
+ ```json
132
+ {"prompt": "prompt1", "visual_input": "path/to/video1.mp4"}
133
+ {"prompt": "prompt2", "visual_input": "path/to/video2.mp4"}
134
+ ```
135
+
136
+ ### Sample Commands
137
+
138
+ There are two main demo scripts for autoregressive world generation: `base.py` and `video2world.py`. Below you will find sample commands for single and batch generation, as well as commands for running with low-memory GPUs using model offloading. We also provide a memory usage table comparing different offloading strategies to help with configuration.
139
+
140
+ #### Base (base.py): 4B and 12B
141
+
142
+ Generates world from image/video input.
143
+
144
+ The `input_type` argument can be either `video` or `image`. We have tuned the sampling parameters `top_p` and `temperature` to achieve the best performance. Please use the provided values in the command examples.
145
+
146
+ Note that the command examples below all use video input. If you want to use image input, please change the `input_type` to `image`.
147
+
148
+ ##### Single Generation
149
+
150
+ ```bash
151
+ # Example using 4B model
152
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
153
+ --input_type=video \
154
+ --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
155
+ --video_save_name=Cosmos-1.0-Autoregressive-4B \
156
+ --ar_model_dir=Cosmos-1.0-Autoregressive-4B \
157
+ --top_p=0.8 \
158
+ --temperature=1.0
159
+
160
+ # Example for low-memory GPUs using 4B model with model offloading
161
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
162
+ --input_type=video \
163
+ --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
164
+ --video_save_name=Cosmos-1.0-Autoregressive-4B \
165
+ --ar_model_dir=Cosmos-1.0-Autoregressive-4B \
166
+ --top_p=0.8 \
167
+ --temperature=1.0 \
168
+ --offload_guardrail_models \
169
+ --offload_diffusion_decoder \
170
+ --offload_ar_model \
171
+ --offload_tokenizer
172
+
173
+ # Example using 12B model
174
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
175
+ --input_type=video \
176
+ --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
177
+ --video_save_name=Cosmos-1.0-Autoregressive-12B \
178
+ --ar_model_dir=Cosmos-1.0-Autoregressive-12B \
179
+ --top_p=0.9 \
180
+ --temperature=1.0
181
+
182
+ # Example for low-memory GPUs using 12B model with model offloading
183
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
184
+ --input_type=video \
185
+ --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
186
+ --video_save_name=Cosmos-1.0-Autoregressive-12B \
187
+ --ar_model_dir=Cosmos-1.0-Autoregressive-12B \
188
+ --top_p=0.9 \
189
+ --temperature=1.0 \
190
+ --offload_guardrail_models \
191
+ --offload_diffusion_decoder \
192
+ --offload_ar_model \
193
+ --offload_tokenizer
194
+ ```
195
+
196
+ ##### Batch Generation
197
+
198
+ ```bash
199
+ # Example using 4B model
200
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
201
+ --input_type=video \
202
+ --batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl \
203
+ --video_save_folder=outputs/Cosmos-1.0-Autoregressive-4B \
204
+ --ar_model_dir=Cosmos-1.0-Autoregressive-4B \
205
+ --top_p=0.8 \
206
+ --temperature=1.0
207
+
208
+ # Example using 12B model
209
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/base.py \
210
+ --input_type=video \
211
+ --batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/base.jsonl \
212
+ --video_save_folder=outputs/Cosmos-1.0-Autoregressive-12B \
213
+ --ar_model_dir=Cosmos-1.0-Autoregressive-12B \
214
+ --top_p=0.9 \
215
+ --temperature=1.0
216
+ ```
217
+
218
+ ##### Example Output
219
+
220
+ Here is an example output video generated using base.py with image input, using `Cosmos-1.0-Autoregressive-12B`:
221
+
222
+ <video src="https://github.com/user-attachments/assets/634403a5-1873-42d7-8dd0-eb7fb4ac8cf4">
223
+ Your browser does not support the video tag.
224
+ </video>
225
+
226
+ The input image used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.jpg`. The image is from [BDD dataset](http://bdd-data.berkeley.edu/).
227
+
228
+ Here is an example output video generated using base.py with 9-frame video input, using `Cosmos-1.0-Autoregressive-12B`:
229
+
230
+ <video src="https://github.com/user-attachments/assets/1a3ff099-87d7-41e8-b149-a25cfcd4f40b">
231
+ Your browser does not support the video tag.
232
+ </video>
233
+
234
+ The input video used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.mp4`.
235
+
236
+ ##### Inference Time and GPU Memory Usage
237
+
238
+ These numbers may vary based on system specifications and are provided for reference only.
239
+
240
+ | Offloading Strategy | Cosmos-1.0-Autoregressive-4B | Cosmos-1.0-Autoregressive-12B |
241
+ |-------------|---------|---------|
242
+ | No offloading | 31.3 GB | 47.5 GB |
243
+ | Guardrails | 28.9 GB | 45.2 GB |
244
+ | Guardrails & Diffusion decoder | 28.5 GB | 43.1 GB |
245
+ | Guardrails & Diffusion decoder & Tokenizer | 27.3 GB | 42.9 GB |
246
+ | Guardrails & Diffusion decoder & Tokenizer & AR model | 18.7 GB | 27.4 GB |
247
+
248
+ End-to-end inference runtime on one H100 without offloading and after model initialization:
249
+
250
+ | Cosmos-1.0-Autoregressive-4B | Cosmos-1.0-Autoregressive-12B |
251
+ |---------|---------|
252
+ | ~62 seconds | ~119 seconds |
253
+
254
+ #### Video2World (video2world.py): 5B and 13B
255
+
256
+ Generates world from image/video and text input.
257
+
258
+ The `input_type` argument can be either `text_and_video` or `text_and_image`. We have tuned the sampling parameters `top_p` and `temperature` to achieve the best performance. Please use the provided values in the command examples.
259
+
260
+ Note that the command examples below all use video input. If you want to use image input, please change the `input_type` to `text_and_image`.
261
+
262
+ ##### Single Generation
263
+
264
+ ```bash
265
+ # Example using 5B model
266
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
267
+ --input_type=text_and_video \
268
+ --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
269
+ --prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \
270
+ --video_save_name=Cosmos-1.0-Autoregressive-5B-Video2World \
271
+ --ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \
272
+ --top_p=0.7 \
273
+ --temperature=1.0
274
+
275
+ # Example for low-memory GPUs using 5B model with model offloading
276
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
277
+ --input_type=text_and_video \
278
+ --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
279
+ --prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \
280
+ --video_save_name=Cosmos-1.0-Autoregressive-5B-Video2World \
281
+ --ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \
282
+ --top_p=0.7 \
283
+ --temperature=1.0 \
284
+ --offload_guardrail_models \
285
+ --offload_diffusion_decoder \
286
+ --offload_ar_model \
287
+ --offload_tokenizer \
288
+ --offload_text_encoder_model
289
+
290
+ # Example using 13B model
291
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
292
+ --input_type=text_and_video \
293
+ --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
294
+ --prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \
295
+ --video_save_name=Cosmos-1.0-Autoregressive-13B-Video2World \
296
+ --ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \
297
+ --top_p=0.8 \
298
+ --temperature=1.0 \
299
+ --offload_guardrail_models
300
+
301
+ # Example for low-memory GPUs using 13B model with model offloading
302
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
303
+ --input_type=text_and_video \
304
+ --input_image_or_video_path=cosmos1/models/autoregressive/assets/v1p0/input.mp4 \
305
+ --prompt="A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions." \
306
+ --video_save_name=Cosmos-1.0-Autoregressive-13B-Video2World \
307
+ --ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \
308
+ --top_p=0.8 \
309
+ --temperature=1.0 \
310
+ --offload_guardrail_models \
311
+ --offload_diffusion_decoder \
312
+ --offload_ar_model \
313
+ --offload_tokenizer \
314
+ --offload_text_encoder_model
315
+ ```
316
+
317
+ ##### Batch Generation
318
+
319
+ ```bash
320
+ # Example using 5B model
321
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
322
+ --input_type=text_and_video \
323
+ --batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl \
324
+ --video_save_folder=outputs/Cosmos-1.0-Autoregressive-5B-Video2World \
325
+ --ar_model_dir=Cosmos-1.0-Autoregressive-5B-Video2World \
326
+ --top_p=0.7 \
327
+ --temperature=1.0
328
+
329
+ # Example using 13B model
330
+ CUDA_VISIBLE_DEVICES=0 PYTHONPATH=$(pwd) python cosmos1/models/autoregressive/inference/video2world.py \
331
+ --input_type=text_and_video \
332
+ --batch_input_path=cosmos1/models/autoregressive/assets/v1p0/batch_inputs/video2world.jsonl \
333
+ --video_save_folder=outputs/Cosmos-1.0-Autoregressive-13B-Video2World \
334
+ --ar_model_dir=Cosmos-1.0-Autoregressive-13B-Video2World \
335
+ --top_p=0.8 \
336
+ --temperature=1.0 \
337
+ --offload_guardrail_models
338
+ ```
339
+
340
+ ##### Example Output
341
+
342
+ Here is an example output video generated using video2world.py with image input, using `Cosmos-1.0-Autoregressive-13B-Video2World`:
343
+
344
+ <video src="https://github.com/user-attachments/assets/869f3b81-fabd-462e-a545-c04cdd9c1d22">
345
+ Your browser does not support the video tag.
346
+ </video>
347
+
348
+ The input image used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.jpg`. The prompt for generating the video is:
349
+
350
+ ```
351
+ A driving video captures a serene urban street scene on a sunny day. The camera is mounted on the dashboard of a moving vehicle, providing a first-person perspective as it travels down a two-lane road. The street is lined with parked cars on both sides, predominantly black and silver sedans and SUVs. The road is flanked by a mix of residential and commercial buildings, with a prominent red-brick building on the left side, featuring multiple windows and a flat roof. The sky is clear with a few scattered clouds, casting soft shadows on the street. Trees with lush green foliage line the right side of the road, providing a natural contrast to the urban environment. The camera remains steady, maintaining a consistent forward motion, suggesting a leisurely drive. Traffic is light, with a few vehicles moving in the opposite direction, including a black sedan and a yellow taxi. Street signs are visible, including a no-parking sign on the right. The overall atmosphere is calm and peaceful, with no pedestrians visible, emphasizing the focus on the drive and the surrounding urban landscape.
352
+ ```
353
+
354
+ Here is an example output video generated using video2world.py with 9-frame video input, using `Cosmos-1.0-Autoregressive-13B-Video2World`:
355
+
356
+ <video src="https://github.com/user-attachments/assets/81840e1c-624b-4b01-9240-ab7db3722e58">
357
+ Your browser does not support the video tag.
358
+ </video>
359
+
360
+ The input video used to generate this video can be found in `cosmos1/models/autoregressive/assets/v1p0/input.mp4`. The prompt for generating the video is:
361
+
362
+ ```
363
+ A video recorded from a moving vehicle's perspective, capturing roads, buildings, landscapes, and changing weather and lighting conditions.
364
+ ```
365
+
366
+ ##### Inference Time and GPU Memory Usage
367
+
368
+ These numbers may vary based on system specifications and are provided for reference only.
369
+
370
+ | Offloading Strategy | Cosmos-1.0-Autoregressive-5B-Video2World | Cosmos-1.0-Autoregressive-13B-Video2World |
371
+ |-------------|---------|---------|
372
+ | No offloading | 66.2 GB | > 80 GB |
373
+ | Guardrails | 58.7 GB | 76.6 GB |
374
+ | Guardrails & T5 encoder | 41.3 GB | 58.0 GB |
375
+ | Guardrails & T5 encoder & Diffusion decoder | 29.0 GB | 46.9 GB |
376
+ | Guardrails & T5 encoder & Diffusion decoder & Tokenizer | 28.8 GB | 46.7 GB |
377
+ | Guardrails & T5 encoder & Diffusion decoder & Tokenizer & AR model | 21.1 GB | 30.9 GB |
378
+
379
+ End-to-end inference runtime on one H100 with no offloading for 5B model and guardrail offloading for 13B, after model initialization:
380
+
381
+ | Cosmos-1.0-Autoregressive-5B-Video2World | Cosmos-1.0-Autoregressive-13B-Video2World |
382
+ |---------|---------|
383
+ | ~73 seconds | ~150 seconds |
384
+
385
+ ### Arguments
386
+
387
+ #### Common Parameters
388
+
389
+ | Parameter | Description | Default |
390
+ |-----------|-------------|---------|
391
+ | `--checkpoint_dir` | Directory containing model weights | "checkpoints" |
392
+ | `--video_save_name` | Output video filename for single video generation | "output" |
393
+ | `--video_save_folder` | Folder where all output videos are stored | "outputs/" |
394
+ | `--input_image_or_video_path` | Input image or video path. Required for single video generation | None |
395
+ | `--batch_input_path` | Folder containing input images or videos. Required for batch video generation | None |
396
+ | `--num_input_frames` | Number of input frames to use for Video2World prediction | 9 |
397
+ | `--temperature` | Temperature used while sampling | 1.0 (recommend using values in sample commands provided) |
398
+ | `--top_p` | Top-p value for top-p sampling | 0.8 (recommend using values in sample commands provided) |
399
+ | `--seed` | Random seed | 0 |
400
+ | `--disable_diffusion_decoder` | When set to True, use discrete tokenizer to decode discrete tokens to video. Otherwise, use diffusion decoder to decode video | False |
401
+ | `--offload_guardrail_models` | Offload guardrail models after inference, used for low-memory GPUs | False |
402
+ | `--offload_diffusion_decoder` | Offload diffusion decoder after inference, used for low-memory GPUs | False |
403
+ | `--offload_ar_model` | Offload AR model after inference, used for low-memory GPUs | False |
404
+ | `--offload_prompt_upsampler` | Offload prompt upsampler after inference, used for low-memory GPUs | False |
405
+
406
+ #### Base Specific Parameters
407
+
408
+ | Parameter | Description | Default |
409
+ |-----------|-------------|---------|
410
+ | `--ar_model_dir` | Directory containing AR model weight | "Cosmos-1.0-Autoregressive-4B" |
411
+ | `--input_type` | Input type, either `video` or `image` | "video" |
412
+
413
+ #### Video2World Specific Parameters
414
+
415
+ | Parameter | Description | Default |
416
+ |-----------|-------------|---------|
417
+ | `--ar_model_dir` | Directory containing AR model weight | "Cosmos-1.0-Autoregressive-4B" |
418
+ | `--input_type` | Input type, either `text_and_video` or `text_and_image` | "text_and_video" |
419
+ | `--prompt` | Text prompt for single video generation. Required for single video generation | None |
420
+ | `--input_prompts_path` | Path to JSONL file for batch video generation. Required for batch video generation | None |
421
+ | `--offload_text_encoder_model` | Offload text encoder after inference, used for low-memory GPUs | False |
422
+
423
+ ### Safety Features
424
+
425
+ The model uses a built-in safety guardrail system that cannot be disabled. Generating human faces is not allowed and will be blurred by the guardrail.
426
+
427
+ For more information, check out the [Cosmos Guardrail Documentation](../guardrail/README.md).
cosmos1/models/autoregressive/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.