Add diffusers support
#1
by dn6 HF Staff - opened
- README.md +43 -1
- __init__.py +57 -0
- modular_blocks.py +1188 -0
- modular_config.json +7 -0
- modular_model_index.json +76 -0
- transformer/__init__.py +16 -0
- transformer/config.json +59 -0
- transformer/diffusion_pytorch_model.safetensors +3 -0
- transformer/model.py +1100 -0
- vae/__init__.py +18 -0
- vae/ae_model.py +368 -0
- vae/config.json +13 -0
- vae/diffusion_pytorch_model.safetensors +3 -0
README.md
CHANGED
|
@@ -79,6 +79,48 @@ This checkpoint is intended to be used with Overworld’s interactive runtime st
|
|
| 79 |
- Play on our official desktop client, [Biome](https://over.world/install)
|
| 80 |
- Use our [world_engine](https://github.com/Wayfarer-Labs/world_engine) inference library to build your own applications
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
### Recommended setup
|
| 84 |
|
|
@@ -129,4 +171,4 @@ Please see our blog post, ["Engineering Safety for Interactive World Models"](ht
|
|
| 129 |
|
| 130 |
- [Website](http://over.world/)
|
| 131 |
- [Discord](https://discord.gg/MEmQa7Wux4)
|
| 132 |
-
- [X/Twitter](https://x.com/overworld_ai)
|
|
|
|
| 79 |
- Play on our official desktop client, [Biome](https://over.world/install)
|
| 80 |
- Use our [world_engine](https://github.com/Wayfarer-Labs/world_engine) inference library to build your own applications
|
| 81 |
|
| 82 |
+
### Diffusers (Modular Pipeline)
|
| 83 |
+
|
| 84 |
+
This model can also be used with [Modular Diffusers](https://huggingface.co/docs/diffusers/main/en/modular_diffusers):
|
| 85 |
+
|
| 86 |
+
```python
|
| 87 |
+
import torch
|
| 88 |
+
from diffusers.modular_pipelines import ModularPipeline
|
| 89 |
+
from diffusers.utils import load_image, export_to_video
|
| 90 |
+
|
| 91 |
+
pipe = ModularPipeline.from_pretrained(
|
| 92 |
+
"Overworld/Waypoint-1.5-1B", trust_remote_code=True
|
| 93 |
+
)
|
| 94 |
+
pipe.load_components(
|
| 95 |
+
device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True
|
| 96 |
+
)
|
| 97 |
+
pipe.transformer.apply_inference_patches()
|
| 98 |
+
pipe.transformer.compile(fullgraph=True, mode="max-autotune", dynamic=False)
|
| 99 |
+
|
| 100 |
+
# Seed the world with an image
|
| 101 |
+
image = load_image("https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_18.png").resize((1024, 512))
|
| 102 |
+
state = pipe(
|
| 103 |
+
image=image,
|
| 104 |
+
prompt="An explorable world",
|
| 105 |
+
button=set(),
|
| 106 |
+
mouse=(0.0, 0.0),
|
| 107 |
+
output_type="pil",
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# Generate subsequent frames with controller inputs
|
| 111 |
+
state.values["image"] = None
|
| 112 |
+
frames = []
|
| 113 |
+
for _ in range(150):
|
| 114 |
+
state = pipe(
|
| 115 |
+
state,
|
| 116 |
+
button={87}, # W key (walk forward)
|
| 117 |
+
mouse=(0.0, 0.0),
|
| 118 |
+
output_type="pil",
|
| 119 |
+
)
|
| 120 |
+
frames.append(state.values["images"]) # list of PIL images
|
| 121 |
+
|
| 122 |
+
export_to_video(outputs, "waypoint-v1-5.mp4", fps=60)
|
| 123 |
+
```
|
| 124 |
|
| 125 |
### Recommended setup
|
| 126 |
|
|
|
|
| 171 |
|
| 172 |
- [Website](http://over.world/)
|
| 173 |
- [Discord](https://discord.gg/MEmQa7Wux4)
|
| 174 |
+
- [X/Twitter](https://x.com/overworld_ai)
|
__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
WorldEngine Modular Pipeline
|
| 18 |
+
|
| 19 |
+
A Diffusers-compatible modular pipeline for frame-by-frame world model generation.
|
| 20 |
+
Supports text and controller (mouse + button + scroll) conditioning.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from .modular_blocks import (
|
| 24 |
+
WorldEngineBlocks,
|
| 25 |
+
AUTO_BLOCKS,
|
| 26 |
+
WorldEngineTextEncoderStep,
|
| 27 |
+
WorldEngineControllerEncoderStep,
|
| 28 |
+
WorldEngineBeforeDenoiseStep,
|
| 29 |
+
WorldEngineSetTimestepsStep,
|
| 30 |
+
WorldEnginePrepareLatentsStep,
|
| 31 |
+
WorldEngineSetupKVCacheStep,
|
| 32 |
+
WorldEngineDenoiseLoop,
|
| 33 |
+
WorldEngineDecodeStep,
|
| 34 |
+
StaticKVCache,
|
| 35 |
+
LayerKVCache,
|
| 36 |
+
)
|
| 37 |
+
from .transformer import WorldModel
|
| 38 |
+
from .vae import ChunkedStreamingTAEHV
|
| 39 |
+
|
| 40 |
+
__version__ = "0.1.0"
|
| 41 |
+
|
| 42 |
+
__all__ = [
|
| 43 |
+
"WorldEngineBlocks",
|
| 44 |
+
"AUTO_BLOCKS",
|
| 45 |
+
"WorldEngineTextEncoderStep",
|
| 46 |
+
"WorldEngineControllerEncoderStep",
|
| 47 |
+
"WorldEngineBeforeDenoiseStep",
|
| 48 |
+
"WorldEngineSetTimestepsStep",
|
| 49 |
+
"WorldEnginePrepareLatentsStep",
|
| 50 |
+
"WorldEngineSetupKVCacheStep",
|
| 51 |
+
"WorldEngineDenoiseLoop",
|
| 52 |
+
"WorldEngineDecodeStep",
|
| 53 |
+
"WorldModel",
|
| 54 |
+
"ChunkedStreamingTAEHV",
|
| 55 |
+
"StaticKVCache",
|
| 56 |
+
"LayerKVCache",
|
| 57 |
+
]
|
modular_blocks.py
ADDED
|
@@ -0,0 +1,1188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""WorldEngine modular pipeline blocks.
|
| 17 |
+
|
| 18 |
+
All pipeline step classes for text encoding, controller encoding,
|
| 19 |
+
KV cache setup, latent preparation, denoising, and decoding.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import html
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import PIL.Image
|
| 26 |
+
import regex as re
|
| 27 |
+
import torch
|
| 28 |
+
from torch import nn, Tensor
|
| 29 |
+
from tensordict import TensorDict
|
| 30 |
+
from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask
|
| 31 |
+
from transformers import AutoTokenizer, UMT5EncoderModel
|
| 32 |
+
|
| 33 |
+
from diffusers import AutoModel
|
| 34 |
+
from diffusers.configuration_utils import FrozenDict
|
| 35 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 36 |
+
from diffusers.utils import is_ftfy_available, logging
|
| 37 |
+
from diffusers.modular_pipelines import (
|
| 38 |
+
ModularPipelineBlocks,
|
| 39 |
+
ModularPipeline,
|
| 40 |
+
PipelineState,
|
| 41 |
+
SequentialPipelineBlocks,
|
| 42 |
+
)
|
| 43 |
+
from diffusers.modular_pipelines.modular_pipeline_utils import (
|
| 44 |
+
ComponentSpec,
|
| 45 |
+
ConfigSpec,
|
| 46 |
+
InputParam,
|
| 47 |
+
InsertableDict,
|
| 48 |
+
OutputParam,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
if is_ftfy_available():
|
| 52 |
+
import ftfy
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# Text encoding helpers
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
def basic_clean(text):
|
| 63 |
+
text = ftfy.fix_text(text)
|
| 64 |
+
text = html.unescape(html.unescape(text))
|
| 65 |
+
return text.strip()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def whitespace_clean(text):
|
| 69 |
+
text = re.sub(r"\s+", " ", text)
|
| 70 |
+
text = text.strip()
|
| 71 |
+
return text
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def prompt_clean(text):
|
| 75 |
+
text = whitespace_clean(basic_clean(text))
|
| 76 |
+
return text
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# Block mask construction
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask:
|
| 84 |
+
"""
|
| 85 |
+
Create a block mask for flex_attention.
|
| 86 |
+
|
| 87 |
+
T and L must be exact multiples of the sparse block size; written must be
|
| 88 |
+
block-aligned (each block is either all True or all False).
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
T: Q length for this frame
|
| 92 |
+
L: KV capacity == written.numel()
|
| 93 |
+
written: [L] bool, True where there is valid KV data
|
| 94 |
+
"""
|
| 95 |
+
BS = _DEFAULT_SPARSE_BLOCK_SIZE
|
| 96 |
+
|
| 97 |
+
if not torch.compiler.is_compiling():
|
| 98 |
+
torch._check(T % BS == 0, f"T ({T}) must be a multiple of block size ({BS})")
|
| 99 |
+
torch._check(L % BS == 0, f"L ({L}) must be a multiple of block size ({BS})")
|
| 100 |
+
|
| 101 |
+
Q_blocks = T // BS
|
| 102 |
+
KV_blocks = L // BS
|
| 103 |
+
|
| 104 |
+
written_blocks = written.view(KV_blocks, BS)
|
| 105 |
+
block_any = written_blocks.any(-1)
|
| 106 |
+
|
| 107 |
+
if not torch.compiler.is_compiling():
|
| 108 |
+
assert torch.equal(block_any, written_blocks.all(-1)), "written must be block-aligned"
|
| 109 |
+
|
| 110 |
+
# Every KV block is a full block (no partial blocks)
|
| 111 |
+
full_bm = block_any[None, :].expand(Q_blocks, KV_blocks)
|
| 112 |
+
full_kv_num_blocks = full_bm.sum(dim=-1, dtype=torch.int32)[None, None].contiguous()
|
| 113 |
+
full_kv_indices = full_bm.argsort(dim=-1, descending=True, stable=True).to(torch.int32)[None, None].contiguous()
|
| 114 |
+
|
| 115 |
+
# No partial blocks
|
| 116 |
+
kv_num_blocks = torch.zeros((1, 1, Q_blocks), dtype=torch.int32, device=written.device)
|
| 117 |
+
kv_indices = torch.zeros((1, 1, Q_blocks, KV_blocks), dtype=torch.int32, device=written.device)
|
| 118 |
+
|
| 119 |
+
return BlockMask.from_kv_blocks(
|
| 120 |
+
kv_num_blocks,
|
| 121 |
+
kv_indices,
|
| 122 |
+
full_kv_num_blocks,
|
| 123 |
+
full_kv_indices,
|
| 124 |
+
BLOCK_SIZE=BS,
|
| 125 |
+
mask_mod=None,
|
| 126 |
+
seq_lengths=(T, L),
|
| 127 |
+
compute_q_blocks=False,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
# ---------------------------------------------------------------------------
|
| 132 |
+
# KV cache
|
| 133 |
+
# ---------------------------------------------------------------------------
|
| 134 |
+
|
| 135 |
+
class LayerKVCache(nn.Module):
|
| 136 |
+
"""
|
| 137 |
+
Ring-buffer KV cache with fixed capacity L (tokens) for history plus
|
| 138 |
+
one extra frame (tokens_per_frame) at the tail holding the current frame.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(
|
| 142 |
+
self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.tpf = tokens_per_frame
|
| 146 |
+
self.L = L
|
| 147 |
+
# total KV capacity: ring (L) + tail frame (tpf)
|
| 148 |
+
self.capacity = L + self.tpf
|
| 149 |
+
self.pinned_dilation = pinned_dilation
|
| 150 |
+
self.num_buckets = (L // self.tpf) // self.pinned_dilation
|
| 151 |
+
assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0
|
| 152 |
+
|
| 153 |
+
# KV buffer: [2, B, H, capacity, Dh]
|
| 154 |
+
self.kv = nn.Buffer(
|
| 155 |
+
torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype),
|
| 156 |
+
persistent=False,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
# which slots have ever been written
|
| 160 |
+
# tail slice [L, L+tpf) always holds the current frame and is considered written
|
| 161 |
+
written = torch.zeros(self.capacity, dtype=torch.bool)
|
| 162 |
+
written[L:] = True
|
| 163 |
+
self.written = nn.Buffer(written, persistent=False)
|
| 164 |
+
|
| 165 |
+
# _mask_written is a scratch buffer for computing block masks without cloning
|
| 166 |
+
self._mask_written = nn.Buffer(torch.zeros_like(written), persistent=False)
|
| 167 |
+
|
| 168 |
+
# Precompute indices:
|
| 169 |
+
# frame_offsets: [0, 1, ..., tpf-1] (for ring indexing)
|
| 170 |
+
# current_idx: [L, L+1, ..., L+tpf-1] (tail slice)
|
| 171 |
+
self.frame_offsets = nn.Buffer(
|
| 172 |
+
torch.arange(self.tpf, dtype=torch.long), persistent=False
|
| 173 |
+
)
|
| 174 |
+
self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False)
|
| 175 |
+
|
| 176 |
+
def reset(self):
|
| 177 |
+
self.kv.zero_()
|
| 178 |
+
self.written.zero_()
|
| 179 |
+
self.written[self.L :].fill_(True)
|
| 180 |
+
|
| 181 |
+
def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool):
|
| 182 |
+
"""
|
| 183 |
+
Args:
|
| 184 |
+
kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame)
|
| 185 |
+
pos_ids: TensorDict with f_pos [B, T] for cache slot indexing
|
| 186 |
+
"""
|
| 187 |
+
T = self.tpf
|
| 188 |
+
f_pos = pos_ids["f_pos"]
|
| 189 |
+
|
| 190 |
+
if not torch.compiler.is_compiling():
|
| 191 |
+
torch._check(
|
| 192 |
+
kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert"
|
| 193 |
+
)
|
| 194 |
+
torch._check(f_pos.shape == (kv.size(1), T), "f_pos must be [B, T]")
|
| 195 |
+
torch._check(self.tpf <= self.L, "frame longer than KV ring capacity")
|
| 196 |
+
torch._check(
|
| 197 |
+
self.L % self.tpf == 0,
|
| 198 |
+
f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})",
|
| 199 |
+
)
|
| 200 |
+
torch._check(
|
| 201 |
+
self.kv.size(3) == self.capacity,
|
| 202 |
+
"KV buffer too long (expected L + tokens_per_frame)",
|
| 203 |
+
)
|
| 204 |
+
torch._check(
|
| 205 |
+
(f_pos >= 0).all().item(),
|
| 206 |
+
"f_pos must be non-negative during inference",
|
| 207 |
+
)
|
| 208 |
+
torch._check(
|
| 209 |
+
((f_pos == f_pos[:, :1]).all()).item(),
|
| 210 |
+
"f_pos must be constant within frame",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
frame_idx = f_pos[0, 0]
|
| 214 |
+
|
| 215 |
+
# map frame_idx to a bucket, each bucket owns T contiguous slots
|
| 216 |
+
bucket = (frame_idx + (self.pinned_dilation - 1)) // self.pinned_dilation
|
| 217 |
+
slot = bucket % self.num_buckets
|
| 218 |
+
base = slot * T
|
| 219 |
+
|
| 220 |
+
# indices in the ring for this frame: [T] in [0, L)
|
| 221 |
+
ring_idx = self.frame_offsets + base
|
| 222 |
+
|
| 223 |
+
# Always write current frame into the tail slice [L, L+T):
|
| 224 |
+
# this is the "self-attention component" for the current frame.
|
| 225 |
+
self.kv.index_copy_(3, self.current_idx, kv)
|
| 226 |
+
|
| 227 |
+
write_step = (frame_idx.remainder(self.pinned_dilation) == 0)
|
| 228 |
+
mask_written = self._mask_written
|
| 229 |
+
mask_written.copy_(self.written)
|
| 230 |
+
mask_written[ring_idx] = mask_written[ring_idx] & ~write_step
|
| 231 |
+
bm = make_block_mask(T, self.capacity, mask_written)
|
| 232 |
+
|
| 233 |
+
# Persist current frame into the ring for future queries when unfrozen.
|
| 234 |
+
if not is_frozen:
|
| 235 |
+
dst = torch.where(write_step, ring_idx, self.current_idx)
|
| 236 |
+
self.kv.index_copy_(3, dst, kv)
|
| 237 |
+
self.written[dst] = True
|
| 238 |
+
|
| 239 |
+
k, v = self.kv.unbind(0)
|
| 240 |
+
return k, v, bm
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class StaticKVCache(nn.Module):
|
| 244 |
+
"""Static KV cache with per-layer configuration for local/global attention."""
|
| 245 |
+
|
| 246 |
+
def __init__(self, config, batch_size, dtype):
|
| 247 |
+
super().__init__()
|
| 248 |
+
|
| 249 |
+
self.tpf = config.height * config.width
|
| 250 |
+
|
| 251 |
+
local_L = config.local_window * self.tpf
|
| 252 |
+
global_L = config.global_window * self.tpf
|
| 253 |
+
|
| 254 |
+
period = config.global_attn_period
|
| 255 |
+
off = getattr(config, "global_attn_offset", 0) % period
|
| 256 |
+
self.layers = nn.ModuleList(
|
| 257 |
+
[
|
| 258 |
+
LayerKVCache(
|
| 259 |
+
batch_size,
|
| 260 |
+
getattr(config, "n_kv_heads", None) or config.n_heads,
|
| 261 |
+
global_L if ((layer_idx - off) % period == 0) else local_L,
|
| 262 |
+
config.d_model // config.n_heads,
|
| 263 |
+
dtype,
|
| 264 |
+
self.tpf,
|
| 265 |
+
(
|
| 266 |
+
config.global_pinned_dilation
|
| 267 |
+
if ((layer_idx - off) % period == 0)
|
| 268 |
+
else 1
|
| 269 |
+
),
|
| 270 |
+
)
|
| 271 |
+
for layer_idx in range(config.n_layers)
|
| 272 |
+
]
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
self._is_frozen = True
|
| 276 |
+
|
| 277 |
+
def reset(self):
|
| 278 |
+
for layer in self.layers:
|
| 279 |
+
layer.reset()
|
| 280 |
+
self._is_frozen = True
|
| 281 |
+
|
| 282 |
+
@torch.inference_mode()
|
| 283 |
+
def get_state(self):
|
| 284 |
+
"""Captures a world state to continue via load_state."""
|
| 285 |
+
layers = [(layer.kv.detach().clone(), layer.written.detach().clone()) for layer in self.layers]
|
| 286 |
+
return {"_is_frozen": self._is_frozen, "layers": layers}
|
| 287 |
+
|
| 288 |
+
@torch.inference_mode()
|
| 289 |
+
def load_state(self, state):
|
| 290 |
+
"""Loads a world state object saved via get_state."""
|
| 291 |
+
self._is_frozen = bool(state.get("_is_frozen", True))
|
| 292 |
+
for layer, (kv, written) in zip(self.layers, state["layers"]):
|
| 293 |
+
layer.kv.copy_(kv)
|
| 294 |
+
layer.written.copy_(written)
|
| 295 |
+
|
| 296 |
+
def set_frozen(self, is_frozen: bool):
|
| 297 |
+
self._is_frozen = is_frozen
|
| 298 |
+
|
| 299 |
+
def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int):
|
| 300 |
+
kv = torch.stack([k, v], dim=0)
|
| 301 |
+
return self.layers[layer].upsert(kv, pos_ids, self._is_frozen)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
# ---------------------------------------------------------------------------
|
| 305 |
+
# Pipeline step: Text Encoder
|
| 306 |
+
# ---------------------------------------------------------------------------
|
| 307 |
+
|
| 308 |
+
class WorldEngineTextEncoderStep(ModularPipelineBlocks):
|
| 309 |
+
"""Encodes text prompts using UMT5-XL for conditioning."""
|
| 310 |
+
|
| 311 |
+
model_name = "world_engine"
|
| 312 |
+
|
| 313 |
+
@property
|
| 314 |
+
def description(self) -> str:
|
| 315 |
+
return (
|
| 316 |
+
"Text Encoder step that generates text embeddings to guide frame generation"
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
@property
|
| 320 |
+
def expected_components(self) -> list[ComponentSpec]:
|
| 321 |
+
return [
|
| 322 |
+
ComponentSpec("text_encoder", UMT5EncoderModel),
|
| 323 |
+
ComponentSpec("tokenizer", AutoTokenizer),
|
| 324 |
+
]
|
| 325 |
+
|
| 326 |
+
@property
|
| 327 |
+
def inputs(self) -> list[InputParam]:
|
| 328 |
+
return [
|
| 329 |
+
InputParam(
|
| 330 |
+
"prompt",
|
| 331 |
+
description="The prompt or prompts to guide the frame generation",
|
| 332 |
+
),
|
| 333 |
+
InputParam(
|
| 334 |
+
"prompt_embeds",
|
| 335 |
+
type_hint=torch.Tensor,
|
| 336 |
+
description="Pre-computed text embeddings",
|
| 337 |
+
),
|
| 338 |
+
InputParam(
|
| 339 |
+
"prompt_pad_mask",
|
| 340 |
+
type_hint=torch.Tensor,
|
| 341 |
+
description="Padding mask for prompt embeddings",
|
| 342 |
+
),
|
| 343 |
+
]
|
| 344 |
+
|
| 345 |
+
@property
|
| 346 |
+
def intermediate_outputs(self) -> list[OutputParam]:
|
| 347 |
+
return [
|
| 348 |
+
OutputParam(
|
| 349 |
+
"prompt_embeds",
|
| 350 |
+
type_hint=torch.Tensor,
|
| 351 |
+
kwargs_type="denoiser_input_fields",
|
| 352 |
+
description="Text embeddings used to guide frame generation",
|
| 353 |
+
),
|
| 354 |
+
OutputParam(
|
| 355 |
+
"prompt_pad_mask",
|
| 356 |
+
type_hint=torch.Tensor,
|
| 357 |
+
kwargs_type="denoiser_input_fields",
|
| 358 |
+
description="Padding mask for prompt embeddings",
|
| 359 |
+
),
|
| 360 |
+
]
|
| 361 |
+
|
| 362 |
+
@staticmethod
|
| 363 |
+
def check_inputs(block_state):
|
| 364 |
+
if block_state.prompt is not None and (
|
| 365 |
+
not isinstance(block_state.prompt, str)
|
| 366 |
+
and not isinstance(block_state.prompt, list)
|
| 367 |
+
):
|
| 368 |
+
raise ValueError(
|
| 369 |
+
f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}"
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
@staticmethod
|
| 373 |
+
def encode_prompt(
|
| 374 |
+
components,
|
| 375 |
+
prompt: str | list[str],
|
| 376 |
+
device: torch.device,
|
| 377 |
+
max_sequence_length: int = 512,
|
| 378 |
+
):
|
| 379 |
+
dtype = components.text_encoder.dtype
|
| 380 |
+
|
| 381 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 382 |
+
prompt = [prompt_clean(p) for p in prompt]
|
| 383 |
+
|
| 384 |
+
text_inputs = components.tokenizer(
|
| 385 |
+
prompt,
|
| 386 |
+
padding="max_length",
|
| 387 |
+
max_length=max_sequence_length,
|
| 388 |
+
truncation=True,
|
| 389 |
+
return_attention_mask=True,
|
| 390 |
+
return_tensors="pt",
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
| 394 |
+
attention_mask = text_inputs.attention_mask.to(device)
|
| 395 |
+
|
| 396 |
+
prompt_embeds = components.text_encoder(
|
| 397 |
+
text_input_ids, attention_mask
|
| 398 |
+
).last_hidden_state
|
| 399 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
| 400 |
+
|
| 401 |
+
# Zero out padding
|
| 402 |
+
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as(
|
| 403 |
+
prompt_embeds
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
# Create padding mask (True where padded)
|
| 407 |
+
prompt_pad_mask = attention_mask.eq(0)
|
| 408 |
+
|
| 409 |
+
return prompt_embeds, prompt_pad_mask
|
| 410 |
+
|
| 411 |
+
@torch.no_grad()
|
| 412 |
+
def __call__(
|
| 413 |
+
self, components: ModularPipeline, state: PipelineState
|
| 414 |
+
) -> PipelineState:
|
| 415 |
+
block_state = self.get_block_state(state)
|
| 416 |
+
self.check_inputs(block_state)
|
| 417 |
+
|
| 418 |
+
device = components._execution_device
|
| 419 |
+
if block_state.prompt_embeds is None:
|
| 420 |
+
block_state.prompt = block_state.prompt or "An explorable world"
|
| 421 |
+
(
|
| 422 |
+
block_state.prompt_embeds,
|
| 423 |
+
block_state.prompt_pad_mask,
|
| 424 |
+
) = self.encode_prompt(components, block_state.prompt, device)
|
| 425 |
+
block_state.prompt_embeds = block_state.prompt_embeds.contiguous()
|
| 426 |
+
|
| 427 |
+
if block_state.prompt_pad_mask is None:
|
| 428 |
+
block_state.prompt_pad_mask = torch.zeros(
|
| 429 |
+
block_state.prompt_embeds.shape[:2],
|
| 430 |
+
dtype=torch.bool,
|
| 431 |
+
device=device,
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
self.set_block_state(state, block_state)
|
| 435 |
+
return components, state
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
# ---------------------------------------------------------------------------
|
| 439 |
+
# Pipeline step: Controller Encoder
|
| 440 |
+
# ---------------------------------------------------------------------------
|
| 441 |
+
|
| 442 |
+
class WorldEngineControllerEncoderStep(ModularPipelineBlocks):
|
| 443 |
+
"""Encodes controller inputs (mouse + buttons + scroll) for conditioning."""
|
| 444 |
+
|
| 445 |
+
model_name = "world_engine"
|
| 446 |
+
|
| 447 |
+
@property
|
| 448 |
+
def description(self) -> str:
|
| 449 |
+
return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning"
|
| 450 |
+
|
| 451 |
+
@property
|
| 452 |
+
def expected_components(self) -> list[ComponentSpec]:
|
| 453 |
+
return [] # Controller embedding is part of transformer
|
| 454 |
+
|
| 455 |
+
@property
|
| 456 |
+
def expected_configs(self) -> list[ConfigSpec]:
|
| 457 |
+
return [ConfigSpec("n_buttons", 256)]
|
| 458 |
+
|
| 459 |
+
@property
|
| 460 |
+
def inputs(self) -> list[InputParam]:
|
| 461 |
+
return [
|
| 462 |
+
InputParam(
|
| 463 |
+
"button",
|
| 464 |
+
type_hint=set[int],
|
| 465 |
+
default=set(),
|
| 466 |
+
description="Set of pressed button IDs",
|
| 467 |
+
),
|
| 468 |
+
InputParam(
|
| 469 |
+
"mouse",
|
| 470 |
+
type_hint=tuple[float, float],
|
| 471 |
+
default=(0.0, 0.0),
|
| 472 |
+
description="Mouse velocity (x, y)",
|
| 473 |
+
),
|
| 474 |
+
InputParam(
|
| 475 |
+
"scroll",
|
| 476 |
+
type_hint=int,
|
| 477 |
+
default=0,
|
| 478 |
+
description="Scroll wheel direction (-1, 0, 1)",
|
| 479 |
+
),
|
| 480 |
+
InputParam(
|
| 481 |
+
"button_tensor",
|
| 482 |
+
type_hint=torch.Tensor,
|
| 483 |
+
kwargs_type="denoiser_input_fields",
|
| 484 |
+
description="One-hot encoded button tensor",
|
| 485 |
+
),
|
| 486 |
+
InputParam(
|
| 487 |
+
"mouse_tensor",
|
| 488 |
+
type_hint=torch.Tensor,
|
| 489 |
+
kwargs_type="denoiser_input_fields",
|
| 490 |
+
description="Mouse velocity tensor",
|
| 491 |
+
),
|
| 492 |
+
InputParam(
|
| 493 |
+
"scroll_tensor",
|
| 494 |
+
type_hint=torch.Tensor,
|
| 495 |
+
kwargs_type="denoiser_input_fields",
|
| 496 |
+
description="Scroll wheel sign tensor",
|
| 497 |
+
),
|
| 498 |
+
]
|
| 499 |
+
|
| 500 |
+
@property
|
| 501 |
+
def intermediate_outputs(self) -> list[OutputParam]:
|
| 502 |
+
return [
|
| 503 |
+
OutputParam(
|
| 504 |
+
"button_tensor",
|
| 505 |
+
type_hint=torch.Tensor,
|
| 506 |
+
kwargs_type="denoiser_input_fields",
|
| 507 |
+
description="One-hot encoded button tensor",
|
| 508 |
+
),
|
| 509 |
+
OutputParam(
|
| 510 |
+
"mouse_tensor",
|
| 511 |
+
type_hint=torch.Tensor,
|
| 512 |
+
kwargs_type="denoiser_input_fields",
|
| 513 |
+
description="Mouse velocity tensor",
|
| 514 |
+
),
|
| 515 |
+
OutputParam(
|
| 516 |
+
"scroll_tensor",
|
| 517 |
+
type_hint=torch.Tensor,
|
| 518 |
+
kwargs_type="denoiser_input_fields",
|
| 519 |
+
description="Scroll wheel sign tensor",
|
| 520 |
+
),
|
| 521 |
+
]
|
| 522 |
+
|
| 523 |
+
@torch.no_grad()
|
| 524 |
+
def __call__(
|
| 525 |
+
self, components: ModularPipeline, state: PipelineState
|
| 526 |
+
) -> PipelineState:
|
| 527 |
+
block_state = self.get_block_state(state)
|
| 528 |
+
device = components._execution_device
|
| 529 |
+
dtype = components.transformer.dtype
|
| 530 |
+
|
| 531 |
+
n_buttons = components.config.n_buttons
|
| 532 |
+
|
| 533 |
+
# Create or reuse button tensor [1, 1, n_buttons]
|
| 534 |
+
if block_state.button_tensor is None:
|
| 535 |
+
block_state.button_tensor = torch.zeros(
|
| 536 |
+
(1, 1, n_buttons), device=device, dtype=dtype
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Update button tensor in-place (avoid dynamic shapes for torch.compile)
|
| 540 |
+
block_state.button_tensor.zero_()
|
| 541 |
+
if block_state.button:
|
| 542 |
+
for btn_id in block_state.button:
|
| 543 |
+
if 0 <= btn_id < n_buttons:
|
| 544 |
+
block_state.button_tensor[0, 0, btn_id] = 1.0
|
| 545 |
+
|
| 546 |
+
# Create or reuse mouse tensor [1, 1, 2]
|
| 547 |
+
if block_state.mouse_tensor is None:
|
| 548 |
+
block_state.mouse_tensor = torch.zeros(
|
| 549 |
+
(1, 1, 2), device=device, dtype=dtype
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
# Update mouse tensor in-place
|
| 553 |
+
mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0)
|
| 554 |
+
block_state.mouse_tensor[0, 0, 0] = mouse[0]
|
| 555 |
+
block_state.mouse_tensor[0, 0, 1] = mouse[1]
|
| 556 |
+
|
| 557 |
+
# Create or reuse scroll tensor [1, 1, 1]
|
| 558 |
+
if block_state.scroll_tensor is None:
|
| 559 |
+
block_state.scroll_tensor = torch.zeros(
|
| 560 |
+
(1, 1, 1), device=device, dtype=dtype
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
# Update scroll tensor in-place (sign of scroll value: -1, 0, or 1)
|
| 564 |
+
scroll = block_state.scroll if block_state.scroll is not None else 0
|
| 565 |
+
block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0)
|
| 566 |
+
|
| 567 |
+
self.set_block_state(state, block_state)
|
| 568 |
+
return components, state
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
# ---------------------------------------------------------------------------
|
| 572 |
+
# Pipeline step: Set Timesteps
|
| 573 |
+
# ---------------------------------------------------------------------------
|
| 574 |
+
|
| 575 |
+
class WorldEngineSetTimestepsStep(ModularPipelineBlocks):
|
| 576 |
+
"""Sets up the scheduler sigmas for rectified flow denoising."""
|
| 577 |
+
|
| 578 |
+
model_name = "world_engine"
|
| 579 |
+
|
| 580 |
+
@property
|
| 581 |
+
def description(self) -> str:
|
| 582 |
+
return "Sets up scheduler sigmas for rectified flow denoising"
|
| 583 |
+
|
| 584 |
+
@property
|
| 585 |
+
def expected_components(self) -> list[ComponentSpec]:
|
| 586 |
+
return []
|
| 587 |
+
|
| 588 |
+
@property
|
| 589 |
+
def expected_configs(self) -> list[ConfigSpec]:
|
| 590 |
+
return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])]
|
| 591 |
+
|
| 592 |
+
@property
|
| 593 |
+
def inputs(self) -> list[InputParam]:
|
| 594 |
+
return [
|
| 595 |
+
InputParam(
|
| 596 |
+
"scheduler_sigmas",
|
| 597 |
+
type_hint=list[float],
|
| 598 |
+
description="Custom scheduler sigmas (overrides config)",
|
| 599 |
+
),
|
| 600 |
+
InputParam(
|
| 601 |
+
"frame_timestamp",
|
| 602 |
+
type_hint=torch.Tensor,
|
| 603 |
+
description="Current frame timestamp",
|
| 604 |
+
),
|
| 605 |
+
]
|
| 606 |
+
|
| 607 |
+
@property
|
| 608 |
+
def intermediate_outputs(self) -> list[OutputParam]:
|
| 609 |
+
return [
|
| 610 |
+
OutputParam(
|
| 611 |
+
"scheduler_sigmas",
|
| 612 |
+
type_hint=torch.Tensor,
|
| 613 |
+
description="Tensor of scheduler sigmas for denoising",
|
| 614 |
+
),
|
| 615 |
+
OutputParam(
|
| 616 |
+
"frame_timestamp",
|
| 617 |
+
type_hint=torch.Tensor,
|
| 618 |
+
description="Current frame timestamp (unscaled counter)",
|
| 619 |
+
),
|
| 620 |
+
OutputParam(
|
| 621 |
+
"ts_mult",
|
| 622 |
+
type_hint=int,
|
| 623 |
+
description="Timestamp multiplier (base_fps // latent_fps)",
|
| 624 |
+
),
|
| 625 |
+
]
|
| 626 |
+
|
| 627 |
+
@torch.no_grad()
|
| 628 |
+
def __call__(
|
| 629 |
+
self, components: ModularPipeline, state: PipelineState
|
| 630 |
+
) -> PipelineState:
|
| 631 |
+
block_state = self.get_block_state(state)
|
| 632 |
+
device = components._execution_device
|
| 633 |
+
dtype = components.transformer.dtype
|
| 634 |
+
|
| 635 |
+
# Use provided sigmas or get from config
|
| 636 |
+
sigmas = block_state.scheduler_sigmas
|
| 637 |
+
if sigmas is None:
|
| 638 |
+
sigmas = components.config.scheduler_sigmas
|
| 639 |
+
block_state.scheduler_sigmas = torch.tensor(
|
| 640 |
+
sigmas, device=device, dtype=dtype
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
frame_ts = block_state.frame_timestamp
|
| 644 |
+
if frame_ts is None:
|
| 645 |
+
frame_ts = torch.tensor([[0]], dtype=torch.long, device=device)
|
| 646 |
+
elif isinstance(frame_ts, int):
|
| 647 |
+
frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device)
|
| 648 |
+
|
| 649 |
+
# Compute ts_mult: ratio of base_fps to latent_fps
|
| 650 |
+
t_cfg = components.transformer.config
|
| 651 |
+
base_fps = getattr(t_cfg, "base_fps", 60)
|
| 652 |
+
inference_fps = getattr(t_cfg, "inference_fps", base_fps)
|
| 653 |
+
temporal_compression = getattr(t_cfg, "temporal_compression", 1)
|
| 654 |
+
latent_fps = inference_fps / temporal_compression
|
| 655 |
+
ts_mult = int(base_fps) // int(latent_fps)
|
| 656 |
+
block_state.ts_mult = ts_mult
|
| 657 |
+
block_state.frame_timestamp = frame_ts
|
| 658 |
+
|
| 659 |
+
self.set_block_state(state, block_state)
|
| 660 |
+
return components, state
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
# ---------------------------------------------------------------------------
|
| 664 |
+
# Pipeline step: Setup KV Cache
|
| 665 |
+
# ---------------------------------------------------------------------------
|
| 666 |
+
|
| 667 |
+
class WorldEngineSetupKVCacheStep(ModularPipelineBlocks):
|
| 668 |
+
"""Initializes or reuses the KV cache for autoregressive generation."""
|
| 669 |
+
|
| 670 |
+
model_name = "world_engine"
|
| 671 |
+
|
| 672 |
+
@property
|
| 673 |
+
def description(self) -> str:
|
| 674 |
+
return "Initializes or reuses KV cache for autoregressive frame generation"
|
| 675 |
+
|
| 676 |
+
@property
|
| 677 |
+
def expected_components(self) -> list[ComponentSpec]:
|
| 678 |
+
return []
|
| 679 |
+
|
| 680 |
+
@property
|
| 681 |
+
def inputs(self) -> list[InputParam]:
|
| 682 |
+
return [
|
| 683 |
+
InputParam(
|
| 684 |
+
"kv_cache",
|
| 685 |
+
type_hint=StaticKVCache | None,
|
| 686 |
+
description="Existing KV cache (will be reused if provided)",
|
| 687 |
+
),
|
| 688 |
+
InputParam(
|
| 689 |
+
"reset_cache",
|
| 690 |
+
type_hint=bool,
|
| 691 |
+
default=False,
|
| 692 |
+
description="If True, reset the KV cache even if one exists",
|
| 693 |
+
),
|
| 694 |
+
]
|
| 695 |
+
|
| 696 |
+
@property
|
| 697 |
+
def intermediate_outputs(self) -> list[OutputParam]:
|
| 698 |
+
return [
|
| 699 |
+
OutputParam(
|
| 700 |
+
"kv_cache",
|
| 701 |
+
type_hint=StaticKVCache,
|
| 702 |
+
description="KV cache for transformer attention",
|
| 703 |
+
),
|
| 704 |
+
]
|
| 705 |
+
|
| 706 |
+
@torch.no_grad()
|
| 707 |
+
def __call__(
|
| 708 |
+
self, components: ModularPipeline, state: PipelineState
|
| 709 |
+
) -> PipelineState:
|
| 710 |
+
block_state = self.get_block_state(state)
|
| 711 |
+
device = components._execution_device
|
| 712 |
+
dtype = components.transformer.dtype
|
| 713 |
+
|
| 714 |
+
# Create or reuse KV cache
|
| 715 |
+
if block_state.kv_cache is None:
|
| 716 |
+
block_state.kv_cache = StaticKVCache(
|
| 717 |
+
components.transformer.config,
|
| 718 |
+
batch_size=1,
|
| 719 |
+
dtype=dtype,
|
| 720 |
+
).to(device)
|
| 721 |
+
elif block_state.reset_cache:
|
| 722 |
+
block_state.kv_cache.reset()
|
| 723 |
+
|
| 724 |
+
self.set_block_state(state, block_state)
|
| 725 |
+
return components, state
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
# ---------------------------------------------------------------------------
|
| 729 |
+
# Pipeline step: Prepare Latents
|
| 730 |
+
# ---------------------------------------------------------------------------
|
| 731 |
+
|
| 732 |
+
class WorldEnginePrepareLatentsStep(ModularPipelineBlocks):
|
| 733 |
+
"""Prepares latents for frame generation, optionally encoding an input image."""
|
| 734 |
+
|
| 735 |
+
model_name = "world_engine"
|
| 736 |
+
|
| 737 |
+
@property
|
| 738 |
+
def description(self) -> str:
|
| 739 |
+
return (
|
| 740 |
+
"Prepares latents for frame generation. If an image is provided on the "
|
| 741 |
+
"first frame, encodes it and caches it as context. Always creates fresh "
|
| 742 |
+
"random noise for the actual denoising."
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
@property
|
| 746 |
+
def expected_components(self) -> list[ComponentSpec]:
|
| 747 |
+
return [
|
| 748 |
+
ComponentSpec(
|
| 749 |
+
"image_processor",
|
| 750 |
+
VaeImageProcessor,
|
| 751 |
+
config=FrozenDict(
|
| 752 |
+
{
|
| 753 |
+
"vae_scale_factor": 16,
|
| 754 |
+
"do_normalize": False,
|
| 755 |
+
"do_convert_rgb": False,
|
| 756 |
+
}
|
| 757 |
+
),
|
| 758 |
+
default_creation_method="from_config",
|
| 759 |
+
),
|
| 760 |
+
]
|
| 761 |
+
|
| 762 |
+
@property
|
| 763 |
+
def expected_configs(self) -> list[ConfigSpec]:
|
| 764 |
+
return [
|
| 765 |
+
ConfigSpec("channels", 16),
|
| 766 |
+
ConfigSpec("height", 16),
|
| 767 |
+
ConfigSpec("width", 16),
|
| 768 |
+
ConfigSpec("patch", [2, 2]),
|
| 769 |
+
ConfigSpec("vae_scale_factor", 16),
|
| 770 |
+
]
|
| 771 |
+
|
| 772 |
+
@property
|
| 773 |
+
def inputs(self) -> list[InputParam]:
|
| 774 |
+
return [
|
| 775 |
+
InputParam(
|
| 776 |
+
"image",
|
| 777 |
+
type_hint=PIL.Image.Image | torch.Tensor,
|
| 778 |
+
description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame",
|
| 779 |
+
),
|
| 780 |
+
InputParam(
|
| 781 |
+
"latents",
|
| 782 |
+
type_hint=torch.Tensor,
|
| 783 |
+
description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.",
|
| 784 |
+
),
|
| 785 |
+
InputParam(
|
| 786 |
+
"use_random_latents",
|
| 787 |
+
type_hint=bool,
|
| 788 |
+
default=True,
|
| 789 |
+
description="If True, always generate fresh random latents. If False, use provided latents.",
|
| 790 |
+
),
|
| 791 |
+
InputParam(
|
| 792 |
+
"kv_cache",
|
| 793 |
+
description="KV cache to update",
|
| 794 |
+
),
|
| 795 |
+
InputParam(
|
| 796 |
+
"frame_timestamp",
|
| 797 |
+
type_hint=torch.Tensor,
|
| 798 |
+
description="Current frame timestamp",
|
| 799 |
+
),
|
| 800 |
+
InputParam(
|
| 801 |
+
"prompt_embeds",
|
| 802 |
+
type_hint=torch.Tensor,
|
| 803 |
+
description="Prompt embeddings for cache pass",
|
| 804 |
+
),
|
| 805 |
+
InputParam(
|
| 806 |
+
"prompt_pad_mask",
|
| 807 |
+
type_hint=torch.Tensor,
|
| 808 |
+
description="Prompt padding mask",
|
| 809 |
+
),
|
| 810 |
+
InputParam(
|
| 811 |
+
"button_tensor",
|
| 812 |
+
type_hint=torch.Tensor,
|
| 813 |
+
description="Button tensor for cache pass",
|
| 814 |
+
),
|
| 815 |
+
InputParam(
|
| 816 |
+
"mouse_tensor",
|
| 817 |
+
type_hint=torch.Tensor,
|
| 818 |
+
description="Mouse tensor for cache pass",
|
| 819 |
+
),
|
| 820 |
+
InputParam(
|
| 821 |
+
"scroll_tensor",
|
| 822 |
+
type_hint=torch.Tensor,
|
| 823 |
+
description="Scroll tensor for cache pass",
|
| 824 |
+
),
|
| 825 |
+
InputParam(
|
| 826 |
+
"generator",
|
| 827 |
+
type_hint=torch.Generator,
|
| 828 |
+
default=None,
|
| 829 |
+
description="torch Generator for deterministic output",
|
| 830 |
+
),
|
| 831 |
+
InputParam(
|
| 832 |
+
"ts_mult",
|
| 833 |
+
required=True,
|
| 834 |
+
type_hint=int,
|
| 835 |
+
description="Timestamp multiplier (base_fps // latent_fps)",
|
| 836 |
+
),
|
| 837 |
+
]
|
| 838 |
+
|
| 839 |
+
@property
|
| 840 |
+
def intermediate_outputs(self) -> list[OutputParam]:
|
| 841 |
+
return [
|
| 842 |
+
OutputParam(
|
| 843 |
+
"latents",
|
| 844 |
+
type_hint=torch.Tensor,
|
| 845 |
+
description="Latent tensor for denoising [1, 1, C, H, W]",
|
| 846 |
+
),
|
| 847 |
+
]
|
| 848 |
+
|
| 849 |
+
@staticmethod
|
| 850 |
+
def _cache_pass(
|
| 851 |
+
transformer,
|
| 852 |
+
x,
|
| 853 |
+
frame_timestamp,
|
| 854 |
+
frame_idx,
|
| 855 |
+
prompt_emb,
|
| 856 |
+
prompt_pad_mask,
|
| 857 |
+
mouse,
|
| 858 |
+
button,
|
| 859 |
+
scroll,
|
| 860 |
+
kv_cache,
|
| 861 |
+
):
|
| 862 |
+
"""Cache pass to persist frame in KV cache."""
|
| 863 |
+
kv_cache.set_frozen(False)
|
| 864 |
+
transformer(
|
| 865 |
+
x=x,
|
| 866 |
+
sigma=x.new_zeros((x.size(0), x.size(1))),
|
| 867 |
+
frame_timestamp=frame_timestamp,
|
| 868 |
+
frame_idx=frame_idx,
|
| 869 |
+
prompt_emb=prompt_emb,
|
| 870 |
+
prompt_pad_mask=prompt_pad_mask,
|
| 871 |
+
mouse=mouse,
|
| 872 |
+
button=button,
|
| 873 |
+
scroll=scroll,
|
| 874 |
+
kv_cache=kv_cache,
|
| 875 |
+
)
|
| 876 |
+
|
| 877 |
+
@torch.inference_mode()
|
| 878 |
+
def __call__(
|
| 879 |
+
self, components: ModularPipeline, state: PipelineState
|
| 880 |
+
) -> PipelineState:
|
| 881 |
+
block_state = self.get_block_state(state)
|
| 882 |
+
device = components._execution_device
|
| 883 |
+
dtype = components.transformer.dtype
|
| 884 |
+
|
| 885 |
+
# Get latent shape info
|
| 886 |
+
channels = components.config.channels
|
| 887 |
+
height = components.config.height # patch grid height
|
| 888 |
+
width = components.config.width # patch grid width
|
| 889 |
+
patch = components.config.patch
|
| 890 |
+
vae_scale_factor = components.config.vae_scale_factor
|
| 891 |
+
|
| 892 |
+
pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch)
|
| 893 |
+
latent_H = height * pH
|
| 894 |
+
latent_W = width * pW
|
| 895 |
+
shape = (1, 1, channels, latent_H, latent_W)
|
| 896 |
+
|
| 897 |
+
# Pixel dimensions for image preprocessing
|
| 898 |
+
pixel_H = latent_H * vae_scale_factor
|
| 899 |
+
pixel_W = latent_W * vae_scale_factor
|
| 900 |
+
|
| 901 |
+
if block_state.image is not None:
|
| 902 |
+
image = block_state.image
|
| 903 |
+
# Preprocess: PIL/tensor -> [B, C, H, W] float32 in [0, 1]
|
| 904 |
+
image = components.image_processor.preprocess(
|
| 905 |
+
image,
|
| 906 |
+
height=pixel_H,
|
| 907 |
+
width=pixel_W,
|
| 908 |
+
)
|
| 909 |
+
# Convert to [H, W, 3] uint8 for VAE encoder
|
| 910 |
+
image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8)
|
| 911 |
+
|
| 912 |
+
assert image.dtype == torch.uint8, (
|
| 913 |
+
f"Expected uint8 image, got {image.dtype}"
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
# Temporal VAE expects [T, H, W, C]; repeat the single frame
|
| 917 |
+
t_down = getattr(components.vae, "t_downscale", 1)
|
| 918 |
+
if t_down > 1:
|
| 919 |
+
image = image.unsqueeze(0).expand(t_down, -1, -1, -1)
|
| 920 |
+
|
| 921 |
+
latents = components.vae.encode(image)
|
| 922 |
+
latents = latents.unsqueeze(1)
|
| 923 |
+
|
| 924 |
+
# Run cache pass to persist encoded frame
|
| 925 |
+
ts_mult = block_state.ts_mult
|
| 926 |
+
self._cache_pass(
|
| 927 |
+
components.transformer,
|
| 928 |
+
latents,
|
| 929 |
+
block_state.frame_timestamp * ts_mult,
|
| 930 |
+
block_state.frame_timestamp,
|
| 931 |
+
block_state.prompt_embeds,
|
| 932 |
+
block_state.prompt_pad_mask,
|
| 933 |
+
block_state.mouse_tensor,
|
| 934 |
+
block_state.button_tensor,
|
| 935 |
+
block_state.scroll_tensor,
|
| 936 |
+
block_state.kv_cache,
|
| 937 |
+
)
|
| 938 |
+
block_state.frame_timestamp.add_(1)
|
| 939 |
+
|
| 940 |
+
# Generate latents based on use_random_latents flag
|
| 941 |
+
if block_state.use_random_latents or block_state.latents is None:
|
| 942 |
+
block_state.latents = torch.randn(
|
| 943 |
+
shape, device=device, dtype=torch.bfloat16
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
self.set_block_state(state, block_state)
|
| 947 |
+
return components, state
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
# ---------------------------------------------------------------------------
|
| 951 |
+
# Pipeline step: Before Denoise (sequential wrapper)
|
| 952 |
+
# ---------------------------------------------------------------------------
|
| 953 |
+
|
| 954 |
+
class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks):
|
| 955 |
+
"""Sequential pipeline that prepares all inputs for denoising."""
|
| 956 |
+
|
| 957 |
+
block_classes = [
|
| 958 |
+
WorldEngineSetTimestepsStep,
|
| 959 |
+
WorldEngineSetupKVCacheStep,
|
| 960 |
+
WorldEnginePrepareLatentsStep,
|
| 961 |
+
]
|
| 962 |
+
block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"]
|
| 963 |
+
|
| 964 |
+
@property
|
| 965 |
+
def description(self) -> str:
|
| 966 |
+
return (
|
| 967 |
+
"Before denoise step that prepares inputs for denoising:\n"
|
| 968 |
+
" - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n"
|
| 969 |
+
" - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n"
|
| 970 |
+
" - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise"
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
|
| 974 |
+
# ---------------------------------------------------------------------------
|
| 975 |
+
# Pipeline step: Denoise Loop
|
| 976 |
+
# ---------------------------------------------------------------------------
|
| 977 |
+
|
| 978 |
+
class WorldEngineDenoiseLoop(ModularPipelineBlocks):
|
| 979 |
+
"""Denoises latents using rectified flow and updates KV cache."""
|
| 980 |
+
|
| 981 |
+
model_name = "world_engine"
|
| 982 |
+
|
| 983 |
+
@property
|
| 984 |
+
def expected_components(self) -> list[ComponentSpec]:
|
| 985 |
+
return [ComponentSpec("transformer", AutoModel)]
|
| 986 |
+
|
| 987 |
+
@property
|
| 988 |
+
def description(self) -> str:
|
| 989 |
+
return (
|
| 990 |
+
"Denoises latents using rectified flow (x = x + dsigma * v) "
|
| 991 |
+
"and updates KV cache for autoregressive generation."
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
@property
|
| 995 |
+
def inputs(self) -> list[InputParam]:
|
| 996 |
+
return [
|
| 997 |
+
InputParam("scheduler_sigmas", required=True, type_hint=torch.Tensor, description="Scheduler sigmas for denoising"),
|
| 998 |
+
InputParam("latents", required=True, type_hint=torch.Tensor, description="Initial noisy latents [1, 1, C, H, W]"),
|
| 999 |
+
InputParam("kv_cache", required=True, description="KV cache for transformer attention"),
|
| 1000 |
+
InputParam("frame_timestamp", required=True, type_hint=torch.Tensor, description="Current frame timestamp"),
|
| 1001 |
+
InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Text embeddings for conditioning"),
|
| 1002 |
+
InputParam("prompt_pad_mask", type_hint=torch.Tensor, description="Padding mask for prompt embeddings"),
|
| 1003 |
+
InputParam("button_tensor", required=True, type_hint=torch.Tensor, description="One-hot encoded button tensor"),
|
| 1004 |
+
InputParam("mouse_tensor", required=True, type_hint=torch.Tensor, description="Mouse velocity tensor"),
|
| 1005 |
+
InputParam("scroll_tensor", required=True, type_hint=torch.Tensor, description="Scroll wheel sign tensor"),
|
| 1006 |
+
InputParam("ts_mult", required=True, type_hint=int, description="Timestamp multiplier (base_fps // latent_fps)"),
|
| 1007 |
+
]
|
| 1008 |
+
|
| 1009 |
+
@property
|
| 1010 |
+
def intermediate_outputs(self) -> list[OutputParam]:
|
| 1011 |
+
return [
|
| 1012 |
+
OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
|
| 1013 |
+
]
|
| 1014 |
+
|
| 1015 |
+
@staticmethod
|
| 1016 |
+
def _denoise_pass(
|
| 1017 |
+
transformer, x, sigmas, frame_timestamp, frame_idx,
|
| 1018 |
+
prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache,
|
| 1019 |
+
):
|
| 1020 |
+
"""Denoising loop using rectified flow."""
|
| 1021 |
+
kv_cache.set_frozen(True)
|
| 1022 |
+
sigma = x.new_empty((x.size(0), x.size(1)))
|
| 1023 |
+
for step_sig, step_dsig in zip(sigmas, sigmas.diff()):
|
| 1024 |
+
v = transformer(
|
| 1025 |
+
x=x, sigma=sigma.fill_(step_sig),
|
| 1026 |
+
frame_timestamp=frame_timestamp, frame_idx=frame_idx,
|
| 1027 |
+
prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask,
|
| 1028 |
+
mouse=mouse, button=button, scroll=scroll,
|
| 1029 |
+
kv_cache=kv_cache,
|
| 1030 |
+
)
|
| 1031 |
+
x = x + step_dsig * v
|
| 1032 |
+
return x
|
| 1033 |
+
|
| 1034 |
+
@staticmethod
|
| 1035 |
+
def _cache_pass(
|
| 1036 |
+
transformer, x, frame_timestamp, frame_idx,
|
| 1037 |
+
prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache,
|
| 1038 |
+
):
|
| 1039 |
+
"""Cache pass to persist frame for next generation."""
|
| 1040 |
+
kv_cache.set_frozen(False)
|
| 1041 |
+
transformer(
|
| 1042 |
+
x=x, sigma=x.new_zeros((x.size(0), x.size(1))),
|
| 1043 |
+
frame_timestamp=frame_timestamp, frame_idx=frame_idx,
|
| 1044 |
+
prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask,
|
| 1045 |
+
mouse=mouse, button=button, scroll=scroll,
|
| 1046 |
+
kv_cache=kv_cache,
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
@torch.inference_mode()
|
| 1050 |
+
def __call__(
|
| 1051 |
+
self, components: ModularPipeline, state: PipelineState
|
| 1052 |
+
) -> PipelineState:
|
| 1053 |
+
block_state = self.get_block_state(state)
|
| 1054 |
+
ts_mult = block_state.ts_mult
|
| 1055 |
+
block_state.latents = self._denoise_pass(
|
| 1056 |
+
components.transformer,
|
| 1057 |
+
block_state.latents,
|
| 1058 |
+
block_state.scheduler_sigmas,
|
| 1059 |
+
block_state.frame_timestamp * ts_mult,
|
| 1060 |
+
block_state.frame_timestamp,
|
| 1061 |
+
block_state.prompt_embeds,
|
| 1062 |
+
block_state.prompt_pad_mask,
|
| 1063 |
+
block_state.mouse_tensor,
|
| 1064 |
+
block_state.button_tensor,
|
| 1065 |
+
block_state.scroll_tensor,
|
| 1066 |
+
block_state.kv_cache,
|
| 1067 |
+
).clone()
|
| 1068 |
+
|
| 1069 |
+
self._cache_pass(
|
| 1070 |
+
components.transformer,
|
| 1071 |
+
block_state.latents,
|
| 1072 |
+
block_state.frame_timestamp * ts_mult,
|
| 1073 |
+
block_state.frame_timestamp,
|
| 1074 |
+
block_state.prompt_embeds,
|
| 1075 |
+
block_state.prompt_pad_mask,
|
| 1076 |
+
block_state.mouse_tensor,
|
| 1077 |
+
block_state.button_tensor,
|
| 1078 |
+
block_state.scroll_tensor,
|
| 1079 |
+
block_state.kv_cache,
|
| 1080 |
+
)
|
| 1081 |
+
block_state.frame_timestamp.add_(1)
|
| 1082 |
+
|
| 1083 |
+
self.set_block_state(state, block_state)
|
| 1084 |
+
return components, state
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
# ---------------------------------------------------------------------------
|
| 1088 |
+
# Pipeline step: Decode
|
| 1089 |
+
# ---------------------------------------------------------------------------
|
| 1090 |
+
|
| 1091 |
+
class WorldEngineDecodeStep(ModularPipelineBlocks):
|
| 1092 |
+
"""Decodes denoised latents back to RGB image using VAE."""
|
| 1093 |
+
|
| 1094 |
+
model_name = "world_engine"
|
| 1095 |
+
|
| 1096 |
+
@property
|
| 1097 |
+
def expected_components(self) -> list[ComponentSpec]:
|
| 1098 |
+
return [
|
| 1099 |
+
ComponentSpec("vae", AutoModel),
|
| 1100 |
+
ComponentSpec(
|
| 1101 |
+
"image_processor",
|
| 1102 |
+
VaeImageProcessor,
|
| 1103 |
+
config=FrozenDict(
|
| 1104 |
+
{
|
| 1105 |
+
"vae_scale_factor": 16,
|
| 1106 |
+
"do_normalize": False,
|
| 1107 |
+
"do_convert_rgb": True,
|
| 1108 |
+
}
|
| 1109 |
+
),
|
| 1110 |
+
default_creation_method="from_config",
|
| 1111 |
+
),
|
| 1112 |
+
]
|
| 1113 |
+
|
| 1114 |
+
@property
|
| 1115 |
+
def description(self) -> str:
|
| 1116 |
+
return "Decodes denoised latents to RGB image using the VAE decoder"
|
| 1117 |
+
|
| 1118 |
+
@property
|
| 1119 |
+
def inputs(self) -> list[InputParam]:
|
| 1120 |
+
return [
|
| 1121 |
+
InputParam("latents", required=True, type_hint=torch.Tensor, description="Denoised latent tensor [1, 1, C, H, W]"),
|
| 1122 |
+
InputParam("output_type", default="pil", description="The output format for the generated images (pil, latent, pt, or np)"),
|
| 1123 |
+
]
|
| 1124 |
+
|
| 1125 |
+
@property
|
| 1126 |
+
def intermediate_outputs(self) -> list[OutputParam]:
|
| 1127 |
+
return [
|
| 1128 |
+
OutputParam(
|
| 1129 |
+
"images",
|
| 1130 |
+
type_hint=PIL.Image.Image | torch.Tensor | np.ndarray,
|
| 1131 |
+
description="Decoded RGB image in requested output format",
|
| 1132 |
+
),
|
| 1133 |
+
]
|
| 1134 |
+
|
| 1135 |
+
@torch.no_grad()
|
| 1136 |
+
def __call__(
|
| 1137 |
+
self, components: ModularPipeline, state: PipelineState
|
| 1138 |
+
) -> PipelineState:
|
| 1139 |
+
block_state = self.get_block_state(state)
|
| 1140 |
+
latents = block_state.latents
|
| 1141 |
+
output_type = block_state.output_type or "pil"
|
| 1142 |
+
|
| 1143 |
+
if output_type == "latent":
|
| 1144 |
+
block_state.images = latents
|
| 1145 |
+
else:
|
| 1146 |
+
# Decode to image
|
| 1147 |
+
frames = components.vae.decode(latents.squeeze(1))
|
| 1148 |
+
|
| 1149 |
+
if frames.dim() == 3:
|
| 1150 |
+
# Single frame [H, W, C] — wrap so the loop below works uniformly
|
| 1151 |
+
frames = frames.unsqueeze(0)
|
| 1152 |
+
|
| 1153 |
+
# Postprocess based on output_type
|
| 1154 |
+
if output_type == "pt":
|
| 1155 |
+
block_state.images = frames
|
| 1156 |
+
elif output_type == "np":
|
| 1157 |
+
block_state.images = frames.cpu().numpy()
|
| 1158 |
+
else: # "pil"
|
| 1159 |
+
block_state.images = [
|
| 1160 |
+
PIL.Image.fromarray(f.cpu().numpy()) for f in frames
|
| 1161 |
+
]
|
| 1162 |
+
|
| 1163 |
+
# Clear latents so next frame generates fresh random noise
|
| 1164 |
+
block_state.latents = None
|
| 1165 |
+
self.set_block_state(state, block_state)
|
| 1166 |
+
return components, state
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
# ---------------------------------------------------------------------------
|
| 1170 |
+
# Top-level block registry
|
| 1171 |
+
# ---------------------------------------------------------------------------
|
| 1172 |
+
|
| 1173 |
+
AUTO_BLOCKS = InsertableDict(
|
| 1174 |
+
[
|
| 1175 |
+
("text_encoder", WorldEngineTextEncoderStep),
|
| 1176 |
+
("controller_encoder", WorldEngineControllerEncoderStep),
|
| 1177 |
+
("before_denoise", WorldEngineBeforeDenoiseStep),
|
| 1178 |
+
("denoise", WorldEngineDenoiseLoop),
|
| 1179 |
+
("decode", WorldEngineDecodeStep),
|
| 1180 |
+
]
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
|
| 1184 |
+
class WorldEngineBlocks(SequentialPipelineBlocks):
|
| 1185 |
+
"""Sequential pipeline blocks for WorldEngine frame generation."""
|
| 1186 |
+
|
| 1187 |
+
block_classes = list(AUTO_BLOCKS.values())
|
| 1188 |
+
block_names = list(AUTO_BLOCKS.keys())
|
modular_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "WorldEngineBlocks",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"ModularPipelineBlocks": "modular_blocks.WorldEngineBlocks"
|
| 6 |
+
}
|
| 7 |
+
}
|
modular_model_index.json
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_blocks_class_name": "WorldEngineBlocks",
|
| 3 |
+
"_class_name": "ModularPipeline",
|
| 4 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 5 |
+
"channels": 32,
|
| 6 |
+
"height": 16,
|
| 7 |
+
"width": 32,
|
| 8 |
+
"patch": [
|
| 9 |
+
2,
|
| 10 |
+
2
|
| 11 |
+
],
|
| 12 |
+
"vae_scale_factor": 16,
|
| 13 |
+
"n_buttons": 256,
|
| 14 |
+
"tokens_per_frame": 512,
|
| 15 |
+
"scheduler_sigmas": [
|
| 16 |
+
1.0,
|
| 17 |
+
0.9,
|
| 18 |
+
0.75,
|
| 19 |
+
0.3,
|
| 20 |
+
0.0
|
| 21 |
+
],
|
| 22 |
+
"transformer": [
|
| 23 |
+
null,
|
| 24 |
+
null,
|
| 25 |
+
{
|
| 26 |
+
"pretrained_model_name_or_path": "Overworld/Waypoint-1.5-1B",
|
| 27 |
+
"subfolder": "transformer",
|
| 28 |
+
"type_hint": [
|
| 29 |
+
"diffusers",
|
| 30 |
+
"AutoModel"
|
| 31 |
+
],
|
| 32 |
+
"revision": null,
|
| 33 |
+
"variant": null
|
| 34 |
+
}
|
| 35 |
+
],
|
| 36 |
+
"vae": [
|
| 37 |
+
null,
|
| 38 |
+
null,
|
| 39 |
+
{
|
| 40 |
+
"pretrained_model_name_or_path": "Overworld/Waypoint-1.5-1B",
|
| 41 |
+
"subfolder": "vae",
|
| 42 |
+
"type_hint": [
|
| 43 |
+
"diffusers",
|
| 44 |
+
"AutoModel"
|
| 45 |
+
],
|
| 46 |
+
"revision": null,
|
| 47 |
+
"variant": null
|
| 48 |
+
}
|
| 49 |
+
],
|
| 50 |
+
"text_encoder": [
|
| 51 |
+
null,
|
| 52 |
+
null,
|
| 53 |
+
{
|
| 54 |
+
"pretrained_model_name_or_path": "google/umt5-xl",
|
| 55 |
+
"type_hint": [
|
| 56 |
+
"transformers",
|
| 57 |
+
"UMT5EncoderModel"
|
| 58 |
+
],
|
| 59 |
+
"revision": null,
|
| 60 |
+
"variant": null
|
| 61 |
+
}
|
| 62 |
+
],
|
| 63 |
+
"tokenizer": [
|
| 64 |
+
null,
|
| 65 |
+
null,
|
| 66 |
+
{
|
| 67 |
+
"pretrained_model_name_or_path": "google/umt5-xl",
|
| 68 |
+
"type_hint": [
|
| 69 |
+
"transformers",
|
| 70 |
+
"AutoTokenizer"
|
| 71 |
+
],
|
| 72 |
+
"revision": null,
|
| 73 |
+
"variant": null
|
| 74 |
+
}
|
| 75 |
+
]
|
| 76 |
+
}
|
transformer/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
from .model import WorldModel
|
transformer/config.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "WorldModel",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "model.WorldModel"
|
| 6 |
+
},
|
| 7 |
+
"d_model": 2048,
|
| 8 |
+
"n_heads": 32,
|
| 9 |
+
"n_kv_heads": 16,
|
| 10 |
+
"n_layers": 24,
|
| 11 |
+
"mlp_ratio": 4,
|
| 12 |
+
"channels": 32,
|
| 13 |
+
"height": 16,
|
| 14 |
+
"width": 32,
|
| 15 |
+
"patch": [
|
| 16 |
+
2,
|
| 17 |
+
2
|
| 18 |
+
],
|
| 19 |
+
"tokens_per_frame": 512,
|
| 20 |
+
"n_frames": 512,
|
| 21 |
+
"local_window": 16,
|
| 22 |
+
"global_window": 128,
|
| 23 |
+
"global_attn_period": 4,
|
| 24 |
+
"global_pinned_dilation": 8,
|
| 25 |
+
"global_attn_offset": -1,
|
| 26 |
+
"value_residual": true,
|
| 27 |
+
"gated_attn": false,
|
| 28 |
+
"n_buttons": 256,
|
| 29 |
+
"ctrl_conditioning": true,
|
| 30 |
+
"ctrl_conditioning_period": 3,
|
| 31 |
+
"ctrl_cond_dropout": 0.0,
|
| 32 |
+
"prompt_conditioning": null,
|
| 33 |
+
"prompt_conditioning_period": 3,
|
| 34 |
+
"prompt_embedding_dim": 2048,
|
| 35 |
+
"prompt_cond_dropout": 0.0,
|
| 36 |
+
"noise_conditioning": "wan",
|
| 37 |
+
"base_fps": 15,
|
| 38 |
+
"causal": true,
|
| 39 |
+
"mlp_gradient_checkpointing": true,
|
| 40 |
+
"block_gradient_checkpointing": true,
|
| 41 |
+
"rope_impl": "ortho",
|
| 42 |
+
"moe": false,
|
| 43 |
+
"moe_top_k": 2,
|
| 44 |
+
"moe_n_experts": 8,
|
| 45 |
+
"moe_mlp_ratio": null,
|
| 46 |
+
"gated_linear": false,
|
| 47 |
+
"temporal_compression": 4,
|
| 48 |
+
"inference_fps": 60,
|
| 49 |
+
"taehv_ae": true,
|
| 50 |
+
"rope_nyquist_frac": 0.8,
|
| 51 |
+
"rope_theta": 10000.0,
|
| 52 |
+
"scheduler_sigmas": [
|
| 53 |
+
1.0,
|
| 54 |
+
0.9,
|
| 55 |
+
0.75,
|
| 56 |
+
0.3,
|
| 57 |
+
0.0
|
| 58 |
+
]
|
| 59 |
+
}
|
transformer/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:467ea80d76f63a82a2f5ba80d1170c27dc1060e60c3ede341105bff89b468700
|
| 3 |
+
size 7443340184
|
transformer/model.py
ADDED
|
@@ -0,0 +1,1100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""WorldModel transformer for frame generation.
|
| 17 |
+
|
| 18 |
+
Single-file model containing all building blocks: nn primitives, attention,
|
| 19 |
+
RoPE, quantization, inference caching, and the top-level WorldModel.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import warnings
|
| 23 |
+
|
| 24 |
+
import einops as eo
|
| 25 |
+
import torch
|
| 26 |
+
from torch import nn, Tensor
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
from tensordict import TensorDict
|
| 29 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 30 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling
|
| 34 |
+
import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa
|
| 35 |
+
HAS_FBGEMM = True
|
| 36 |
+
except ImportError:
|
| 37 |
+
HAS_FBGEMM = False
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# NN primitives
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
class NoCastModule(torch.nn.Module):
|
| 45 |
+
"""Module that prevents dtype casting during .to() calls."""
|
| 46 |
+
|
| 47 |
+
def _apply(self, fn):
|
| 48 |
+
def keep_dtype(t):
|
| 49 |
+
old_dtype = t.dtype
|
| 50 |
+
out = fn(t)
|
| 51 |
+
if out.dtype is not old_dtype:
|
| 52 |
+
warnings.warn(
|
| 53 |
+
f"{self.__class__.__name__}: requested dtype cast ignored; "
|
| 54 |
+
f"keeping {old_dtype}.",
|
| 55 |
+
stacklevel=3,
|
| 56 |
+
)
|
| 57 |
+
out = out.to(dtype=old_dtype)
|
| 58 |
+
return out
|
| 59 |
+
|
| 60 |
+
return super()._apply(keep_dtype)
|
| 61 |
+
|
| 62 |
+
def to(self, *args, **kwargs):
|
| 63 |
+
warn_cast = False
|
| 64 |
+
|
| 65 |
+
if args and isinstance(args[0], torch.Tensor):
|
| 66 |
+
ref, *rest = args
|
| 67 |
+
args = (ref.device, *rest)
|
| 68 |
+
base = next(self.parameters(), None) or next(self.buffers(), None)
|
| 69 |
+
if base is not None and ref.dtype is not base.dtype:
|
| 70 |
+
warn_cast = True
|
| 71 |
+
|
| 72 |
+
if kwargs.pop("dtype", None) is not None:
|
| 73 |
+
warn_cast = True
|
| 74 |
+
|
| 75 |
+
args = tuple(a for a in args if not isinstance(a, torch.dtype))
|
| 76 |
+
|
| 77 |
+
if warn_cast:
|
| 78 |
+
warnings.warn(
|
| 79 |
+
f"{self.__class__.__name__}.to: requested dtype cast ignored; "
|
| 80 |
+
"keeping existing dtypes.",
|
| 81 |
+
stacklevel=2,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
return super().to(*args, **kwargs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def rms_norm(x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
"""Root mean square layer normalization."""
|
| 89 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class MLP(nn.Module):
|
| 93 |
+
"""Simple MLP with SiLU activation."""
|
| 94 |
+
|
| 95 |
+
def __init__(self, dim_in, dim_middle, dim_out):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.fc1 = nn.Linear(dim_in, dim_middle, bias=False)
|
| 98 |
+
self.fc2 = nn.Linear(dim_middle, dim_out, bias=False)
|
| 99 |
+
|
| 100 |
+
def forward(self, x):
|
| 101 |
+
return self.fc2(F.silu(self.fc1(x)))
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class AdaLN(nn.Module):
|
| 105 |
+
"""Adaptive Layer Normalization."""
|
| 106 |
+
|
| 107 |
+
def __init__(self, dim):
|
| 108 |
+
super().__init__()
|
| 109 |
+
self.fc = nn.Linear(dim, 2 * dim, bias=False)
|
| 110 |
+
|
| 111 |
+
def forward(self, x, cond):
|
| 112 |
+
b, n, d = cond.shape
|
| 113 |
+
_, nm, _ = x.shape
|
| 114 |
+
m = nm // n
|
| 115 |
+
|
| 116 |
+
y = F.silu(cond)
|
| 117 |
+
ab = self.fc(y) # [b, n, 2d]
|
| 118 |
+
ab = ab.view(b, n, 1, 2 * d) # [b, n, 1, 2d]
|
| 119 |
+
ab = ab.expand(-1, -1, m, -1) # [b, n, m, 2d]
|
| 120 |
+
ab = ab.reshape(b, nm, 2 * d) # [b, nm, 2d]
|
| 121 |
+
|
| 122 |
+
a, b_ = ab.chunk(2, dim=-1) # [b, nm, d] each
|
| 123 |
+
x = rms_norm(x) * (1 + a) + b_
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def ada_rmsnorm(x, scale, bias):
|
| 128 |
+
"""Adaptive RMS normalization with scale and bias."""
|
| 129 |
+
x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=scale.size(1))
|
| 130 |
+
y4 = rms_norm(x4) * (1 + scale.unsqueeze(2)) + bias.unsqueeze(2)
|
| 131 |
+
return eo.rearrange(y4, "b n m d -> b (n m) d")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def ada_gate(x, gate):
|
| 135 |
+
"""Apply gating to x with per-frame gates."""
|
| 136 |
+
x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=gate.size(1))
|
| 137 |
+
return eo.rearrange(x4 * gate.unsqueeze(2), "b n m d -> b (n m) d")
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class NoiseConditioner(NoCastModule):
|
| 141 |
+
"""Sigma -> logSNR -> Fourier Features -> Dense embedding."""
|
| 142 |
+
|
| 143 |
+
def __init__(self, dim, fourier_dim=512, base=10_000.0):
|
| 144 |
+
super().__init__()
|
| 145 |
+
assert fourier_dim % 2 == 0
|
| 146 |
+
half = fourier_dim // 2
|
| 147 |
+
self.freq = nn.Buffer(
|
| 148 |
+
torch.logspace(0, -1, steps=half, base=base, dtype=torch.float32),
|
| 149 |
+
persistent=False,
|
| 150 |
+
)
|
| 151 |
+
self.mlp = MLP(fourier_dim, dim * 4, dim)
|
| 152 |
+
|
| 153 |
+
def forward(self, s, eps=torch.finfo(torch.float32).eps):
|
| 154 |
+
assert self.freq.dtype == torch.float32
|
| 155 |
+
orig_dtype, shape = s.dtype, s.shape
|
| 156 |
+
|
| 157 |
+
with torch.autocast("cuda", enabled=False):
|
| 158 |
+
s = s.reshape(-1).float()
|
| 159 |
+
s = s * 1000
|
| 160 |
+
|
| 161 |
+
phase = s[:, None] * self.freq[None, :]
|
| 162 |
+
emb = torch.cat((torch.sin(phase), torch.cos(phase)), dim=-1)
|
| 163 |
+
emb = emb * 2**0.5
|
| 164 |
+
emb = self.mlp(emb)
|
| 165 |
+
|
| 166 |
+
return emb.to(orig_dtype).view(*shape, -1)
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
# ---------------------------------------------------------------------------
|
| 170 |
+
# Attention
|
| 171 |
+
# ---------------------------------------------------------------------------
|
| 172 |
+
|
| 173 |
+
class OrthoRoPEAngles(NoCastModule):
|
| 174 |
+
"""Computes RoPE angles on the fly each forward pass."""
|
| 175 |
+
|
| 176 |
+
def __init__(self, config):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.config = config
|
| 179 |
+
|
| 180 |
+
d_head = config.d_model // config.n_heads
|
| 181 |
+
torch._assert(d_head % 8 == 0, "d_head must be divisible by 8")
|
| 182 |
+
d_xy, d_t = d_head // 8, d_head // 4
|
| 183 |
+
|
| 184 |
+
nyq = float(getattr(config, "rope_nyquist_frac", 0.8))
|
| 185 |
+
max_freq = min(self.config.height, self.config.width) * nyq
|
| 186 |
+
n = (d_xy + 1) // 2
|
| 187 |
+
xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy]
|
| 188 |
+
|
| 189 |
+
theta = float(getattr(config, "rope_theta", 10000.0))
|
| 190 |
+
inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t))
|
| 191 |
+
inv_t = inv_t.repeat_interleave(2)
|
| 192 |
+
|
| 193 |
+
self.register_buffer("xy", xy, persistent=False)
|
| 194 |
+
self.register_buffer("inv_t", inv_t, persistent=False)
|
| 195 |
+
|
| 196 |
+
@torch.autocast("cuda", enabled=False)
|
| 197 |
+
def forward(self, pos_ids):
|
| 198 |
+
if not torch.compiler.is_compiling():
|
| 199 |
+
torch._assert(
|
| 200 |
+
(pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width),
|
| 201 |
+
f"pos_ids out of bounds, {self.config.height}, {self.config.width}"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0
|
| 205 |
+
y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0
|
| 206 |
+
t = pos_ids["t_pos"].float()
|
| 207 |
+
|
| 208 |
+
freqs = torch.cat(
|
| 209 |
+
(x.unsqueeze(-1) * self.xy, y.unsqueeze(-1) * self.xy, t.unsqueeze(-1) * self.inv_t),
|
| 210 |
+
dim=-1,
|
| 211 |
+
)
|
| 212 |
+
return freqs.cos()[:, None], freqs.sin()[:, None]
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class OrthoRoPE(NoCastModule):
|
| 216 |
+
"""Applies precomputed RoPE angles to input tensors."""
|
| 217 |
+
|
| 218 |
+
def __init__(self, config):
|
| 219 |
+
super().__init__()
|
| 220 |
+
self.config = config
|
| 221 |
+
assert not getattr(self.config, "has_audio", False)
|
| 222 |
+
|
| 223 |
+
@torch.autocast("cuda", enabled=False)
|
| 224 |
+
def forward(self, x, rope_angles):
|
| 225 |
+
cos, sin = rope_angles
|
| 226 |
+
x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1)
|
| 227 |
+
y0 = x0 * cos - x1 * sin
|
| 228 |
+
y1 = x1 * cos + x0 * sin
|
| 229 |
+
return torch.cat((y0, y1), dim=-1).type_as(x)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class Attn(nn.Module):
|
| 233 |
+
"""Self-attention with RoPE and optional GQA, value residual, and gated attention."""
|
| 234 |
+
|
| 235 |
+
def __init__(self, config, layer_idx):
|
| 236 |
+
super().__init__()
|
| 237 |
+
self.config = config
|
| 238 |
+
self.layer_idx = layer_idx
|
| 239 |
+
|
| 240 |
+
self.value_residual = getattr(config, "value_residual", False)
|
| 241 |
+
if self.value_residual:
|
| 242 |
+
self.v_lamb = nn.Parameter(torch.tensor(0.5))
|
| 243 |
+
|
| 244 |
+
self.n_heads = config.n_heads
|
| 245 |
+
self.n_kv_heads = getattr(config, "n_kv_heads", None) or config.n_heads
|
| 246 |
+
self.d_head = config.d_model // self.n_heads
|
| 247 |
+
assert config.d_model % self.n_heads == 0
|
| 248 |
+
|
| 249 |
+
self.enable_gqa = self.n_heads != self.n_kv_heads
|
| 250 |
+
|
| 251 |
+
self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False)
|
| 252 |
+
self.k_proj = nn.Linear(
|
| 253 |
+
config.d_model, self.n_kv_heads * self.d_head, bias=False
|
| 254 |
+
)
|
| 255 |
+
self.v_proj = nn.Linear(
|
| 256 |
+
config.d_model, self.n_kv_heads * self.d_head, bias=False
|
| 257 |
+
)
|
| 258 |
+
self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
|
| 259 |
+
|
| 260 |
+
self.rope = OrthoRoPE(config)
|
| 261 |
+
|
| 262 |
+
self.gated_attn = getattr(config, "gated_attn", False)
|
| 263 |
+
if self.gated_attn:
|
| 264 |
+
self.gate_proj = nn.Linear(
|
| 265 |
+
self.n_heads, self.n_heads, bias=False
|
| 266 |
+
)
|
| 267 |
+
nn.init.zeros_(self.gate_proj.weight)
|
| 268 |
+
|
| 269 |
+
def forward(self, x, pos_ids, rope_angles, v1, kv_cache):
|
| 270 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 271 |
+
|
| 272 |
+
q = eo.rearrange(
|
| 273 |
+
self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head
|
| 274 |
+
)
|
| 275 |
+
k = eo.rearrange(
|
| 276 |
+
self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
|
| 277 |
+
)
|
| 278 |
+
v = eo.rearrange(
|
| 279 |
+
self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if self.value_residual:
|
| 283 |
+
v1 = v if v1 is None else v1
|
| 284 |
+
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
|
| 285 |
+
|
| 286 |
+
q, k = rms_norm(q), rms_norm(k)
|
| 287 |
+
q, k = self.rope(q, rope_angles), self.rope(k, rope_angles)
|
| 288 |
+
|
| 289 |
+
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
|
| 290 |
+
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
|
| 291 |
+
|
| 292 |
+
if self.gated_attn:
|
| 293 |
+
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
|
| 294 |
+
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
|
| 295 |
+
y = eo.rearrange(y, "b h t d -> b t (h d)")
|
| 296 |
+
y = self.out_proj(y)
|
| 297 |
+
return y, v1
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class MergedQKVAttn(Attn):
|
| 301 |
+
def __init__(self, src: Attn, config):
|
| 302 |
+
super().__init__(config, src.layer_idx)
|
| 303 |
+
self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype)
|
| 304 |
+
self.load_state_dict(
|
| 305 |
+
src.state_dict(), strict=False
|
| 306 |
+
)
|
| 307 |
+
self.train(src.training)
|
| 308 |
+
|
| 309 |
+
self.q_out = self.n_heads * self.d_head
|
| 310 |
+
self.kv_out = self.n_kv_heads * self.d_head
|
| 311 |
+
|
| 312 |
+
self.qkv_proj = nn.Linear(
|
| 313 |
+
self.q_proj.in_features,
|
| 314 |
+
self.q_out + 2 * self.kv_out,
|
| 315 |
+
bias=False,
|
| 316 |
+
device=self.q_proj.weight.device,
|
| 317 |
+
dtype=self.q_proj.weight.dtype,
|
| 318 |
+
)
|
| 319 |
+
with torch.no_grad():
|
| 320 |
+
self.qkv_proj.weight.copy_(
|
| 321 |
+
torch.cat(
|
| 322 |
+
[self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
|
| 323 |
+
)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
del self.q_proj, self.k_proj, self.v_proj
|
| 327 |
+
|
| 328 |
+
def forward(self, x, pos_ids, rope_angles, v1, kv_cache):
|
| 329 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 330 |
+
|
| 331 |
+
q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1)
|
| 332 |
+
|
| 333 |
+
B, T = x.shape[:2]
|
| 334 |
+
q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
|
| 335 |
+
k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
|
| 336 |
+
v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
|
| 337 |
+
|
| 338 |
+
if self.value_residual:
|
| 339 |
+
v1 = v if v1 is None else v1
|
| 340 |
+
v = torch.lerp(v, v1.view_as(v), self.v_lamb)
|
| 341 |
+
|
| 342 |
+
q, k = rms_norm(q), rms_norm(k)
|
| 343 |
+
q, k = self.rope(q, rope_angles), self.rope(k, rope_angles)
|
| 344 |
+
|
| 345 |
+
k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
|
| 346 |
+
y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
|
| 347 |
+
|
| 348 |
+
if self.gated_attn:
|
| 349 |
+
gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
|
| 350 |
+
y = y * gates.permute(0, 2, 1).unsqueeze(-1)
|
| 351 |
+
|
| 352 |
+
y = y.transpose(1, 2).reshape(B, T, -1)
|
| 353 |
+
y = self.out_proj(y)
|
| 354 |
+
return y, v1
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class CrossAttention(nn.Module):
|
| 358 |
+
"""Cross-attention for prompt conditioning."""
|
| 359 |
+
|
| 360 |
+
def __init__(self, config, context_dim=None):
|
| 361 |
+
super().__init__()
|
| 362 |
+
assert config.d_model % config.n_heads == 0
|
| 363 |
+
|
| 364 |
+
self.d_head = config.d_model // config.n_heads
|
| 365 |
+
self.inner_dim = context_dim or config.d_model
|
| 366 |
+
assert self.inner_dim % self.d_head == 0
|
| 367 |
+
self.n_heads = self.inner_dim // self.d_head
|
| 368 |
+
self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False)
|
| 369 |
+
self.k_proj = nn.Linear(
|
| 370 |
+
context_dim or config.d_model, self.inner_dim, bias=False
|
| 371 |
+
)
|
| 372 |
+
self.v_proj = nn.Linear(
|
| 373 |
+
context_dim or config.d_model, self.inner_dim, bias=False
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False)
|
| 377 |
+
self.out_proj.weight.detach().zero_()
|
| 378 |
+
|
| 379 |
+
def forward(self, x, context, context_pad_mask=None):
|
| 380 |
+
from torch.nn.attention.flex_attention import flex_attention
|
| 381 |
+
|
| 382 |
+
q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads)
|
| 383 |
+
k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
|
| 384 |
+
v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
|
| 385 |
+
q, k = rms_norm(q), rms_norm(k)
|
| 386 |
+
out = flex_attention(q, k, v)
|
| 387 |
+
out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1)
|
| 388 |
+
return self.out_proj(out)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# ---------------------------------------------------------------------------
|
| 392 |
+
# Inference caching
|
| 393 |
+
# ---------------------------------------------------------------------------
|
| 394 |
+
|
| 395 |
+
def _bf16_u16(x: Tensor) -> Tensor:
|
| 396 |
+
return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class CachedDenoiseStepEmb(nn.Module):
|
| 400 |
+
"""bf16 sigma -> bf16 embedding via 64k LUT."""
|
| 401 |
+
|
| 402 |
+
def __init__(self, base: nn.Module, sigmas: list[float]):
|
| 403 |
+
super().__init__()
|
| 404 |
+
device = next(base.parameters()).device
|
| 405 |
+
|
| 406 |
+
levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16)
|
| 407 |
+
bits = _bf16_u16(levels)
|
| 408 |
+
if torch.unique(bits).numel() != bits.numel():
|
| 409 |
+
raise ValueError(
|
| 410 |
+
"scheduler_sigmas collide in bf16; caching would be ambiguous"
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
with torch.no_grad():
|
| 414 |
+
table = (
|
| 415 |
+
base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous()
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
lut = torch.full((65536,), -1, device=device, dtype=torch.int32)
|
| 419 |
+
lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32)
|
| 420 |
+
|
| 421 |
+
self.register_buffer("table", table, persistent=False)
|
| 422 |
+
self.register_buffer("lut", lut, persistent=False)
|
| 423 |
+
self.register_buffer(
|
| 424 |
+
"oob",
|
| 425 |
+
torch.tensor(bits.numel(), device=device, dtype=torch.int32),
|
| 426 |
+
persistent=False,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
def forward(self, sigma: Tensor) -> Tensor:
|
| 430 |
+
if sigma.dtype is not torch.bfloat16:
|
| 431 |
+
raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16")
|
| 432 |
+
idx = self.lut[_bf16_u16(sigma)]
|
| 433 |
+
idx = torch.where(idx >= 0, idx, self.oob)
|
| 434 |
+
return self.table[idx.to(torch.int64)]
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
class CachedCondHead(nn.Module):
|
| 438 |
+
"""bf16 cond -> cached conditioning; invalid cond => OOB index error."""
|
| 439 |
+
|
| 440 |
+
def __init__(
|
| 441 |
+
self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8
|
| 442 |
+
):
|
| 443 |
+
super().__init__()
|
| 444 |
+
table = cached_denoise_step_emb.table
|
| 445 |
+
S, D = table.shape
|
| 446 |
+
|
| 447 |
+
with torch.no_grad():
|
| 448 |
+
emb = table[:, None, :]
|
| 449 |
+
cache = (
|
| 450 |
+
torch.stack([t.squeeze(1) for t in base(emb)], 0)
|
| 451 |
+
.to(torch.bfloat16)
|
| 452 |
+
.contiguous()
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
key_dim = None
|
| 456 |
+
for d in range(min(D, max_key_dims)):
|
| 457 |
+
b = _bf16_u16(table[:, d])
|
| 458 |
+
if torch.unique(b).numel() == S:
|
| 459 |
+
key_dim = d
|
| 460 |
+
key_bits = b
|
| 461 |
+
break
|
| 462 |
+
if key_dim is None:
|
| 463 |
+
raise ValueError(
|
| 464 |
+
"Could not find a unique bf16 key dim for cond->sigma mapping"
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32)
|
| 468 |
+
lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32)
|
| 469 |
+
|
| 470 |
+
self.key_dim = int(key_dim)
|
| 471 |
+
self.register_buffer("cache", cache, persistent=False)
|
| 472 |
+
self.register_buffer("lut", lut, persistent=False)
|
| 473 |
+
self.register_buffer(
|
| 474 |
+
"oob",
|
| 475 |
+
torch.tensor(S, device=table.device, dtype=torch.int32),
|
| 476 |
+
persistent=False,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
def forward(self, cond: Tensor):
|
| 480 |
+
if cond.dtype is not torch.bfloat16:
|
| 481 |
+
raise RuntimeError("CachedCondHead expects cond bf16")
|
| 482 |
+
idx = self.lut[_bf16_u16(cond[..., self.key_dim])]
|
| 483 |
+
idx = torch.where(idx >= 0, idx, self.oob)
|
| 484 |
+
g = self.cache[:, idx.to(torch.int64)]
|
| 485 |
+
return tuple(g.unbind(0))
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# ---------------------------------------------------------------------------
|
| 489 |
+
# Quantization
|
| 490 |
+
# ---------------------------------------------------------------------------
|
| 491 |
+
|
| 492 |
+
QUANTS = [None]
|
| 493 |
+
|
| 494 |
+
try:
|
| 495 |
+
from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
|
| 496 |
+
QUANTS.append("nvfp4")
|
| 497 |
+
except ImportError:
|
| 498 |
+
pass
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
@torch.library.custom_op("world_engine::fp4_linear", mutates_args=())
|
| 502 |
+
def fp4_linear(
|
| 503 |
+
a_bf16: torch.Tensor,
|
| 504 |
+
b_fp4_T: torch.Tensor,
|
| 505 |
+
a_global_sf: torch.Tensor,
|
| 506 |
+
b_sf_T: torch.Tensor,
|
| 507 |
+
alpha: torch.Tensor,
|
| 508 |
+
) -> torch.Tensor:
|
| 509 |
+
a_fp4, a_sf = nvfp4_quantize(
|
| 510 |
+
a_bf16, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False,
|
| 511 |
+
)
|
| 512 |
+
return mm_fp4(
|
| 513 |
+
a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass"
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
@fp4_linear.register_fake
|
| 518 |
+
def _fp4_linear_fake(
|
| 519 |
+
a_bf16: torch.Tensor, b_fp4_T: torch.Tensor,
|
| 520 |
+
a_global_sf: torch.Tensor, b_sf_T: torch.Tensor, alpha: torch.Tensor,
|
| 521 |
+
) -> torch.Tensor:
|
| 522 |
+
return torch.empty(
|
| 523 |
+
(a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
class FP4Linear(nn.Module):
|
| 528 |
+
"""FP4 Linear layer using FlashInfer's NVFP4 quantization."""
|
| 529 |
+
|
| 530 |
+
def __init__(self, lin: nn.Linear):
|
| 531 |
+
super().__init__()
|
| 532 |
+
self.in_features = lin.in_features
|
| 533 |
+
self.out_features = lin.out_features
|
| 534 |
+
assert self.in_features % 32 == 0 and self.out_features % 32 == 0
|
| 535 |
+
|
| 536 |
+
self.weight = nn.Parameter(lin.weight.detach().clone())
|
| 537 |
+
self._weight_fp4_T = None
|
| 538 |
+
self._weight_scales_T = None
|
| 539 |
+
self._alpha = None
|
| 540 |
+
self._dummy_scale = None
|
| 541 |
+
self._weight_global_sf = None
|
| 542 |
+
|
| 543 |
+
with torch.no_grad():
|
| 544 |
+
self._dummy_scale = torch.full((1,), 1.0, device=self.weight.device, dtype=torch.float32)
|
| 545 |
+
weight_bf16 = self.weight.to(torch.bfloat16).to(self.weight.device).contiguous()
|
| 546 |
+
weight_amax = weight_bf16.float().abs().nan_to_num().max()
|
| 547 |
+
self._weight_global_sf = (1.0) / weight_amax
|
| 548 |
+
self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale)
|
| 549 |
+
w_fp4, w_sf = nvfp4_quantize(
|
| 550 |
+
weight_bf16, self._weight_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False,
|
| 551 |
+
)
|
| 552 |
+
self._weight_fp4_T = w_fp4.t()
|
| 553 |
+
self._weight_scales_T = w_sf.t()
|
| 554 |
+
|
| 555 |
+
assert self.weight.is_cuda
|
| 556 |
+
lazy_x = torch.zeros((1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16)
|
| 557 |
+
fp4_linear(lazy_x, self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha)
|
| 558 |
+
|
| 559 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 560 |
+
x_flat = x.reshape(-1, x.shape[-1])
|
| 561 |
+
y = fp4_linear(
|
| 562 |
+
x_flat.to(torch.bfloat16).contiguous(),
|
| 563 |
+
self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha,
|
| 564 |
+
)
|
| 565 |
+
return y.reshape(x.shape[:-1] + (-1,))
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
class FP8W8A8Linear(nn.Module):
|
| 569 |
+
__constants__ = ("in_features", "out_features")
|
| 570 |
+
|
| 571 |
+
def __init__(self, lin: nn.Linear):
|
| 572 |
+
super().__init__()
|
| 573 |
+
self.in_features, self.out_features = lin.in_features, lin.out_features
|
| 574 |
+
f8 = torch.float8_e4m3fn
|
| 575 |
+
inv = 1.0 / float(torch.finfo(f8).max)
|
| 576 |
+
self._inv = inv
|
| 577 |
+
w = lin.weight.detach()
|
| 578 |
+
ws = (w.abs().amax() * inv).clamp_min(1e-8).float()
|
| 579 |
+
wf8 = (w / ws.to(w.dtype)).to(f8).contiguous()
|
| 580 |
+
self.register_buffer("wT", wf8.t())
|
| 581 |
+
self.register_buffer("ws", ws)
|
| 582 |
+
if lin.bias is None:
|
| 583 |
+
self.bias = None
|
| 584 |
+
else:
|
| 585 |
+
self.register_buffer("bias", lin.bias.detach().to(torch.float16))
|
| 586 |
+
|
| 587 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 588 |
+
s = x.shape
|
| 589 |
+
x2 = x.reshape(-1, s[-1])
|
| 590 |
+
xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float()
|
| 591 |
+
xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous()
|
| 592 |
+
y = torch._scaled_mm(
|
| 593 |
+
xf8, self.wT, xs, self.ws,
|
| 594 |
+
bias=self.bias, out_dtype=torch.float16, use_fast_accum=True,
|
| 595 |
+
)
|
| 596 |
+
return y.reshape(*s[:-1], self.out_features).to(x.dtype)
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
class FP8Linear(nn.Module):
|
| 600 |
+
def __init__(self, lin: nn.Linear):
|
| 601 |
+
super().__init__()
|
| 602 |
+
self.in_features, self.out_features = lin.in_features, lin.out_features
|
| 603 |
+
self.bias = (
|
| 604 |
+
nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn))
|
| 605 |
+
if lin.bias is not None else None
|
| 606 |
+
)
|
| 607 |
+
w_amax = lin.weight.data.abs().amax()
|
| 608 |
+
w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn)
|
| 609 |
+
self.register_buffer("w_amax", w_amax)
|
| 610 |
+
self.register_buffer("weightT", w.t())
|
| 611 |
+
self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32)
|
| 612 |
+
|
| 613 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 614 |
+
x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous()
|
| 615 |
+
result = torch._scaled_mm(
|
| 616 |
+
x_fp8, self.weightT,
|
| 617 |
+
bias=self.bias, scale_a=self.dummy_scale, scale_b=self.w_amax,
|
| 618 |
+
out_dtype=torch.bfloat16, use_fast_accum=True,
|
| 619 |
+
)
|
| 620 |
+
return result.reshape(x.shape[:-1] + (-1,))
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def quantize_model(model: nn.Module, quant: str):
|
| 624 |
+
if quant is None:
|
| 625 |
+
return model
|
| 626 |
+
|
| 627 |
+
def eligible(m: nn.Module) -> bool:
|
| 628 |
+
w = getattr(m, "weight", None)
|
| 629 |
+
if not isinstance(m, nn.Linear):
|
| 630 |
+
return False
|
| 631 |
+
if getattr(w, "dtype", None) != torch.bfloat16:
|
| 632 |
+
return False
|
| 633 |
+
o, k = w.shape
|
| 634 |
+
return (o % 32 == 0) and (k % 32 == 0)
|
| 635 |
+
|
| 636 |
+
new_linear = {"w8a8": FP8W8A8Linear, "nvfp4": FP4Linear, "fp8": FP8Linear}[quant]
|
| 637 |
+
|
| 638 |
+
for name, child in model.named_children():
|
| 639 |
+
setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(child, quant)
|
| 640 |
+
return model
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
# ---------------------------------------------------------------------------
|
| 644 |
+
# Inference patches
|
| 645 |
+
# ---------------------------------------------------------------------------
|
| 646 |
+
|
| 647 |
+
def patch_cached_noise_conditioning(model) -> None:
|
| 648 |
+
cached_denoise_step_emb = CachedDenoiseStepEmb(
|
| 649 |
+
model.denoise_step_emb, model.config.scheduler_sigmas
|
| 650 |
+
)
|
| 651 |
+
model.denoise_step_emb = cached_denoise_step_emb
|
| 652 |
+
for blk in model.transformer.blocks:
|
| 653 |
+
blk.attn_cond_head = CachedCondHead(blk.attn_cond_head, cached_denoise_step_emb)
|
| 654 |
+
blk.mlp_cond_head = CachedCondHead(blk.mlp_cond_head, cached_denoise_step_emb)
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def patch_Attn_merge_qkv(model) -> None:
|
| 658 |
+
for name, mod in list(model.named_modules()):
|
| 659 |
+
if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn):
|
| 660 |
+
model.set_submodule(name, MergedQKVAttn(mod, model.config))
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def _apply_inference_patches(model) -> None:
|
| 664 |
+
patch_cached_noise_conditioning(model)
|
| 665 |
+
patch_Attn_merge_qkv(model)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
# ---------------------------------------------------------------------------
|
| 669 |
+
# Model components
|
| 670 |
+
# ---------------------------------------------------------------------------
|
| 671 |
+
|
| 672 |
+
class CFG(nn.Module):
|
| 673 |
+
def __init__(self, d_model: int, dropout: float):
|
| 674 |
+
super().__init__()
|
| 675 |
+
self.dropout = dropout
|
| 676 |
+
self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model))
|
| 677 |
+
|
| 678 |
+
def forward(
|
| 679 |
+
self, x: torch.Tensor, is_conditioned: bool | None = None
|
| 680 |
+
) -> torch.Tensor:
|
| 681 |
+
B, L, _ = x.shape
|
| 682 |
+
null = self.null_emb.expand(B, L, -1)
|
| 683 |
+
|
| 684 |
+
if self.training or is_conditioned is None:
|
| 685 |
+
if self.dropout == 0.0:
|
| 686 |
+
return x
|
| 687 |
+
drop = torch.rand(B, 1, 1, device=x.device) < self.dropout
|
| 688 |
+
return torch.where(drop, null, x)
|
| 689 |
+
|
| 690 |
+
return x if is_conditioned else null
|
| 691 |
+
|
| 692 |
+
|
| 693 |
+
class ControllerInputEmbedding(nn.Module):
|
| 694 |
+
"""Embeds controller inputs (mouse + buttons) into model dimension."""
|
| 695 |
+
|
| 696 |
+
def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4):
|
| 697 |
+
super().__init__()
|
| 698 |
+
self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model)
|
| 699 |
+
|
| 700 |
+
def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor):
|
| 701 |
+
assert len(mouse.shape) == 3
|
| 702 |
+
x = torch.cat((mouse, button, scroll), dim=-1)
|
| 703 |
+
return self.mlp(x)
|
| 704 |
+
|
| 705 |
+
|
| 706 |
+
class MLPFusion(nn.Module):
|
| 707 |
+
"""Fuses per-group conditioning into tokens via split linear projections."""
|
| 708 |
+
|
| 709 |
+
def __init__(self, d_model: int):
|
| 710 |
+
super().__init__()
|
| 711 |
+
self.fc1_x = nn.Linear(d_model, d_model, bias=False)
|
| 712 |
+
self.fc1_c = nn.Linear(d_model, d_model, bias=False)
|
| 713 |
+
self.fc2 = nn.Linear(d_model, d_model, bias=False)
|
| 714 |
+
|
| 715 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 716 |
+
B, _, D = x.shape
|
| 717 |
+
L = cond.shape[1]
|
| 718 |
+
x = x.reshape(B, L, -1, D)
|
| 719 |
+
return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten(
|
| 720 |
+
1, 2
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
class MoEWithoutFBGEMM(nn.Module):
|
| 725 |
+
"""MoE implementation using torch grouped_mm (no fbgemm dependency)."""
|
| 726 |
+
|
| 727 |
+
def __init__(self, config):
|
| 728 |
+
super().__init__()
|
| 729 |
+
self.config = config
|
| 730 |
+
self.top_k = config.moe_top_k
|
| 731 |
+
moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or config.mlp_ratio / config.moe_top_k
|
| 732 |
+
d_intermediate = int(config.d_model * moe_mlp_ratio)
|
| 733 |
+
self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False)
|
| 734 |
+
self.expert_in_proj = nn.Parameter(
|
| 735 |
+
torch.empty(config.moe_n_experts, d_intermediate * (2 if config.gated_linear else 1), config.d_model)
|
| 736 |
+
)
|
| 737 |
+
self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_intermediate))
|
| 738 |
+
|
| 739 |
+
def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor:
|
| 740 |
+
if self.training or torch.is_grad_enabled():
|
| 741 |
+
raise NotImplementedError("inference only")
|
| 742 |
+
|
| 743 |
+
orig_shape = x.shape
|
| 744 |
+
x = x.reshape(-1, orig_shape[-1])
|
| 745 |
+
logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1))
|
| 746 |
+
|
| 747 |
+
logits_fp32 = logits.float()
|
| 748 |
+
scores, expert = logits.topk(self.top_k, dim=-1, sorted=False)
|
| 749 |
+
weights = (scores.float() - logits_fp32.logsumexp(dim=-1, keepdim=True)).exp().to(x.dtype)
|
| 750 |
+
|
| 751 |
+
expert = expert.flatten()
|
| 752 |
+
expert_sorted, sort_idx = expert.sort()
|
| 753 |
+
expert_ids = torch.arange(self.expert_in_proj.size(0), device=expert.device, dtype=expert_sorted.dtype)
|
| 754 |
+
offsets = torch.searchsorted(expert_sorted, expert_ids, right=True).to(torch.int32)
|
| 755 |
+
|
| 756 |
+
src = sort_idx // self.top_k
|
| 757 |
+
x_grouped = x.index_select(0, torch.cat((src, src[:1]), dim=0))
|
| 758 |
+
h = F.grouped_mm(x_grouped, self.expert_in_proj.transpose(-2, -1), offs=offsets)
|
| 759 |
+
h[-1].zero_()
|
| 760 |
+
|
| 761 |
+
if self.config.gated_linear:
|
| 762 |
+
gate_act, up = h.chunk(2, dim=-1)
|
| 763 |
+
h = F.silu(gate_act) * up
|
| 764 |
+
else:
|
| 765 |
+
h = F.silu(h)
|
| 766 |
+
|
| 767 |
+
y_grouped = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offsets)[:-1]
|
| 768 |
+
y = torch.empty_like(y_grouped).index_copy_(0, sort_idx, y_grouped).view(x.size(0), self.top_k, -1)
|
| 769 |
+
return (y * weights.unsqueeze(-1)).sum(dim=1).reshape(orig_shape)
|
| 770 |
+
|
| 771 |
+
|
| 772 |
+
class MoE(nn.Module):
|
| 773 |
+
"""MoE implementation using fbgemm optimized kernels."""
|
| 774 |
+
|
| 775 |
+
def __init__(self, config):
|
| 776 |
+
super().__init__()
|
| 777 |
+
self.config = config
|
| 778 |
+
self.top_k = config.moe_top_k
|
| 779 |
+
moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or (config.mlp_ratio / config.moe_top_k)
|
| 780 |
+
d_int = int(config.d_model * moe_mlp_ratio)
|
| 781 |
+
|
| 782 |
+
self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False)
|
| 783 |
+
self.expert_in_proj = nn.Parameter(
|
| 784 |
+
torch.empty(config.moe_n_experts, d_int * (2 if config.gated_linear else 1), config.d_model)
|
| 785 |
+
)
|
| 786 |
+
self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_int))
|
| 787 |
+
|
| 788 |
+
def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor:
|
| 789 |
+
if self.training or torch.is_grad_enabled():
|
| 790 |
+
raise NotImplementedError("inference only")
|
| 791 |
+
|
| 792 |
+
orig = x.shape
|
| 793 |
+
x = x.reshape(-1, orig[-1])
|
| 794 |
+
logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1))
|
| 795 |
+
|
| 796 |
+
logits32 = logits.float()
|
| 797 |
+
token_counts, expert_sorted, src = index_shuffling(logits32, top_k=self.top_k)
|
| 798 |
+
|
| 799 |
+
E = self.expert_in_proj.size(0)
|
| 800 |
+
offs = token_counts[:E].cumsum(0).to(torch.int32)
|
| 801 |
+
|
| 802 |
+
src = src.to(torch.long)
|
| 803 |
+
expert_sorted = expert_sorted.to(torch.long)
|
| 804 |
+
logZ = logits32.logsumexp(-1)
|
| 805 |
+
w = (logits32[src, expert_sorted] - logZ[src]).exp().to(x.dtype)
|
| 806 |
+
|
| 807 |
+
xg = x.index_select(0, torch.cat((src, src[:1]), 0))
|
| 808 |
+
h = F.grouped_mm(xg, self.expert_in_proj.transpose(-2, -1), offs=offs)
|
| 809 |
+
if self.config.gated_linear:
|
| 810 |
+
ga, up = h.chunk(2, -1)
|
| 811 |
+
h = F.silu(ga) * up
|
| 812 |
+
else:
|
| 813 |
+
h = F.silu(h)
|
| 814 |
+
|
| 815 |
+
yg = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offs)[:-1]
|
| 816 |
+
out = torch.zeros_like(x)
|
| 817 |
+
torch.ops.fbgemm.scatter_add_dense_tokens(out, (yg * w.unsqueeze(-1)).contiguous(), src)
|
| 818 |
+
return out.reshape(orig)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
class CondHead(nn.Module):
|
| 822 |
+
"""Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond)."""
|
| 823 |
+
|
| 824 |
+
def __init__(self, d_model: int, noise_conditioning: str = "wan", n_cond: int = 3):
|
| 825 |
+
super().__init__()
|
| 826 |
+
self.bias_in = (
|
| 827 |
+
nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None
|
| 828 |
+
)
|
| 829 |
+
self.cond_proj = nn.ModuleList(
|
| 830 |
+
[nn.Linear(d_model, d_model, bias=False) for _ in range(n_cond)]
|
| 831 |
+
)
|
| 832 |
+
|
| 833 |
+
def forward(self, cond):
|
| 834 |
+
cond = cond + self.bias_in if self.bias_in is not None else cond
|
| 835 |
+
h = F.silu(cond)
|
| 836 |
+
return tuple(p(h) for p in self.cond_proj)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
# ---------------------------------------------------------------------------
|
| 840 |
+
# Transformer blocks
|
| 841 |
+
# ---------------------------------------------------------------------------
|
| 842 |
+
|
| 843 |
+
class WorldDiTBlock(nn.Module):
|
| 844 |
+
"""Single transformer block with self-attention, optional cross-attention, and MLP."""
|
| 845 |
+
|
| 846 |
+
def __init__(
|
| 847 |
+
self, d_model, n_heads, mlp_ratio, layer_idx,
|
| 848 |
+
prompt_conditioning, prompt_conditioning_period, prompt_embedding_dim,
|
| 849 |
+
ctrl_conditioning_period, noise_conditioning, config,
|
| 850 |
+
):
|
| 851 |
+
super().__init__()
|
| 852 |
+
self.config = config
|
| 853 |
+
self.attn = Attn(config, layer_idx)
|
| 854 |
+
if getattr(config, "moe", False):
|
| 855 |
+
self.dit_mlp = MoE(config) if HAS_FBGEMM else MoEWithoutFBGEMM(config)
|
| 856 |
+
else:
|
| 857 |
+
self.dit_mlp = MLP(d_model, d_model * mlp_ratio, d_model)
|
| 858 |
+
self.attn_cond_head = CondHead(d_model, noise_conditioning, n_cond=3)
|
| 859 |
+
self.mlp_cond_head = CondHead(d_model, noise_conditioning, n_cond=3)
|
| 860 |
+
|
| 861 |
+
do_prompt_cond = (
|
| 862 |
+
prompt_conditioning is not None
|
| 863 |
+
and layer_idx % prompt_conditioning_period == 0
|
| 864 |
+
)
|
| 865 |
+
self.prompt_cross_attn = (
|
| 866 |
+
CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None
|
| 867 |
+
)
|
| 868 |
+
do_ctrl_cond = ctrl_conditioning_period is not None and layer_idx % ctrl_conditioning_period == 0
|
| 869 |
+
self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None
|
| 870 |
+
|
| 871 |
+
def forward(self, x, pos_ids, rope_angles, cond, ctx, v, kv_cache=None):
|
| 872 |
+
s0, b0, g0 = self.attn_cond_head(cond)
|
| 873 |
+
s1, b1, g1 = self.mlp_cond_head(cond)
|
| 874 |
+
|
| 875 |
+
residual = x
|
| 876 |
+
x = ada_rmsnorm(x, s0, b0)
|
| 877 |
+
x, v = self.attn(x, pos_ids, rope_angles, v, kv_cache=kv_cache)
|
| 878 |
+
x = ada_gate(x, g0) + residual
|
| 879 |
+
|
| 880 |
+
if self.prompt_cross_attn is not None:
|
| 881 |
+
x = (
|
| 882 |
+
self.prompt_cross_attn(
|
| 883 |
+
rms_norm(x),
|
| 884 |
+
context=rms_norm(ctx["prompt_emb"]),
|
| 885 |
+
context_pad_mask=ctx["prompt_pad_mask"],
|
| 886 |
+
)
|
| 887 |
+
+ x
|
| 888 |
+
)
|
| 889 |
+
|
| 890 |
+
if self.ctrl_mlpfusion is not None:
|
| 891 |
+
x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x
|
| 892 |
+
|
| 893 |
+
x = ada_gate(self.dit_mlp(ada_rmsnorm(x, s1, b1)), g1) + x
|
| 894 |
+
|
| 895 |
+
return x, v
|
| 896 |
+
|
| 897 |
+
|
| 898 |
+
class WorldDiT(nn.Module):
|
| 899 |
+
"""Stack of WorldDiTBlocks with shared parameters."""
|
| 900 |
+
|
| 901 |
+
def __init__(self, config):
|
| 902 |
+
super().__init__()
|
| 903 |
+
self.config = config
|
| 904 |
+
self.blocks = nn.ModuleList(
|
| 905 |
+
[
|
| 906 |
+
WorldDiTBlock(
|
| 907 |
+
d_model=config.d_model,
|
| 908 |
+
n_heads=config.n_heads,
|
| 909 |
+
mlp_ratio=config.mlp_ratio,
|
| 910 |
+
layer_idx=idx,
|
| 911 |
+
prompt_conditioning=config.prompt_conditioning,
|
| 912 |
+
prompt_conditioning_period=config.prompt_conditioning_period,
|
| 913 |
+
prompt_embedding_dim=config.prompt_embedding_dim,
|
| 914 |
+
ctrl_conditioning_period=config.ctrl_conditioning_period,
|
| 915 |
+
noise_conditioning=config.noise_conditioning,
|
| 916 |
+
config=config,
|
| 917 |
+
)
|
| 918 |
+
for idx in range(config.n_layers)
|
| 919 |
+
]
|
| 920 |
+
)
|
| 921 |
+
self.rope_angles = OrthoRoPEAngles(config)
|
| 922 |
+
|
| 923 |
+
def forward(self, x, pos_ids, cond, ctx, kv_cache=None):
|
| 924 |
+
rope_angles = self.rope_angles(pos_ids)
|
| 925 |
+
v = None
|
| 926 |
+
for i, block in enumerate(self.blocks):
|
| 927 |
+
x, v = block(x, pos_ids, rope_angles, cond, ctx, v, kv_cache=kv_cache)
|
| 928 |
+
return x
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
# ---------------------------------------------------------------------------
|
| 932 |
+
# Top-level model
|
| 933 |
+
# ---------------------------------------------------------------------------
|
| 934 |
+
|
| 935 |
+
class WorldModel(ModelMixin, ConfigMixin):
|
| 936 |
+
"""
|
| 937 |
+
WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser.
|
| 938 |
+
|
| 939 |
+
Denoises a frame given:
|
| 940 |
+
- All previous frames (via KV cache)
|
| 941 |
+
- The prompt embedding
|
| 942 |
+
- The controller input embedding
|
| 943 |
+
- The current noise level
|
| 944 |
+
"""
|
| 945 |
+
|
| 946 |
+
_supports_gradient_checkpointing = False
|
| 947 |
+
_keep_in_fp32_modules = ["denoise_step_emb", "rope_angles"]
|
| 948 |
+
|
| 949 |
+
@register_to_config
|
| 950 |
+
def __init__(
|
| 951 |
+
self,
|
| 952 |
+
d_model: int = 2048,
|
| 953 |
+
n_heads: int = 32,
|
| 954 |
+
n_kv_heads: int | None = None,
|
| 955 |
+
n_layers: int = 24,
|
| 956 |
+
mlp_ratio: int = 4,
|
| 957 |
+
channels: int = 32,
|
| 958 |
+
height: int = 16,
|
| 959 |
+
width: int = 16,
|
| 960 |
+
patch: tuple = (2, 2),
|
| 961 |
+
tokens_per_frame: int = 256,
|
| 962 |
+
n_frames: int = 4096,
|
| 963 |
+
local_window: int = 16,
|
| 964 |
+
global_window: int = 128,
|
| 965 |
+
global_attn_period: int = 4,
|
| 966 |
+
global_pinned_dilation: int = 8,
|
| 967 |
+
global_attn_offset: int = 0,
|
| 968 |
+
value_residual: bool = True,
|
| 969 |
+
gated_attn: bool = False,
|
| 970 |
+
n_buttons: int = 256,
|
| 971 |
+
ctrl_conditioning: str | None = "mlp_fusion",
|
| 972 |
+
ctrl_conditioning_period: int | None = 3,
|
| 973 |
+
ctrl_cond_dropout: float = 0.0,
|
| 974 |
+
prompt_conditioning: str | None = None,
|
| 975 |
+
prompt_conditioning_period: int = 3,
|
| 976 |
+
prompt_embedding_dim: int = 2048,
|
| 977 |
+
prompt_cond_dropout: float = 0.0,
|
| 978 |
+
noise_conditioning: str = "wan",
|
| 979 |
+
scheduler_sigmas: list[float] | None = [
|
| 980 |
+
1.0, 0.8609585762023926, 0.729332447052002, 0.3205108940601349, 0.0,
|
| 981 |
+
],
|
| 982 |
+
base_fps: int = 60,
|
| 983 |
+
causal: bool = True,
|
| 984 |
+
mlp_gradient_checkpointing: bool = True,
|
| 985 |
+
block_gradient_checkpointing: bool = True,
|
| 986 |
+
rope_impl: str = "ortho",
|
| 987 |
+
moe: bool = False,
|
| 988 |
+
moe_top_k: int = 2,
|
| 989 |
+
moe_n_experts: int = 8,
|
| 990 |
+
moe_mlp_ratio: float | None = None,
|
| 991 |
+
gated_linear: bool = False,
|
| 992 |
+
temporal_compression: int = 1,
|
| 993 |
+
inference_fps: int | None = None,
|
| 994 |
+
taehv_ae: bool = False,
|
| 995 |
+
rope_nyquist_frac: float = 0.8,
|
| 996 |
+
rope_theta: float = 10000.0,
|
| 997 |
+
):
|
| 998 |
+
super().__init__()
|
| 999 |
+
|
| 1000 |
+
self.denoise_step_emb = NoiseConditioner(d_model)
|
| 1001 |
+
self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio)
|
| 1002 |
+
|
| 1003 |
+
if self.config.ctrl_conditioning is not None:
|
| 1004 |
+
self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout)
|
| 1005 |
+
if self.config.prompt_conditioning is not None:
|
| 1006 |
+
self.prompt_cfg = CFG(
|
| 1007 |
+
self.config.prompt_embedding_dim, self.config.prompt_cond_dropout
|
| 1008 |
+
)
|
| 1009 |
+
|
| 1010 |
+
self.transformer = WorldDiT(self.config)
|
| 1011 |
+
self.patch = tuple(patch)
|
| 1012 |
+
|
| 1013 |
+
C, D = channels, d_model
|
| 1014 |
+
self.patchify = nn.Conv2d(
|
| 1015 |
+
C, D, kernel_size=self.patch, stride=self.patch, bias=False
|
| 1016 |
+
)
|
| 1017 |
+
self.unpatchify = nn.ConvTranspose2d(
|
| 1018 |
+
D, C, kernel_size=self.patch, stride=self.patch, bias=True
|
| 1019 |
+
)
|
| 1020 |
+
self.out_norm = AdaLN(d_model)
|
| 1021 |
+
|
| 1022 |
+
T = tokens_per_frame
|
| 1023 |
+
idx = torch.arange(T, dtype=torch.long)
|
| 1024 |
+
self.register_buffer(
|
| 1025 |
+
"_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False
|
| 1026 |
+
)
|
| 1027 |
+
self.register_buffer(
|
| 1028 |
+
"_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False
|
| 1029 |
+
)
|
| 1030 |
+
self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False)
|
| 1031 |
+
|
| 1032 |
+
def forward(
|
| 1033 |
+
self,
|
| 1034 |
+
x: Tensor,
|
| 1035 |
+
sigma: Tensor,
|
| 1036 |
+
frame_timestamp: Tensor,
|
| 1037 |
+
frame_idx: Tensor | None = None,
|
| 1038 |
+
prompt_emb: Tensor | None = None,
|
| 1039 |
+
prompt_pad_mask: Tensor | None = None,
|
| 1040 |
+
mouse: Tensor | None = None,
|
| 1041 |
+
button: Tensor | None = None,
|
| 1042 |
+
scroll: Tensor | None = None,
|
| 1043 |
+
kv_cache=None,
|
| 1044 |
+
):
|
| 1045 |
+
B, N, C, H, W = x.shape
|
| 1046 |
+
ph, pw = self.patch
|
| 1047 |
+
assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch"
|
| 1048 |
+
Hp, Wp = H // ph, W // pw
|
| 1049 |
+
torch._assert(
|
| 1050 |
+
Hp * Wp == self.config.tokens_per_frame,
|
| 1051 |
+
f"{Hp} * {Wp} != {self.config.tokens_per_frame}",
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
torch._assert(
|
| 1055 |
+
B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1"
|
| 1056 |
+
)
|
| 1057 |
+
self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f))
|
| 1058 |
+
pos_ids = TensorDict(
|
| 1059 |
+
{
|
| 1060 |
+
"f_pos": (frame_timestamp if frame_idx is None else frame_idx)[0, 0].expand_as(self._t_pos_1f)[None],
|
| 1061 |
+
"t_pos": self._t_pos_1f[None],
|
| 1062 |
+
"y_pos": self._y_pos_1f[None],
|
| 1063 |
+
"x_pos": self._x_pos_1f[None],
|
| 1064 |
+
},
|
| 1065 |
+
batch_size=[1, self._t_pos_1f.numel()],
|
| 1066 |
+
)
|
| 1067 |
+
cond = self.denoise_step_emb(sigma)
|
| 1068 |
+
|
| 1069 |
+
assert button is not None
|
| 1070 |
+
ctx = {
|
| 1071 |
+
"ctrl_emb": self.ctrl_emb(mouse, button, scroll),
|
| 1072 |
+
"prompt_emb": prompt_emb,
|
| 1073 |
+
"prompt_pad_mask": prompt_pad_mask,
|
| 1074 |
+
}
|
| 1075 |
+
|
| 1076 |
+
D = self.config.d_model
|
| 1077 |
+
x = self.patchify(x.reshape(B * N, C, H, W))
|
| 1078 |
+
x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d")
|
| 1079 |
+
x = self.transformer(x, pos_ids, cond, ctx, kv_cache)
|
| 1080 |
+
x = F.silu(self.out_norm(x, cond))
|
| 1081 |
+
x = eo.rearrange(x, "b (n hp wp) d -> (b n) d hp wp", n=N, hp=Hp, wp=Wp)
|
| 1082 |
+
x = self.unpatchify(x)
|
| 1083 |
+
x = x.view(B, N, C, H, W)
|
| 1084 |
+
|
| 1085 |
+
return x
|
| 1086 |
+
|
| 1087 |
+
def get_active_parameters(self) -> int:
|
| 1088 |
+
total = sum(p.numel() for p in self.parameters())
|
| 1089 |
+
c = self.config
|
| 1090 |
+
if getattr(c, "moe", False):
|
| 1091 |
+
moe_mlp_ratio = getattr(c, "moe_mlp_ratio", None) or c.mlp_ratio / c.moe_top_k
|
| 1092 |
+
hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts)
|
| 1093 |
+
total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2)
|
| 1094 |
+
return total
|
| 1095 |
+
|
| 1096 |
+
def quantize(self, quant_type: str):
|
| 1097 |
+
quantize_model(self, quant_type)
|
| 1098 |
+
|
| 1099 |
+
def apply_inference_patches(self):
|
| 1100 |
+
_apply_inference_patches(self)
|
vae/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
from .ae_model import ChunkedStreamingTAEHV
|
| 17 |
+
|
| 18 |
+
__all__ = ["ChunkedStreamingTAEHV"]
|
vae/ae_model.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2025 Hugging Face Team and Overworld
|
| 2 |
+
#
|
| 3 |
+
# This program is free software: you can redistribute it and/or modify
|
| 4 |
+
# it under the terms of the GNU General Public License as published by
|
| 5 |
+
# the Free Software Foundation, either version 3 of the License, or
|
| 6 |
+
# (at your option) any later version.
|
| 7 |
+
#
|
| 8 |
+
# This program is distributed in the hope that it will be useful,
|
| 9 |
+
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 10 |
+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 11 |
+
# GNU General Public License for more details.
|
| 12 |
+
#
|
| 13 |
+
# You should have received a copy of the GNU General Public License
|
| 14 |
+
# along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 15 |
+
|
| 16 |
+
"""Streaming TAEHV autoencoder for WorldEngine wp-1.5 temporal-compressed latent decoding."""
|
| 17 |
+
|
| 18 |
+
from collections import namedtuple
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
from torch import Tensor
|
| 23 |
+
|
| 24 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 25 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Building blocks (mirror the taehv library)
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _conv(n_in, n_out, **kwargs):
|
| 36 |
+
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Clamp(nn.Module):
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
return torch.tanh(x / 3) * 3
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class MemBlock(nn.Module):
|
| 45 |
+
def __init__(self, n_in, n_out):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.conv = nn.Sequential(
|
| 48 |
+
_conv(n_in * 2, n_out),
|
| 49 |
+
nn.ReLU(inplace=True),
|
| 50 |
+
_conv(n_out, n_out),
|
| 51 |
+
nn.ReLU(inplace=True),
|
| 52 |
+
_conv(n_out, n_out),
|
| 53 |
+
)
|
| 54 |
+
self.skip = (
|
| 55 |
+
nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
| 56 |
+
)
|
| 57 |
+
self.act = nn.ReLU(inplace=True)
|
| 58 |
+
|
| 59 |
+
def forward(self, x, past):
|
| 60 |
+
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class TPool(nn.Module):
|
| 64 |
+
def __init__(self, n_f, stride):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.stride = stride
|
| 67 |
+
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
_NT, C, H, W = x.shape
|
| 71 |
+
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TGrow(nn.Module):
|
| 75 |
+
def __init__(self, n_f, stride):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.stride = stride
|
| 78 |
+
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
_NT, C, H, W = x.shape
|
| 82 |
+
x = self.conv(x)
|
| 83 |
+
return x.reshape(-1, C, H, W)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
# Sequential streaming helpers
|
| 88 |
+
# ---------------------------------------------------------------------------
|
| 89 |
+
|
| 90 |
+
def _sequential_single_step(model, memory, work_queue):
|
| 91 |
+
"""Process the work queue until an output frame is produced or the queue is empty."""
|
| 92 |
+
while work_queue:
|
| 93 |
+
xt, i = work_queue.pop(0)
|
| 94 |
+
if i == len(model):
|
| 95 |
+
return xt.unsqueeze(1)
|
| 96 |
+
b = model[i]
|
| 97 |
+
if isinstance(b, MemBlock):
|
| 98 |
+
if memory[i] is None:
|
| 99 |
+
xt_new = b(xt, xt * 0)
|
| 100 |
+
else:
|
| 101 |
+
xt_new = b(xt, memory[i])
|
| 102 |
+
memory[i] = xt
|
| 103 |
+
work_queue.insert(0, TWorkItem(xt_new, i + 1))
|
| 104 |
+
elif isinstance(b, TPool):
|
| 105 |
+
if memory[i] is None:
|
| 106 |
+
memory[i] = []
|
| 107 |
+
memory[i].append(xt)
|
| 108 |
+
if len(memory[i]) == b.stride:
|
| 109 |
+
N, C, H, W = xt.shape
|
| 110 |
+
xt = b(torch.cat(memory[i], 1).view(N * b.stride, C, H, W))
|
| 111 |
+
memory[i] = []
|
| 112 |
+
work_queue.insert(0, TWorkItem(xt, i + 1))
|
| 113 |
+
elif isinstance(b, TGrow):
|
| 114 |
+
xt = b(xt)
|
| 115 |
+
NT, C, H, W = xt.shape
|
| 116 |
+
for xt_next in reversed(
|
| 117 |
+
xt.view(NT // b.stride, b.stride * C, H, W).chunk(b.stride, 1)
|
| 118 |
+
):
|
| 119 |
+
work_queue.insert(0, TWorkItem(xt_next, i + 1))
|
| 120 |
+
else:
|
| 121 |
+
xt = b(xt)
|
| 122 |
+
work_queue.insert(0, TWorkItem(xt, i + 1))
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def _apply_parallel(model, x):
|
| 127 |
+
"""Apply model with parallelization over time axis. x: NTCHW."""
|
| 128 |
+
N, T, C, H, W = x.shape
|
| 129 |
+
x = x.reshape(N * T, C, H, W)
|
| 130 |
+
for b in model:
|
| 131 |
+
if isinstance(b, MemBlock):
|
| 132 |
+
NT, C, H, W = x.shape
|
| 133 |
+
T = NT // N
|
| 134 |
+
_x = x.reshape(N, T, C, H, W)
|
| 135 |
+
block_memory = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(
|
| 136 |
+
x.shape
|
| 137 |
+
)
|
| 138 |
+
x = b(x, block_memory)
|
| 139 |
+
else:
|
| 140 |
+
x = b(x)
|
| 141 |
+
NT, C, H, W = x.shape
|
| 142 |
+
T = NT // N
|
| 143 |
+
return x.view(N, T, C, H, W)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
# ---------------------------------------------------------------------------
|
| 147 |
+
# ChunkedStreamingTAEHV
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
|
| 150 |
+
class ChunkedStreamingTAEHV(ModelMixin, ConfigMixin):
|
| 151 |
+
"""Streaming TAEHV autoencoder for temporal-compressed latent decoding.
|
| 152 |
+
|
| 153 |
+
Owns the encoder/decoder weights directly so diffusers can load them
|
| 154 |
+
from safetensors. Provides a streaming interface that processes one
|
| 155 |
+
temporal chunk at a time, maintaining internal state across calls.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
_supports_gradient_checkpointing = False
|
| 159 |
+
|
| 160 |
+
@register_to_config
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
latent_channels: int = 32,
|
| 164 |
+
patch_size: int = 2,
|
| 165 |
+
image_channels: int = 3,
|
| 166 |
+
encoder_time_downscale: tuple[bool, ...] = (True, True, False),
|
| 167 |
+
decoder_time_upscale: tuple[bool, ...] = (False, True, True),
|
| 168 |
+
decoder_space_upscale: tuple[bool, ...] = (True, True, True),
|
| 169 |
+
):
|
| 170 |
+
super().__init__()
|
| 171 |
+
|
| 172 |
+
in_ch = image_channels * patch_size ** 2
|
| 173 |
+
|
| 174 |
+
self.encoder = nn.Sequential(
|
| 175 |
+
_conv(in_ch, 64), nn.ReLU(inplace=True),
|
| 176 |
+
TPool(64, 2 if encoder_time_downscale[0] else 1),
|
| 177 |
+
_conv(64, 64, stride=2, bias=False),
|
| 178 |
+
MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
| 179 |
+
TPool(64, 2 if encoder_time_downscale[1] else 1),
|
| 180 |
+
_conv(64, 64, stride=2, bias=False),
|
| 181 |
+
MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
| 182 |
+
TPool(64, 2 if encoder_time_downscale[2] else 1),
|
| 183 |
+
_conv(64, 64, stride=2, bias=False),
|
| 184 |
+
MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
| 185 |
+
_conv(64, latent_channels),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
n_f = [256, 128, 64, 64]
|
| 189 |
+
self.decoder = nn.Sequential(
|
| 190 |
+
Clamp(),
|
| 191 |
+
_conv(latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
| 192 |
+
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
|
| 193 |
+
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
|
| 194 |
+
TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1),
|
| 195 |
+
_conv(n_f[0], n_f[1], bias=False),
|
| 196 |
+
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
|
| 197 |
+
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
|
| 198 |
+
TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1),
|
| 199 |
+
_conv(n_f[1], n_f[2], bias=False),
|
| 200 |
+
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
|
| 201 |
+
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
|
| 202 |
+
TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1),
|
| 203 |
+
_conv(n_f[2], n_f[3], bias=False),
|
| 204 |
+
nn.ReLU(inplace=True),
|
| 205 |
+
_conv(n_f[3], image_channels * patch_size ** 2),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Computed properties
|
| 209 |
+
self.t_downscale = 2 ** sum(
|
| 210 |
+
t.stride == 2 for t in self.encoder if isinstance(t, TPool)
|
| 211 |
+
)
|
| 212 |
+
self.t_upscale = 2 ** sum(
|
| 213 |
+
t.stride == 2 for t in self.decoder if isinstance(t, TGrow)
|
| 214 |
+
)
|
| 215 |
+
self.frames_to_trim = self.t_upscale - 1
|
| 216 |
+
self.patch_size = patch_size
|
| 217 |
+
|
| 218 |
+
# Streaming state (initialised on first use / reset)
|
| 219 |
+
self._encoder_work_queue: list[TWorkItem] = []
|
| 220 |
+
self._encoder_memory: list = [None] * len(self.encoder)
|
| 221 |
+
self._decoder_work_queue: list[TWorkItem] = []
|
| 222 |
+
self._decoder_memory: list = [None] * len(self.decoder)
|
| 223 |
+
self._n_frames_encoded: int = 0
|
| 224 |
+
self._n_frames_decoded: int = 0
|
| 225 |
+
self._last_encoder_input_frame: Tensor | None = None
|
| 226 |
+
|
| 227 |
+
# ------------------------------------------------------------------
|
| 228 |
+
# Streaming state management
|
| 229 |
+
# ------------------------------------------------------------------
|
| 230 |
+
|
| 231 |
+
def reset(self):
|
| 232 |
+
"""Reset streaming state for a new sequence."""
|
| 233 |
+
self._encoder_work_queue = []
|
| 234 |
+
self._encoder_memory = [None] * len(self.encoder)
|
| 235 |
+
self._decoder_work_queue = []
|
| 236 |
+
self._decoder_memory = [None] * len(self.decoder)
|
| 237 |
+
self._n_frames_encoded = 0
|
| 238 |
+
self._n_frames_decoded = 0
|
| 239 |
+
self._last_encoder_input_frame = None
|
| 240 |
+
|
| 241 |
+
# ------------------------------------------------------------------
|
| 242 |
+
# Pre/post processing
|
| 243 |
+
# ------------------------------------------------------------------
|
| 244 |
+
|
| 245 |
+
def _preprocess_input_frames(self, x: Tensor) -> Tensor:
|
| 246 |
+
if self.patch_size > 1:
|
| 247 |
+
x = F.pixel_unshuffle(x, self.patch_size)
|
| 248 |
+
return x
|
| 249 |
+
|
| 250 |
+
def _postprocess_output_frames(self, x: Tensor) -> Tensor:
|
| 251 |
+
if self.patch_size > 1:
|
| 252 |
+
x = F.pixel_shuffle(x, self.patch_size)
|
| 253 |
+
return x.clamp_(0, 1)
|
| 254 |
+
|
| 255 |
+
# ------------------------------------------------------------------
|
| 256 |
+
# Streaming encode / decode (one chunk at a time)
|
| 257 |
+
# ------------------------------------------------------------------
|
| 258 |
+
|
| 259 |
+
def _streaming_encode_step(self, x: Tensor | None = None) -> Tensor | None:
|
| 260 |
+
"""Feed an input frame and try to produce an encoder output.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
x: N1CHW RGB frame tensor with values in [0, 1], or None.
|
| 264 |
+
Returns:
|
| 265 |
+
N1CHW latent tensor, or None if not enough input accumulated.
|
| 266 |
+
"""
|
| 267 |
+
if x is not None:
|
| 268 |
+
self._last_encoder_input_frame = x[:, -1:]
|
| 269 |
+
x = self._preprocess_input_frames(x)
|
| 270 |
+
self._encoder_work_queue.extend(
|
| 271 |
+
TWorkItem(xt, 0) for xt in x.unbind(1)
|
| 272 |
+
)
|
| 273 |
+
self._n_frames_encoded += x.shape[1]
|
| 274 |
+
return _sequential_single_step(
|
| 275 |
+
self.encoder, self._encoder_memory, self._encoder_work_queue
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
def _streaming_decode_step(self, x: Tensor | None = None) -> Tensor | None:
|
| 279 |
+
"""Feed a latent and try to produce a decoded frame.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
x: N1CHW latent tensor, or None to retrieve the next pending frame.
|
| 283 |
+
Returns:
|
| 284 |
+
N1CHW decoded RGB frame tensor, or None.
|
| 285 |
+
"""
|
| 286 |
+
if x is not None:
|
| 287 |
+
self._decoder_work_queue.extend(
|
| 288 |
+
TWorkItem(xt, 0) for xt in x.unbind(1)
|
| 289 |
+
)
|
| 290 |
+
while True:
|
| 291 |
+
xt = _sequential_single_step(
|
| 292 |
+
self.decoder, self._decoder_memory, self._decoder_work_queue
|
| 293 |
+
)
|
| 294 |
+
if xt is None:
|
| 295 |
+
return None
|
| 296 |
+
self._n_frames_decoded += 1
|
| 297 |
+
if self._n_frames_decoded <= self.frames_to_trim:
|
| 298 |
+
continue
|
| 299 |
+
return self._postprocess_output_frames(xt)
|
| 300 |
+
|
| 301 |
+
def _flush_decoder(self) -> list[Tensor]:
|
| 302 |
+
"""Drain all remaining decoded frames from the decoder."""
|
| 303 |
+
frames = []
|
| 304 |
+
while (frame := self._streaming_decode_step()) is not None:
|
| 305 |
+
frames.append(frame)
|
| 306 |
+
return frames
|
| 307 |
+
|
| 308 |
+
# ------------------------------------------------------------------
|
| 309 |
+
# Pipeline-facing encode / decode
|
| 310 |
+
# ------------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
@torch.inference_mode()
|
| 313 |
+
def encode(self, img: Tensor) -> Tensor:
|
| 314 |
+
"""Encode a chunk of frames to a single latent.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
img: [T, H, W, C] uint8 where T == t_downscale
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
latent: [B, C, h, w]
|
| 321 |
+
"""
|
| 322 |
+
assert img.dim() == 4 and img.shape[-1] == 3, "Expected [T, H, W, C] RGB uint8"
|
| 323 |
+
|
| 324 |
+
if img.shape[0] != self.t_downscale:
|
| 325 |
+
raise ValueError(
|
| 326 |
+
f"Expected {self.t_downscale} frames, got {img.shape[0]}"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
rgb = (
|
| 330 |
+
img.unsqueeze(0)
|
| 331 |
+
.to(device=self.device, dtype=self.dtype)
|
| 332 |
+
.permute(0, 1, 4, 2, 3)
|
| 333 |
+
.contiguous()
|
| 334 |
+
.div(255)
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
latent = self._streaming_encode_step(rgb)
|
| 338 |
+
if latent is None:
|
| 339 |
+
raise RuntimeError("Expected a latent after a full chunk")
|
| 340 |
+
|
| 341 |
+
return latent.squeeze(1)
|
| 342 |
+
|
| 343 |
+
@torch.inference_mode()
|
| 344 |
+
def decode(self, latent: Tensor) -> Tensor:
|
| 345 |
+
"""Decode a latent to frames.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
latent: [B, C, h, w]
|
| 349 |
+
|
| 350 |
+
Returns:
|
| 351 |
+
frames: [T, H, W, C] uint8
|
| 352 |
+
"""
|
| 353 |
+
assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor"
|
| 354 |
+
|
| 355 |
+
z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype)
|
| 356 |
+
|
| 357 |
+
if self._n_frames_decoded == 0:
|
| 358 |
+
for _ in range(self.frames_to_trim):
|
| 359 |
+
self._streaming_decode_step(z)
|
| 360 |
+
self._flush_decoder()
|
| 361 |
+
|
| 362 |
+
first = self._streaming_decode_step(z)
|
| 363 |
+
assert first is not None, "Expected decoded output after a latent"
|
| 364 |
+
frames = [first, *self._flush_decoder()]
|
| 365 |
+
|
| 366 |
+
decoded = torch.cat(frames, dim=1)
|
| 367 |
+
decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8)
|
| 368 |
+
return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3]
|
vae/config.json
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "ChunkedStreamingTAEHV",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"auto_map": {
|
| 5 |
+
"AutoModel": "ae_model.ChunkedStreamingTAEHV"
|
| 6 |
+
},
|
| 7 |
+
"latent_channels": 32,
|
| 8 |
+
"patch_size": 2,
|
| 9 |
+
"image_channels": 3,
|
| 10 |
+
"encoder_time_downscale": [true, true, false],
|
| 11 |
+
"decoder_time_upscale": [false, true, true],
|
| 12 |
+
"decoder_space_upscale": [true, true, true]
|
| 13 |
+
}
|
vae/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b52e245bb86c62e159f50338e2e8f422d4b6f98b467164939c1c031c7d61352e
|
| 3 |
+
size 22755856
|