Add diffusers support

#1
by dn6 HF Staff - opened
README.md CHANGED
@@ -79,6 +79,48 @@ This checkpoint is intended to be used with Overworld’s interactive runtime st
79
  - Play on our official desktop client, [Biome](https://over.world/install)
80
  - Use our [world_engine](https://github.com/Wayfarer-Labs/world_engine) inference library to build your own applications
81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  ### Recommended setup
84
 
@@ -129,4 +171,4 @@ Please see our blog post, ["Engineering Safety for Interactive World Models"](ht
129
 
130
  - [Website](http://over.world/)
131
  - [Discord](https://discord.gg/MEmQa7Wux4)
132
- - [X/Twitter](https://x.com/overworld_ai)
 
79
  - Play on our official desktop client, [Biome](https://over.world/install)
80
  - Use our [world_engine](https://github.com/Wayfarer-Labs/world_engine) inference library to build your own applications
81
 
82
+ ### Diffusers (Modular Pipeline)
83
+
84
+ This model can also be used with [Modular Diffusers](https://huggingface.co/docs/diffusers/main/en/modular_diffusers):
85
+
86
+ ```python
87
+ import torch
88
+ from diffusers.modular_pipelines import ModularPipeline
89
+ from diffusers.utils import load_image, export_to_video
90
+
91
+ pipe = ModularPipeline.from_pretrained(
92
+ "Overworld/Waypoint-1.5-1B", trust_remote_code=True
93
+ )
94
+ pipe.load_components(
95
+ device_map="cuda", torch_dtype=torch.bfloat16, trust_remote_code=True
96
+ )
97
+ pipe.transformer.apply_inference_patches()
98
+ pipe.transformer.compile(fullgraph=True, mode="max-autotune", dynamic=False)
99
+
100
+ # Seed the world with an image
101
+ image = load_image("https://huggingface.co/spaces/Overworld/waypoint-1-small/resolve/main/starter_18.png").resize((1024, 512))
102
+ state = pipe(
103
+ image=image,
104
+ prompt="An explorable world",
105
+ button=set(),
106
+ mouse=(0.0, 0.0),
107
+ output_type="pil",
108
+ )
109
+
110
+ # Generate subsequent frames with controller inputs
111
+ state.values["image"] = None
112
+ frames = []
113
+ for _ in range(150):
114
+ state = pipe(
115
+ state,
116
+ button={87}, # W key (walk forward)
117
+ mouse=(0.0, 0.0),
118
+ output_type="pil",
119
+ )
120
+ frames.append(state.values["images"]) # list of PIL images
121
+
122
+ export_to_video(outputs, "waypoint-v1-5.mp4", fps=60)
123
+ ```
124
 
125
  ### Recommended setup
126
 
 
171
 
172
  - [Website](http://over.world/)
173
  - [Discord](https://discord.gg/MEmQa7Wux4)
174
+ - [X/Twitter](https://x.com/overworld_ai)
__init__.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """
17
+ WorldEngine Modular Pipeline
18
+
19
+ A Diffusers-compatible modular pipeline for frame-by-frame world model generation.
20
+ Supports text and controller (mouse + button + scroll) conditioning.
21
+ """
22
+
23
+ from .modular_blocks import (
24
+ WorldEngineBlocks,
25
+ AUTO_BLOCKS,
26
+ WorldEngineTextEncoderStep,
27
+ WorldEngineControllerEncoderStep,
28
+ WorldEngineBeforeDenoiseStep,
29
+ WorldEngineSetTimestepsStep,
30
+ WorldEnginePrepareLatentsStep,
31
+ WorldEngineSetupKVCacheStep,
32
+ WorldEngineDenoiseLoop,
33
+ WorldEngineDecodeStep,
34
+ StaticKVCache,
35
+ LayerKVCache,
36
+ )
37
+ from .transformer import WorldModel
38
+ from .vae import ChunkedStreamingTAEHV
39
+
40
+ __version__ = "0.1.0"
41
+
42
+ __all__ = [
43
+ "WorldEngineBlocks",
44
+ "AUTO_BLOCKS",
45
+ "WorldEngineTextEncoderStep",
46
+ "WorldEngineControllerEncoderStep",
47
+ "WorldEngineBeforeDenoiseStep",
48
+ "WorldEngineSetTimestepsStep",
49
+ "WorldEnginePrepareLatentsStep",
50
+ "WorldEngineSetupKVCacheStep",
51
+ "WorldEngineDenoiseLoop",
52
+ "WorldEngineDecodeStep",
53
+ "WorldModel",
54
+ "ChunkedStreamingTAEHV",
55
+ "StaticKVCache",
56
+ "LayerKVCache",
57
+ ]
modular_blocks.py ADDED
@@ -0,0 +1,1188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """WorldEngine modular pipeline blocks.
17
+
18
+ All pipeline step classes for text encoding, controller encoding,
19
+ KV cache setup, latent preparation, denoising, and decoding.
20
+ """
21
+
22
+ import html
23
+
24
+ import numpy as np
25
+ import PIL.Image
26
+ import regex as re
27
+ import torch
28
+ from torch import nn, Tensor
29
+ from tensordict import TensorDict
30
+ from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask
31
+ from transformers import AutoTokenizer, UMT5EncoderModel
32
+
33
+ from diffusers import AutoModel
34
+ from diffusers.configuration_utils import FrozenDict
35
+ from diffusers.image_processor import VaeImageProcessor
36
+ from diffusers.utils import is_ftfy_available, logging
37
+ from diffusers.modular_pipelines import (
38
+ ModularPipelineBlocks,
39
+ ModularPipeline,
40
+ PipelineState,
41
+ SequentialPipelineBlocks,
42
+ )
43
+ from diffusers.modular_pipelines.modular_pipeline_utils import (
44
+ ComponentSpec,
45
+ ConfigSpec,
46
+ InputParam,
47
+ InsertableDict,
48
+ OutputParam,
49
+ )
50
+
51
+ if is_ftfy_available():
52
+ import ftfy
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # Text encoding helpers
60
+ # ---------------------------------------------------------------------------
61
+
62
+ def basic_clean(text):
63
+ text = ftfy.fix_text(text)
64
+ text = html.unescape(html.unescape(text))
65
+ return text.strip()
66
+
67
+
68
+ def whitespace_clean(text):
69
+ text = re.sub(r"\s+", " ", text)
70
+ text = text.strip()
71
+ return text
72
+
73
+
74
+ def prompt_clean(text):
75
+ text = whitespace_clean(basic_clean(text))
76
+ return text
77
+
78
+
79
+ # ---------------------------------------------------------------------------
80
+ # Block mask construction
81
+ # ---------------------------------------------------------------------------
82
+
83
+ def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask:
84
+ """
85
+ Create a block mask for flex_attention.
86
+
87
+ T and L must be exact multiples of the sparse block size; written must be
88
+ block-aligned (each block is either all True or all False).
89
+
90
+ Args:
91
+ T: Q length for this frame
92
+ L: KV capacity == written.numel()
93
+ written: [L] bool, True where there is valid KV data
94
+ """
95
+ BS = _DEFAULT_SPARSE_BLOCK_SIZE
96
+
97
+ if not torch.compiler.is_compiling():
98
+ torch._check(T % BS == 0, f"T ({T}) must be a multiple of block size ({BS})")
99
+ torch._check(L % BS == 0, f"L ({L}) must be a multiple of block size ({BS})")
100
+
101
+ Q_blocks = T // BS
102
+ KV_blocks = L // BS
103
+
104
+ written_blocks = written.view(KV_blocks, BS)
105
+ block_any = written_blocks.any(-1)
106
+
107
+ if not torch.compiler.is_compiling():
108
+ assert torch.equal(block_any, written_blocks.all(-1)), "written must be block-aligned"
109
+
110
+ # Every KV block is a full block (no partial blocks)
111
+ full_bm = block_any[None, :].expand(Q_blocks, KV_blocks)
112
+ full_kv_num_blocks = full_bm.sum(dim=-1, dtype=torch.int32)[None, None].contiguous()
113
+ full_kv_indices = full_bm.argsort(dim=-1, descending=True, stable=True).to(torch.int32)[None, None].contiguous()
114
+
115
+ # No partial blocks
116
+ kv_num_blocks = torch.zeros((1, 1, Q_blocks), dtype=torch.int32, device=written.device)
117
+ kv_indices = torch.zeros((1, 1, Q_blocks, KV_blocks), dtype=torch.int32, device=written.device)
118
+
119
+ return BlockMask.from_kv_blocks(
120
+ kv_num_blocks,
121
+ kv_indices,
122
+ full_kv_num_blocks,
123
+ full_kv_indices,
124
+ BLOCK_SIZE=BS,
125
+ mask_mod=None,
126
+ seq_lengths=(T, L),
127
+ compute_q_blocks=False,
128
+ )
129
+
130
+
131
+ # ---------------------------------------------------------------------------
132
+ # KV cache
133
+ # ---------------------------------------------------------------------------
134
+
135
+ class LayerKVCache(nn.Module):
136
+ """
137
+ Ring-buffer KV cache with fixed capacity L (tokens) for history plus
138
+ one extra frame (tokens_per_frame) at the tail holding the current frame.
139
+ """
140
+
141
+ def __init__(
142
+ self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1
143
+ ):
144
+ super().__init__()
145
+ self.tpf = tokens_per_frame
146
+ self.L = L
147
+ # total KV capacity: ring (L) + tail frame (tpf)
148
+ self.capacity = L + self.tpf
149
+ self.pinned_dilation = pinned_dilation
150
+ self.num_buckets = (L // self.tpf) // self.pinned_dilation
151
+ assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0
152
+
153
+ # KV buffer: [2, B, H, capacity, Dh]
154
+ self.kv = nn.Buffer(
155
+ torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype),
156
+ persistent=False,
157
+ )
158
+
159
+ # which slots have ever been written
160
+ # tail slice [L, L+tpf) always holds the current frame and is considered written
161
+ written = torch.zeros(self.capacity, dtype=torch.bool)
162
+ written[L:] = True
163
+ self.written = nn.Buffer(written, persistent=False)
164
+
165
+ # _mask_written is a scratch buffer for computing block masks without cloning
166
+ self._mask_written = nn.Buffer(torch.zeros_like(written), persistent=False)
167
+
168
+ # Precompute indices:
169
+ # frame_offsets: [0, 1, ..., tpf-1] (for ring indexing)
170
+ # current_idx: [L, L+1, ..., L+tpf-1] (tail slice)
171
+ self.frame_offsets = nn.Buffer(
172
+ torch.arange(self.tpf, dtype=torch.long), persistent=False
173
+ )
174
+ self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False)
175
+
176
+ def reset(self):
177
+ self.kv.zero_()
178
+ self.written.zero_()
179
+ self.written[self.L :].fill_(True)
180
+
181
+ def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool):
182
+ """
183
+ Args:
184
+ kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame)
185
+ pos_ids: TensorDict with f_pos [B, T] for cache slot indexing
186
+ """
187
+ T = self.tpf
188
+ f_pos = pos_ids["f_pos"]
189
+
190
+ if not torch.compiler.is_compiling():
191
+ torch._check(
192
+ kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert"
193
+ )
194
+ torch._check(f_pos.shape == (kv.size(1), T), "f_pos must be [B, T]")
195
+ torch._check(self.tpf <= self.L, "frame longer than KV ring capacity")
196
+ torch._check(
197
+ self.L % self.tpf == 0,
198
+ f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})",
199
+ )
200
+ torch._check(
201
+ self.kv.size(3) == self.capacity,
202
+ "KV buffer too long (expected L + tokens_per_frame)",
203
+ )
204
+ torch._check(
205
+ (f_pos >= 0).all().item(),
206
+ "f_pos must be non-negative during inference",
207
+ )
208
+ torch._check(
209
+ ((f_pos == f_pos[:, :1]).all()).item(),
210
+ "f_pos must be constant within frame",
211
+ )
212
+
213
+ frame_idx = f_pos[0, 0]
214
+
215
+ # map frame_idx to a bucket, each bucket owns T contiguous slots
216
+ bucket = (frame_idx + (self.pinned_dilation - 1)) // self.pinned_dilation
217
+ slot = bucket % self.num_buckets
218
+ base = slot * T
219
+
220
+ # indices in the ring for this frame: [T] in [0, L)
221
+ ring_idx = self.frame_offsets + base
222
+
223
+ # Always write current frame into the tail slice [L, L+T):
224
+ # this is the "self-attention component" for the current frame.
225
+ self.kv.index_copy_(3, self.current_idx, kv)
226
+
227
+ write_step = (frame_idx.remainder(self.pinned_dilation) == 0)
228
+ mask_written = self._mask_written
229
+ mask_written.copy_(self.written)
230
+ mask_written[ring_idx] = mask_written[ring_idx] & ~write_step
231
+ bm = make_block_mask(T, self.capacity, mask_written)
232
+
233
+ # Persist current frame into the ring for future queries when unfrozen.
234
+ if not is_frozen:
235
+ dst = torch.where(write_step, ring_idx, self.current_idx)
236
+ self.kv.index_copy_(3, dst, kv)
237
+ self.written[dst] = True
238
+
239
+ k, v = self.kv.unbind(0)
240
+ return k, v, bm
241
+
242
+
243
+ class StaticKVCache(nn.Module):
244
+ """Static KV cache with per-layer configuration for local/global attention."""
245
+
246
+ def __init__(self, config, batch_size, dtype):
247
+ super().__init__()
248
+
249
+ self.tpf = config.height * config.width
250
+
251
+ local_L = config.local_window * self.tpf
252
+ global_L = config.global_window * self.tpf
253
+
254
+ period = config.global_attn_period
255
+ off = getattr(config, "global_attn_offset", 0) % period
256
+ self.layers = nn.ModuleList(
257
+ [
258
+ LayerKVCache(
259
+ batch_size,
260
+ getattr(config, "n_kv_heads", None) or config.n_heads,
261
+ global_L if ((layer_idx - off) % period == 0) else local_L,
262
+ config.d_model // config.n_heads,
263
+ dtype,
264
+ self.tpf,
265
+ (
266
+ config.global_pinned_dilation
267
+ if ((layer_idx - off) % period == 0)
268
+ else 1
269
+ ),
270
+ )
271
+ for layer_idx in range(config.n_layers)
272
+ ]
273
+ )
274
+
275
+ self._is_frozen = True
276
+
277
+ def reset(self):
278
+ for layer in self.layers:
279
+ layer.reset()
280
+ self._is_frozen = True
281
+
282
+ @torch.inference_mode()
283
+ def get_state(self):
284
+ """Captures a world state to continue via load_state."""
285
+ layers = [(layer.kv.detach().clone(), layer.written.detach().clone()) for layer in self.layers]
286
+ return {"_is_frozen": self._is_frozen, "layers": layers}
287
+
288
+ @torch.inference_mode()
289
+ def load_state(self, state):
290
+ """Loads a world state object saved via get_state."""
291
+ self._is_frozen = bool(state.get("_is_frozen", True))
292
+ for layer, (kv, written) in zip(self.layers, state["layers"]):
293
+ layer.kv.copy_(kv)
294
+ layer.written.copy_(written)
295
+
296
+ def set_frozen(self, is_frozen: bool):
297
+ self._is_frozen = is_frozen
298
+
299
+ def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int):
300
+ kv = torch.stack([k, v], dim=0)
301
+ return self.layers[layer].upsert(kv, pos_ids, self._is_frozen)
302
+
303
+
304
+ # ---------------------------------------------------------------------------
305
+ # Pipeline step: Text Encoder
306
+ # ---------------------------------------------------------------------------
307
+
308
+ class WorldEngineTextEncoderStep(ModularPipelineBlocks):
309
+ """Encodes text prompts using UMT5-XL for conditioning."""
310
+
311
+ model_name = "world_engine"
312
+
313
+ @property
314
+ def description(self) -> str:
315
+ return (
316
+ "Text Encoder step that generates text embeddings to guide frame generation"
317
+ )
318
+
319
+ @property
320
+ def expected_components(self) -> list[ComponentSpec]:
321
+ return [
322
+ ComponentSpec("text_encoder", UMT5EncoderModel),
323
+ ComponentSpec("tokenizer", AutoTokenizer),
324
+ ]
325
+
326
+ @property
327
+ def inputs(self) -> list[InputParam]:
328
+ return [
329
+ InputParam(
330
+ "prompt",
331
+ description="The prompt or prompts to guide the frame generation",
332
+ ),
333
+ InputParam(
334
+ "prompt_embeds",
335
+ type_hint=torch.Tensor,
336
+ description="Pre-computed text embeddings",
337
+ ),
338
+ InputParam(
339
+ "prompt_pad_mask",
340
+ type_hint=torch.Tensor,
341
+ description="Padding mask for prompt embeddings",
342
+ ),
343
+ ]
344
+
345
+ @property
346
+ def intermediate_outputs(self) -> list[OutputParam]:
347
+ return [
348
+ OutputParam(
349
+ "prompt_embeds",
350
+ type_hint=torch.Tensor,
351
+ kwargs_type="denoiser_input_fields",
352
+ description="Text embeddings used to guide frame generation",
353
+ ),
354
+ OutputParam(
355
+ "prompt_pad_mask",
356
+ type_hint=torch.Tensor,
357
+ kwargs_type="denoiser_input_fields",
358
+ description="Padding mask for prompt embeddings",
359
+ ),
360
+ ]
361
+
362
+ @staticmethod
363
+ def check_inputs(block_state):
364
+ if block_state.prompt is not None and (
365
+ not isinstance(block_state.prompt, str)
366
+ and not isinstance(block_state.prompt, list)
367
+ ):
368
+ raise ValueError(
369
+ f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}"
370
+ )
371
+
372
+ @staticmethod
373
+ def encode_prompt(
374
+ components,
375
+ prompt: str | list[str],
376
+ device: torch.device,
377
+ max_sequence_length: int = 512,
378
+ ):
379
+ dtype = components.text_encoder.dtype
380
+
381
+ prompt = [prompt] if isinstance(prompt, str) else prompt
382
+ prompt = [prompt_clean(p) for p in prompt]
383
+
384
+ text_inputs = components.tokenizer(
385
+ prompt,
386
+ padding="max_length",
387
+ max_length=max_sequence_length,
388
+ truncation=True,
389
+ return_attention_mask=True,
390
+ return_tensors="pt",
391
+ )
392
+
393
+ text_input_ids = text_inputs.input_ids.to(device)
394
+ attention_mask = text_inputs.attention_mask.to(device)
395
+
396
+ prompt_embeds = components.text_encoder(
397
+ text_input_ids, attention_mask
398
+ ).last_hidden_state
399
+ prompt_embeds = prompt_embeds.to(dtype=dtype)
400
+
401
+ # Zero out padding
402
+ prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as(
403
+ prompt_embeds
404
+ )
405
+
406
+ # Create padding mask (True where padded)
407
+ prompt_pad_mask = attention_mask.eq(0)
408
+
409
+ return prompt_embeds, prompt_pad_mask
410
+
411
+ @torch.no_grad()
412
+ def __call__(
413
+ self, components: ModularPipeline, state: PipelineState
414
+ ) -> PipelineState:
415
+ block_state = self.get_block_state(state)
416
+ self.check_inputs(block_state)
417
+
418
+ device = components._execution_device
419
+ if block_state.prompt_embeds is None:
420
+ block_state.prompt = block_state.prompt or "An explorable world"
421
+ (
422
+ block_state.prompt_embeds,
423
+ block_state.prompt_pad_mask,
424
+ ) = self.encode_prompt(components, block_state.prompt, device)
425
+ block_state.prompt_embeds = block_state.prompt_embeds.contiguous()
426
+
427
+ if block_state.prompt_pad_mask is None:
428
+ block_state.prompt_pad_mask = torch.zeros(
429
+ block_state.prompt_embeds.shape[:2],
430
+ dtype=torch.bool,
431
+ device=device,
432
+ )
433
+
434
+ self.set_block_state(state, block_state)
435
+ return components, state
436
+
437
+
438
+ # ---------------------------------------------------------------------------
439
+ # Pipeline step: Controller Encoder
440
+ # ---------------------------------------------------------------------------
441
+
442
+ class WorldEngineControllerEncoderStep(ModularPipelineBlocks):
443
+ """Encodes controller inputs (mouse + buttons + scroll) for conditioning."""
444
+
445
+ model_name = "world_engine"
446
+
447
+ @property
448
+ def description(self) -> str:
449
+ return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning"
450
+
451
+ @property
452
+ def expected_components(self) -> list[ComponentSpec]:
453
+ return [] # Controller embedding is part of transformer
454
+
455
+ @property
456
+ def expected_configs(self) -> list[ConfigSpec]:
457
+ return [ConfigSpec("n_buttons", 256)]
458
+
459
+ @property
460
+ def inputs(self) -> list[InputParam]:
461
+ return [
462
+ InputParam(
463
+ "button",
464
+ type_hint=set[int],
465
+ default=set(),
466
+ description="Set of pressed button IDs",
467
+ ),
468
+ InputParam(
469
+ "mouse",
470
+ type_hint=tuple[float, float],
471
+ default=(0.0, 0.0),
472
+ description="Mouse velocity (x, y)",
473
+ ),
474
+ InputParam(
475
+ "scroll",
476
+ type_hint=int,
477
+ default=0,
478
+ description="Scroll wheel direction (-1, 0, 1)",
479
+ ),
480
+ InputParam(
481
+ "button_tensor",
482
+ type_hint=torch.Tensor,
483
+ kwargs_type="denoiser_input_fields",
484
+ description="One-hot encoded button tensor",
485
+ ),
486
+ InputParam(
487
+ "mouse_tensor",
488
+ type_hint=torch.Tensor,
489
+ kwargs_type="denoiser_input_fields",
490
+ description="Mouse velocity tensor",
491
+ ),
492
+ InputParam(
493
+ "scroll_tensor",
494
+ type_hint=torch.Tensor,
495
+ kwargs_type="denoiser_input_fields",
496
+ description="Scroll wheel sign tensor",
497
+ ),
498
+ ]
499
+
500
+ @property
501
+ def intermediate_outputs(self) -> list[OutputParam]:
502
+ return [
503
+ OutputParam(
504
+ "button_tensor",
505
+ type_hint=torch.Tensor,
506
+ kwargs_type="denoiser_input_fields",
507
+ description="One-hot encoded button tensor",
508
+ ),
509
+ OutputParam(
510
+ "mouse_tensor",
511
+ type_hint=torch.Tensor,
512
+ kwargs_type="denoiser_input_fields",
513
+ description="Mouse velocity tensor",
514
+ ),
515
+ OutputParam(
516
+ "scroll_tensor",
517
+ type_hint=torch.Tensor,
518
+ kwargs_type="denoiser_input_fields",
519
+ description="Scroll wheel sign tensor",
520
+ ),
521
+ ]
522
+
523
+ @torch.no_grad()
524
+ def __call__(
525
+ self, components: ModularPipeline, state: PipelineState
526
+ ) -> PipelineState:
527
+ block_state = self.get_block_state(state)
528
+ device = components._execution_device
529
+ dtype = components.transformer.dtype
530
+
531
+ n_buttons = components.config.n_buttons
532
+
533
+ # Create or reuse button tensor [1, 1, n_buttons]
534
+ if block_state.button_tensor is None:
535
+ block_state.button_tensor = torch.zeros(
536
+ (1, 1, n_buttons), device=device, dtype=dtype
537
+ )
538
+
539
+ # Update button tensor in-place (avoid dynamic shapes for torch.compile)
540
+ block_state.button_tensor.zero_()
541
+ if block_state.button:
542
+ for btn_id in block_state.button:
543
+ if 0 <= btn_id < n_buttons:
544
+ block_state.button_tensor[0, 0, btn_id] = 1.0
545
+
546
+ # Create or reuse mouse tensor [1, 1, 2]
547
+ if block_state.mouse_tensor is None:
548
+ block_state.mouse_tensor = torch.zeros(
549
+ (1, 1, 2), device=device, dtype=dtype
550
+ )
551
+
552
+ # Update mouse tensor in-place
553
+ mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0)
554
+ block_state.mouse_tensor[0, 0, 0] = mouse[0]
555
+ block_state.mouse_tensor[0, 0, 1] = mouse[1]
556
+
557
+ # Create or reuse scroll tensor [1, 1, 1]
558
+ if block_state.scroll_tensor is None:
559
+ block_state.scroll_tensor = torch.zeros(
560
+ (1, 1, 1), device=device, dtype=dtype
561
+ )
562
+
563
+ # Update scroll tensor in-place (sign of scroll value: -1, 0, or 1)
564
+ scroll = block_state.scroll if block_state.scroll is not None else 0
565
+ block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0)
566
+
567
+ self.set_block_state(state, block_state)
568
+ return components, state
569
+
570
+
571
+ # ---------------------------------------------------------------------------
572
+ # Pipeline step: Set Timesteps
573
+ # ---------------------------------------------------------------------------
574
+
575
+ class WorldEngineSetTimestepsStep(ModularPipelineBlocks):
576
+ """Sets up the scheduler sigmas for rectified flow denoising."""
577
+
578
+ model_name = "world_engine"
579
+
580
+ @property
581
+ def description(self) -> str:
582
+ return "Sets up scheduler sigmas for rectified flow denoising"
583
+
584
+ @property
585
+ def expected_components(self) -> list[ComponentSpec]:
586
+ return []
587
+
588
+ @property
589
+ def expected_configs(self) -> list[ConfigSpec]:
590
+ return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])]
591
+
592
+ @property
593
+ def inputs(self) -> list[InputParam]:
594
+ return [
595
+ InputParam(
596
+ "scheduler_sigmas",
597
+ type_hint=list[float],
598
+ description="Custom scheduler sigmas (overrides config)",
599
+ ),
600
+ InputParam(
601
+ "frame_timestamp",
602
+ type_hint=torch.Tensor,
603
+ description="Current frame timestamp",
604
+ ),
605
+ ]
606
+
607
+ @property
608
+ def intermediate_outputs(self) -> list[OutputParam]:
609
+ return [
610
+ OutputParam(
611
+ "scheduler_sigmas",
612
+ type_hint=torch.Tensor,
613
+ description="Tensor of scheduler sigmas for denoising",
614
+ ),
615
+ OutputParam(
616
+ "frame_timestamp",
617
+ type_hint=torch.Tensor,
618
+ description="Current frame timestamp (unscaled counter)",
619
+ ),
620
+ OutputParam(
621
+ "ts_mult",
622
+ type_hint=int,
623
+ description="Timestamp multiplier (base_fps // latent_fps)",
624
+ ),
625
+ ]
626
+
627
+ @torch.no_grad()
628
+ def __call__(
629
+ self, components: ModularPipeline, state: PipelineState
630
+ ) -> PipelineState:
631
+ block_state = self.get_block_state(state)
632
+ device = components._execution_device
633
+ dtype = components.transformer.dtype
634
+
635
+ # Use provided sigmas or get from config
636
+ sigmas = block_state.scheduler_sigmas
637
+ if sigmas is None:
638
+ sigmas = components.config.scheduler_sigmas
639
+ block_state.scheduler_sigmas = torch.tensor(
640
+ sigmas, device=device, dtype=dtype
641
+ )
642
+
643
+ frame_ts = block_state.frame_timestamp
644
+ if frame_ts is None:
645
+ frame_ts = torch.tensor([[0]], dtype=torch.long, device=device)
646
+ elif isinstance(frame_ts, int):
647
+ frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device)
648
+
649
+ # Compute ts_mult: ratio of base_fps to latent_fps
650
+ t_cfg = components.transformer.config
651
+ base_fps = getattr(t_cfg, "base_fps", 60)
652
+ inference_fps = getattr(t_cfg, "inference_fps", base_fps)
653
+ temporal_compression = getattr(t_cfg, "temporal_compression", 1)
654
+ latent_fps = inference_fps / temporal_compression
655
+ ts_mult = int(base_fps) // int(latent_fps)
656
+ block_state.ts_mult = ts_mult
657
+ block_state.frame_timestamp = frame_ts
658
+
659
+ self.set_block_state(state, block_state)
660
+ return components, state
661
+
662
+
663
+ # ---------------------------------------------------------------------------
664
+ # Pipeline step: Setup KV Cache
665
+ # ---------------------------------------------------------------------------
666
+
667
+ class WorldEngineSetupKVCacheStep(ModularPipelineBlocks):
668
+ """Initializes or reuses the KV cache for autoregressive generation."""
669
+
670
+ model_name = "world_engine"
671
+
672
+ @property
673
+ def description(self) -> str:
674
+ return "Initializes or reuses KV cache for autoregressive frame generation"
675
+
676
+ @property
677
+ def expected_components(self) -> list[ComponentSpec]:
678
+ return []
679
+
680
+ @property
681
+ def inputs(self) -> list[InputParam]:
682
+ return [
683
+ InputParam(
684
+ "kv_cache",
685
+ type_hint=StaticKVCache | None,
686
+ description="Existing KV cache (will be reused if provided)",
687
+ ),
688
+ InputParam(
689
+ "reset_cache",
690
+ type_hint=bool,
691
+ default=False,
692
+ description="If True, reset the KV cache even if one exists",
693
+ ),
694
+ ]
695
+
696
+ @property
697
+ def intermediate_outputs(self) -> list[OutputParam]:
698
+ return [
699
+ OutputParam(
700
+ "kv_cache",
701
+ type_hint=StaticKVCache,
702
+ description="KV cache for transformer attention",
703
+ ),
704
+ ]
705
+
706
+ @torch.no_grad()
707
+ def __call__(
708
+ self, components: ModularPipeline, state: PipelineState
709
+ ) -> PipelineState:
710
+ block_state = self.get_block_state(state)
711
+ device = components._execution_device
712
+ dtype = components.transformer.dtype
713
+
714
+ # Create or reuse KV cache
715
+ if block_state.kv_cache is None:
716
+ block_state.kv_cache = StaticKVCache(
717
+ components.transformer.config,
718
+ batch_size=1,
719
+ dtype=dtype,
720
+ ).to(device)
721
+ elif block_state.reset_cache:
722
+ block_state.kv_cache.reset()
723
+
724
+ self.set_block_state(state, block_state)
725
+ return components, state
726
+
727
+
728
+ # ---------------------------------------------------------------------------
729
+ # Pipeline step: Prepare Latents
730
+ # ---------------------------------------------------------------------------
731
+
732
+ class WorldEnginePrepareLatentsStep(ModularPipelineBlocks):
733
+ """Prepares latents for frame generation, optionally encoding an input image."""
734
+
735
+ model_name = "world_engine"
736
+
737
+ @property
738
+ def description(self) -> str:
739
+ return (
740
+ "Prepares latents for frame generation. If an image is provided on the "
741
+ "first frame, encodes it and caches it as context. Always creates fresh "
742
+ "random noise for the actual denoising."
743
+ )
744
+
745
+ @property
746
+ def expected_components(self) -> list[ComponentSpec]:
747
+ return [
748
+ ComponentSpec(
749
+ "image_processor",
750
+ VaeImageProcessor,
751
+ config=FrozenDict(
752
+ {
753
+ "vae_scale_factor": 16,
754
+ "do_normalize": False,
755
+ "do_convert_rgb": False,
756
+ }
757
+ ),
758
+ default_creation_method="from_config",
759
+ ),
760
+ ]
761
+
762
+ @property
763
+ def expected_configs(self) -> list[ConfigSpec]:
764
+ return [
765
+ ConfigSpec("channels", 16),
766
+ ConfigSpec("height", 16),
767
+ ConfigSpec("width", 16),
768
+ ConfigSpec("patch", [2, 2]),
769
+ ConfigSpec("vae_scale_factor", 16),
770
+ ]
771
+
772
+ @property
773
+ def inputs(self) -> list[InputParam]:
774
+ return [
775
+ InputParam(
776
+ "image",
777
+ type_hint=PIL.Image.Image | torch.Tensor,
778
+ description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame",
779
+ ),
780
+ InputParam(
781
+ "latents",
782
+ type_hint=torch.Tensor,
783
+ description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.",
784
+ ),
785
+ InputParam(
786
+ "use_random_latents",
787
+ type_hint=bool,
788
+ default=True,
789
+ description="If True, always generate fresh random latents. If False, use provided latents.",
790
+ ),
791
+ InputParam(
792
+ "kv_cache",
793
+ description="KV cache to update",
794
+ ),
795
+ InputParam(
796
+ "frame_timestamp",
797
+ type_hint=torch.Tensor,
798
+ description="Current frame timestamp",
799
+ ),
800
+ InputParam(
801
+ "prompt_embeds",
802
+ type_hint=torch.Tensor,
803
+ description="Prompt embeddings for cache pass",
804
+ ),
805
+ InputParam(
806
+ "prompt_pad_mask",
807
+ type_hint=torch.Tensor,
808
+ description="Prompt padding mask",
809
+ ),
810
+ InputParam(
811
+ "button_tensor",
812
+ type_hint=torch.Tensor,
813
+ description="Button tensor for cache pass",
814
+ ),
815
+ InputParam(
816
+ "mouse_tensor",
817
+ type_hint=torch.Tensor,
818
+ description="Mouse tensor for cache pass",
819
+ ),
820
+ InputParam(
821
+ "scroll_tensor",
822
+ type_hint=torch.Tensor,
823
+ description="Scroll tensor for cache pass",
824
+ ),
825
+ InputParam(
826
+ "generator",
827
+ type_hint=torch.Generator,
828
+ default=None,
829
+ description="torch Generator for deterministic output",
830
+ ),
831
+ InputParam(
832
+ "ts_mult",
833
+ required=True,
834
+ type_hint=int,
835
+ description="Timestamp multiplier (base_fps // latent_fps)",
836
+ ),
837
+ ]
838
+
839
+ @property
840
+ def intermediate_outputs(self) -> list[OutputParam]:
841
+ return [
842
+ OutputParam(
843
+ "latents",
844
+ type_hint=torch.Tensor,
845
+ description="Latent tensor for denoising [1, 1, C, H, W]",
846
+ ),
847
+ ]
848
+
849
+ @staticmethod
850
+ def _cache_pass(
851
+ transformer,
852
+ x,
853
+ frame_timestamp,
854
+ frame_idx,
855
+ prompt_emb,
856
+ prompt_pad_mask,
857
+ mouse,
858
+ button,
859
+ scroll,
860
+ kv_cache,
861
+ ):
862
+ """Cache pass to persist frame in KV cache."""
863
+ kv_cache.set_frozen(False)
864
+ transformer(
865
+ x=x,
866
+ sigma=x.new_zeros((x.size(0), x.size(1))),
867
+ frame_timestamp=frame_timestamp,
868
+ frame_idx=frame_idx,
869
+ prompt_emb=prompt_emb,
870
+ prompt_pad_mask=prompt_pad_mask,
871
+ mouse=mouse,
872
+ button=button,
873
+ scroll=scroll,
874
+ kv_cache=kv_cache,
875
+ )
876
+
877
+ @torch.inference_mode()
878
+ def __call__(
879
+ self, components: ModularPipeline, state: PipelineState
880
+ ) -> PipelineState:
881
+ block_state = self.get_block_state(state)
882
+ device = components._execution_device
883
+ dtype = components.transformer.dtype
884
+
885
+ # Get latent shape info
886
+ channels = components.config.channels
887
+ height = components.config.height # patch grid height
888
+ width = components.config.width # patch grid width
889
+ patch = components.config.patch
890
+ vae_scale_factor = components.config.vae_scale_factor
891
+
892
+ pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch)
893
+ latent_H = height * pH
894
+ latent_W = width * pW
895
+ shape = (1, 1, channels, latent_H, latent_W)
896
+
897
+ # Pixel dimensions for image preprocessing
898
+ pixel_H = latent_H * vae_scale_factor
899
+ pixel_W = latent_W * vae_scale_factor
900
+
901
+ if block_state.image is not None:
902
+ image = block_state.image
903
+ # Preprocess: PIL/tensor -> [B, C, H, W] float32 in [0, 1]
904
+ image = components.image_processor.preprocess(
905
+ image,
906
+ height=pixel_H,
907
+ width=pixel_W,
908
+ )
909
+ # Convert to [H, W, 3] uint8 for VAE encoder
910
+ image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8)
911
+
912
+ assert image.dtype == torch.uint8, (
913
+ f"Expected uint8 image, got {image.dtype}"
914
+ )
915
+
916
+ # Temporal VAE expects [T, H, W, C]; repeat the single frame
917
+ t_down = getattr(components.vae, "t_downscale", 1)
918
+ if t_down > 1:
919
+ image = image.unsqueeze(0).expand(t_down, -1, -1, -1)
920
+
921
+ latents = components.vae.encode(image)
922
+ latents = latents.unsqueeze(1)
923
+
924
+ # Run cache pass to persist encoded frame
925
+ ts_mult = block_state.ts_mult
926
+ self._cache_pass(
927
+ components.transformer,
928
+ latents,
929
+ block_state.frame_timestamp * ts_mult,
930
+ block_state.frame_timestamp,
931
+ block_state.prompt_embeds,
932
+ block_state.prompt_pad_mask,
933
+ block_state.mouse_tensor,
934
+ block_state.button_tensor,
935
+ block_state.scroll_tensor,
936
+ block_state.kv_cache,
937
+ )
938
+ block_state.frame_timestamp.add_(1)
939
+
940
+ # Generate latents based on use_random_latents flag
941
+ if block_state.use_random_latents or block_state.latents is None:
942
+ block_state.latents = torch.randn(
943
+ shape, device=device, dtype=torch.bfloat16
944
+ )
945
+
946
+ self.set_block_state(state, block_state)
947
+ return components, state
948
+
949
+
950
+ # ---------------------------------------------------------------------------
951
+ # Pipeline step: Before Denoise (sequential wrapper)
952
+ # ---------------------------------------------------------------------------
953
+
954
+ class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks):
955
+ """Sequential pipeline that prepares all inputs for denoising."""
956
+
957
+ block_classes = [
958
+ WorldEngineSetTimestepsStep,
959
+ WorldEngineSetupKVCacheStep,
960
+ WorldEnginePrepareLatentsStep,
961
+ ]
962
+ block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"]
963
+
964
+ @property
965
+ def description(self) -> str:
966
+ return (
967
+ "Before denoise step that prepares inputs for denoising:\n"
968
+ " - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n"
969
+ " - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n"
970
+ " - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise"
971
+ )
972
+
973
+
974
+ # ---------------------------------------------------------------------------
975
+ # Pipeline step: Denoise Loop
976
+ # ---------------------------------------------------------------------------
977
+
978
+ class WorldEngineDenoiseLoop(ModularPipelineBlocks):
979
+ """Denoises latents using rectified flow and updates KV cache."""
980
+
981
+ model_name = "world_engine"
982
+
983
+ @property
984
+ def expected_components(self) -> list[ComponentSpec]:
985
+ return [ComponentSpec("transformer", AutoModel)]
986
+
987
+ @property
988
+ def description(self) -> str:
989
+ return (
990
+ "Denoises latents using rectified flow (x = x + dsigma * v) "
991
+ "and updates KV cache for autoregressive generation."
992
+ )
993
+
994
+ @property
995
+ def inputs(self) -> list[InputParam]:
996
+ return [
997
+ InputParam("scheduler_sigmas", required=True, type_hint=torch.Tensor, description="Scheduler sigmas for denoising"),
998
+ InputParam("latents", required=True, type_hint=torch.Tensor, description="Initial noisy latents [1, 1, C, H, W]"),
999
+ InputParam("kv_cache", required=True, description="KV cache for transformer attention"),
1000
+ InputParam("frame_timestamp", required=True, type_hint=torch.Tensor, description="Current frame timestamp"),
1001
+ InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Text embeddings for conditioning"),
1002
+ InputParam("prompt_pad_mask", type_hint=torch.Tensor, description="Padding mask for prompt embeddings"),
1003
+ InputParam("button_tensor", required=True, type_hint=torch.Tensor, description="One-hot encoded button tensor"),
1004
+ InputParam("mouse_tensor", required=True, type_hint=torch.Tensor, description="Mouse velocity tensor"),
1005
+ InputParam("scroll_tensor", required=True, type_hint=torch.Tensor, description="Scroll wheel sign tensor"),
1006
+ InputParam("ts_mult", required=True, type_hint=int, description="Timestamp multiplier (base_fps // latent_fps)"),
1007
+ ]
1008
+
1009
+ @property
1010
+ def intermediate_outputs(self) -> list[OutputParam]:
1011
+ return [
1012
+ OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"),
1013
+ ]
1014
+
1015
+ @staticmethod
1016
+ def _denoise_pass(
1017
+ transformer, x, sigmas, frame_timestamp, frame_idx,
1018
+ prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache,
1019
+ ):
1020
+ """Denoising loop using rectified flow."""
1021
+ kv_cache.set_frozen(True)
1022
+ sigma = x.new_empty((x.size(0), x.size(1)))
1023
+ for step_sig, step_dsig in zip(sigmas, sigmas.diff()):
1024
+ v = transformer(
1025
+ x=x, sigma=sigma.fill_(step_sig),
1026
+ frame_timestamp=frame_timestamp, frame_idx=frame_idx,
1027
+ prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask,
1028
+ mouse=mouse, button=button, scroll=scroll,
1029
+ kv_cache=kv_cache,
1030
+ )
1031
+ x = x + step_dsig * v
1032
+ return x
1033
+
1034
+ @staticmethod
1035
+ def _cache_pass(
1036
+ transformer, x, frame_timestamp, frame_idx,
1037
+ prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache,
1038
+ ):
1039
+ """Cache pass to persist frame for next generation."""
1040
+ kv_cache.set_frozen(False)
1041
+ transformer(
1042
+ x=x, sigma=x.new_zeros((x.size(0), x.size(1))),
1043
+ frame_timestamp=frame_timestamp, frame_idx=frame_idx,
1044
+ prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask,
1045
+ mouse=mouse, button=button, scroll=scroll,
1046
+ kv_cache=kv_cache,
1047
+ )
1048
+
1049
+ @torch.inference_mode()
1050
+ def __call__(
1051
+ self, components: ModularPipeline, state: PipelineState
1052
+ ) -> PipelineState:
1053
+ block_state = self.get_block_state(state)
1054
+ ts_mult = block_state.ts_mult
1055
+ block_state.latents = self._denoise_pass(
1056
+ components.transformer,
1057
+ block_state.latents,
1058
+ block_state.scheduler_sigmas,
1059
+ block_state.frame_timestamp * ts_mult,
1060
+ block_state.frame_timestamp,
1061
+ block_state.prompt_embeds,
1062
+ block_state.prompt_pad_mask,
1063
+ block_state.mouse_tensor,
1064
+ block_state.button_tensor,
1065
+ block_state.scroll_tensor,
1066
+ block_state.kv_cache,
1067
+ ).clone()
1068
+
1069
+ self._cache_pass(
1070
+ components.transformer,
1071
+ block_state.latents,
1072
+ block_state.frame_timestamp * ts_mult,
1073
+ block_state.frame_timestamp,
1074
+ block_state.prompt_embeds,
1075
+ block_state.prompt_pad_mask,
1076
+ block_state.mouse_tensor,
1077
+ block_state.button_tensor,
1078
+ block_state.scroll_tensor,
1079
+ block_state.kv_cache,
1080
+ )
1081
+ block_state.frame_timestamp.add_(1)
1082
+
1083
+ self.set_block_state(state, block_state)
1084
+ return components, state
1085
+
1086
+
1087
+ # ---------------------------------------------------------------------------
1088
+ # Pipeline step: Decode
1089
+ # ---------------------------------------------------------------------------
1090
+
1091
+ class WorldEngineDecodeStep(ModularPipelineBlocks):
1092
+ """Decodes denoised latents back to RGB image using VAE."""
1093
+
1094
+ model_name = "world_engine"
1095
+
1096
+ @property
1097
+ def expected_components(self) -> list[ComponentSpec]:
1098
+ return [
1099
+ ComponentSpec("vae", AutoModel),
1100
+ ComponentSpec(
1101
+ "image_processor",
1102
+ VaeImageProcessor,
1103
+ config=FrozenDict(
1104
+ {
1105
+ "vae_scale_factor": 16,
1106
+ "do_normalize": False,
1107
+ "do_convert_rgb": True,
1108
+ }
1109
+ ),
1110
+ default_creation_method="from_config",
1111
+ ),
1112
+ ]
1113
+
1114
+ @property
1115
+ def description(self) -> str:
1116
+ return "Decodes denoised latents to RGB image using the VAE decoder"
1117
+
1118
+ @property
1119
+ def inputs(self) -> list[InputParam]:
1120
+ return [
1121
+ InputParam("latents", required=True, type_hint=torch.Tensor, description="Denoised latent tensor [1, 1, C, H, W]"),
1122
+ InputParam("output_type", default="pil", description="The output format for the generated images (pil, latent, pt, or np)"),
1123
+ ]
1124
+
1125
+ @property
1126
+ def intermediate_outputs(self) -> list[OutputParam]:
1127
+ return [
1128
+ OutputParam(
1129
+ "images",
1130
+ type_hint=PIL.Image.Image | torch.Tensor | np.ndarray,
1131
+ description="Decoded RGB image in requested output format",
1132
+ ),
1133
+ ]
1134
+
1135
+ @torch.no_grad()
1136
+ def __call__(
1137
+ self, components: ModularPipeline, state: PipelineState
1138
+ ) -> PipelineState:
1139
+ block_state = self.get_block_state(state)
1140
+ latents = block_state.latents
1141
+ output_type = block_state.output_type or "pil"
1142
+
1143
+ if output_type == "latent":
1144
+ block_state.images = latents
1145
+ else:
1146
+ # Decode to image
1147
+ frames = components.vae.decode(latents.squeeze(1))
1148
+
1149
+ if frames.dim() == 3:
1150
+ # Single frame [H, W, C] — wrap so the loop below works uniformly
1151
+ frames = frames.unsqueeze(0)
1152
+
1153
+ # Postprocess based on output_type
1154
+ if output_type == "pt":
1155
+ block_state.images = frames
1156
+ elif output_type == "np":
1157
+ block_state.images = frames.cpu().numpy()
1158
+ else: # "pil"
1159
+ block_state.images = [
1160
+ PIL.Image.fromarray(f.cpu().numpy()) for f in frames
1161
+ ]
1162
+
1163
+ # Clear latents so next frame generates fresh random noise
1164
+ block_state.latents = None
1165
+ self.set_block_state(state, block_state)
1166
+ return components, state
1167
+
1168
+
1169
+ # ---------------------------------------------------------------------------
1170
+ # Top-level block registry
1171
+ # ---------------------------------------------------------------------------
1172
+
1173
+ AUTO_BLOCKS = InsertableDict(
1174
+ [
1175
+ ("text_encoder", WorldEngineTextEncoderStep),
1176
+ ("controller_encoder", WorldEngineControllerEncoderStep),
1177
+ ("before_denoise", WorldEngineBeforeDenoiseStep),
1178
+ ("denoise", WorldEngineDenoiseLoop),
1179
+ ("decode", WorldEngineDecodeStep),
1180
+ ]
1181
+ )
1182
+
1183
+
1184
+ class WorldEngineBlocks(SequentialPipelineBlocks):
1185
+ """Sequential pipeline blocks for WorldEngine frame generation."""
1186
+
1187
+ block_classes = list(AUTO_BLOCKS.values())
1188
+ block_names = list(AUTO_BLOCKS.keys())
modular_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "WorldEngineBlocks",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "modular_blocks.WorldEngineBlocks"
6
+ }
7
+ }
modular_model_index.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_blocks_class_name": "WorldEngineBlocks",
3
+ "_class_name": "ModularPipeline",
4
+ "_diffusers_version": "0.36.0.dev0",
5
+ "channels": 32,
6
+ "height": 16,
7
+ "width": 32,
8
+ "patch": [
9
+ 2,
10
+ 2
11
+ ],
12
+ "vae_scale_factor": 16,
13
+ "n_buttons": 256,
14
+ "tokens_per_frame": 512,
15
+ "scheduler_sigmas": [
16
+ 1.0,
17
+ 0.9,
18
+ 0.75,
19
+ 0.3,
20
+ 0.0
21
+ ],
22
+ "transformer": [
23
+ null,
24
+ null,
25
+ {
26
+ "pretrained_model_name_or_path": "Overworld/Waypoint-1.5-1B",
27
+ "subfolder": "transformer",
28
+ "type_hint": [
29
+ "diffusers",
30
+ "AutoModel"
31
+ ],
32
+ "revision": null,
33
+ "variant": null
34
+ }
35
+ ],
36
+ "vae": [
37
+ null,
38
+ null,
39
+ {
40
+ "pretrained_model_name_or_path": "Overworld/Waypoint-1.5-1B",
41
+ "subfolder": "vae",
42
+ "type_hint": [
43
+ "diffusers",
44
+ "AutoModel"
45
+ ],
46
+ "revision": null,
47
+ "variant": null
48
+ }
49
+ ],
50
+ "text_encoder": [
51
+ null,
52
+ null,
53
+ {
54
+ "pretrained_model_name_or_path": "google/umt5-xl",
55
+ "type_hint": [
56
+ "transformers",
57
+ "UMT5EncoderModel"
58
+ ],
59
+ "revision": null,
60
+ "variant": null
61
+ }
62
+ ],
63
+ "tokenizer": [
64
+ null,
65
+ null,
66
+ {
67
+ "pretrained_model_name_or_path": "google/umt5-xl",
68
+ "type_hint": [
69
+ "transformers",
70
+ "AutoTokenizer"
71
+ ],
72
+ "revision": null,
73
+ "variant": null
74
+ }
75
+ ]
76
+ }
transformer/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from .model import WorldModel
transformer/config.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "WorldModel",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "model.WorldModel"
6
+ },
7
+ "d_model": 2048,
8
+ "n_heads": 32,
9
+ "n_kv_heads": 16,
10
+ "n_layers": 24,
11
+ "mlp_ratio": 4,
12
+ "channels": 32,
13
+ "height": 16,
14
+ "width": 32,
15
+ "patch": [
16
+ 2,
17
+ 2
18
+ ],
19
+ "tokens_per_frame": 512,
20
+ "n_frames": 512,
21
+ "local_window": 16,
22
+ "global_window": 128,
23
+ "global_attn_period": 4,
24
+ "global_pinned_dilation": 8,
25
+ "global_attn_offset": -1,
26
+ "value_residual": true,
27
+ "gated_attn": false,
28
+ "n_buttons": 256,
29
+ "ctrl_conditioning": true,
30
+ "ctrl_conditioning_period": 3,
31
+ "ctrl_cond_dropout": 0.0,
32
+ "prompt_conditioning": null,
33
+ "prompt_conditioning_period": 3,
34
+ "prompt_embedding_dim": 2048,
35
+ "prompt_cond_dropout": 0.0,
36
+ "noise_conditioning": "wan",
37
+ "base_fps": 15,
38
+ "causal": true,
39
+ "mlp_gradient_checkpointing": true,
40
+ "block_gradient_checkpointing": true,
41
+ "rope_impl": "ortho",
42
+ "moe": false,
43
+ "moe_top_k": 2,
44
+ "moe_n_experts": 8,
45
+ "moe_mlp_ratio": null,
46
+ "gated_linear": false,
47
+ "temporal_compression": 4,
48
+ "inference_fps": 60,
49
+ "taehv_ae": true,
50
+ "rope_nyquist_frac": 0.8,
51
+ "rope_theta": 10000.0,
52
+ "scheduler_sigmas": [
53
+ 1.0,
54
+ 0.9,
55
+ 0.75,
56
+ 0.3,
57
+ 0.0
58
+ ]
59
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:467ea80d76f63a82a2f5ba80d1170c27dc1060e60c3ede341105bff89b468700
3
+ size 7443340184
transformer/model.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """WorldModel transformer for frame generation.
17
+
18
+ Single-file model containing all building blocks: nn primitives, attention,
19
+ RoPE, quantization, inference caching, and the top-level WorldModel.
20
+ """
21
+
22
+ import warnings
23
+
24
+ import einops as eo
25
+ import torch
26
+ from torch import nn, Tensor
27
+ import torch.nn.functional as F
28
+ from tensordict import TensorDict
29
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
30
+ from diffusers.models.modeling_utils import ModelMixin
31
+
32
+ try:
33
+ from fbgemm_gpu.experimental.gen_ai.moe import index_shuffling
34
+ import fbgemm_gpu.experimental.gen_ai.moe.gather_scatter # noqa
35
+ HAS_FBGEMM = True
36
+ except ImportError:
37
+ HAS_FBGEMM = False
38
+
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # NN primitives
42
+ # ---------------------------------------------------------------------------
43
+
44
+ class NoCastModule(torch.nn.Module):
45
+ """Module that prevents dtype casting during .to() calls."""
46
+
47
+ def _apply(self, fn):
48
+ def keep_dtype(t):
49
+ old_dtype = t.dtype
50
+ out = fn(t)
51
+ if out.dtype is not old_dtype:
52
+ warnings.warn(
53
+ f"{self.__class__.__name__}: requested dtype cast ignored; "
54
+ f"keeping {old_dtype}.",
55
+ stacklevel=3,
56
+ )
57
+ out = out.to(dtype=old_dtype)
58
+ return out
59
+
60
+ return super()._apply(keep_dtype)
61
+
62
+ def to(self, *args, **kwargs):
63
+ warn_cast = False
64
+
65
+ if args and isinstance(args[0], torch.Tensor):
66
+ ref, *rest = args
67
+ args = (ref.device, *rest)
68
+ base = next(self.parameters(), None) or next(self.buffers(), None)
69
+ if base is not None and ref.dtype is not base.dtype:
70
+ warn_cast = True
71
+
72
+ if kwargs.pop("dtype", None) is not None:
73
+ warn_cast = True
74
+
75
+ args = tuple(a for a in args if not isinstance(a, torch.dtype))
76
+
77
+ if warn_cast:
78
+ warnings.warn(
79
+ f"{self.__class__.__name__}.to: requested dtype cast ignored; "
80
+ "keeping existing dtypes.",
81
+ stacklevel=2,
82
+ )
83
+
84
+ return super().to(*args, **kwargs)
85
+
86
+
87
+ def rms_norm(x: torch.Tensor) -> torch.Tensor:
88
+ """Root mean square layer normalization."""
89
+ return F.rms_norm(x, (x.size(-1),))
90
+
91
+
92
+ class MLP(nn.Module):
93
+ """Simple MLP with SiLU activation."""
94
+
95
+ def __init__(self, dim_in, dim_middle, dim_out):
96
+ super().__init__()
97
+ self.fc1 = nn.Linear(dim_in, dim_middle, bias=False)
98
+ self.fc2 = nn.Linear(dim_middle, dim_out, bias=False)
99
+
100
+ def forward(self, x):
101
+ return self.fc2(F.silu(self.fc1(x)))
102
+
103
+
104
+ class AdaLN(nn.Module):
105
+ """Adaptive Layer Normalization."""
106
+
107
+ def __init__(self, dim):
108
+ super().__init__()
109
+ self.fc = nn.Linear(dim, 2 * dim, bias=False)
110
+
111
+ def forward(self, x, cond):
112
+ b, n, d = cond.shape
113
+ _, nm, _ = x.shape
114
+ m = nm // n
115
+
116
+ y = F.silu(cond)
117
+ ab = self.fc(y) # [b, n, 2d]
118
+ ab = ab.view(b, n, 1, 2 * d) # [b, n, 1, 2d]
119
+ ab = ab.expand(-1, -1, m, -1) # [b, n, m, 2d]
120
+ ab = ab.reshape(b, nm, 2 * d) # [b, nm, 2d]
121
+
122
+ a, b_ = ab.chunk(2, dim=-1) # [b, nm, d] each
123
+ x = rms_norm(x) * (1 + a) + b_
124
+ return x
125
+
126
+
127
+ def ada_rmsnorm(x, scale, bias):
128
+ """Adaptive RMS normalization with scale and bias."""
129
+ x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=scale.size(1))
130
+ y4 = rms_norm(x4) * (1 + scale.unsqueeze(2)) + bias.unsqueeze(2)
131
+ return eo.rearrange(y4, "b n m d -> b (n m) d")
132
+
133
+
134
+ def ada_gate(x, gate):
135
+ """Apply gating to x with per-frame gates."""
136
+ x4 = eo.rearrange(x, "b (n m) d -> b n m d", n=gate.size(1))
137
+ return eo.rearrange(x4 * gate.unsqueeze(2), "b n m d -> b (n m) d")
138
+
139
+
140
+ class NoiseConditioner(NoCastModule):
141
+ """Sigma -> logSNR -> Fourier Features -> Dense embedding."""
142
+
143
+ def __init__(self, dim, fourier_dim=512, base=10_000.0):
144
+ super().__init__()
145
+ assert fourier_dim % 2 == 0
146
+ half = fourier_dim // 2
147
+ self.freq = nn.Buffer(
148
+ torch.logspace(0, -1, steps=half, base=base, dtype=torch.float32),
149
+ persistent=False,
150
+ )
151
+ self.mlp = MLP(fourier_dim, dim * 4, dim)
152
+
153
+ def forward(self, s, eps=torch.finfo(torch.float32).eps):
154
+ assert self.freq.dtype == torch.float32
155
+ orig_dtype, shape = s.dtype, s.shape
156
+
157
+ with torch.autocast("cuda", enabled=False):
158
+ s = s.reshape(-1).float()
159
+ s = s * 1000
160
+
161
+ phase = s[:, None] * self.freq[None, :]
162
+ emb = torch.cat((torch.sin(phase), torch.cos(phase)), dim=-1)
163
+ emb = emb * 2**0.5
164
+ emb = self.mlp(emb)
165
+
166
+ return emb.to(orig_dtype).view(*shape, -1)
167
+
168
+
169
+ # ---------------------------------------------------------------------------
170
+ # Attention
171
+ # ---------------------------------------------------------------------------
172
+
173
+ class OrthoRoPEAngles(NoCastModule):
174
+ """Computes RoPE angles on the fly each forward pass."""
175
+
176
+ def __init__(self, config):
177
+ super().__init__()
178
+ self.config = config
179
+
180
+ d_head = config.d_model // config.n_heads
181
+ torch._assert(d_head % 8 == 0, "d_head must be divisible by 8")
182
+ d_xy, d_t = d_head // 8, d_head // 4
183
+
184
+ nyq = float(getattr(config, "rope_nyquist_frac", 0.8))
185
+ max_freq = min(self.config.height, self.config.width) * nyq
186
+ n = (d_xy + 1) // 2
187
+ xy = (torch.linspace(1.0, max_freq / 2, n, dtype=torch.float32) * torch.pi).repeat_interleave(2)[:d_xy]
188
+
189
+ theta = float(getattr(config, "rope_theta", 10000.0))
190
+ inv_t = 1.0 / (theta ** (torch.arange(0, d_t, 2, dtype=torch.float32) / d_t))
191
+ inv_t = inv_t.repeat_interleave(2)
192
+
193
+ self.register_buffer("xy", xy, persistent=False)
194
+ self.register_buffer("inv_t", inv_t, persistent=False)
195
+
196
+ @torch.autocast("cuda", enabled=False)
197
+ def forward(self, pos_ids):
198
+ if not torch.compiler.is_compiling():
199
+ torch._assert(
200
+ (pos_ids["y_pos"].max() < self.config.height) & (pos_ids["x_pos"].max() < self.config.width),
201
+ f"pos_ids out of bounds, {self.config.height}, {self.config.width}"
202
+ )
203
+
204
+ x = (2.0 * pos_ids["x_pos"].float() + 1.0) / self.config.width - 1.0
205
+ y = (2.0 * pos_ids["y_pos"].float() + 1.0) / self.config.height - 1.0
206
+ t = pos_ids["t_pos"].float()
207
+
208
+ freqs = torch.cat(
209
+ (x.unsqueeze(-1) * self.xy, y.unsqueeze(-1) * self.xy, t.unsqueeze(-1) * self.inv_t),
210
+ dim=-1,
211
+ )
212
+ return freqs.cos()[:, None], freqs.sin()[:, None]
213
+
214
+
215
+ class OrthoRoPE(NoCastModule):
216
+ """Applies precomputed RoPE angles to input tensors."""
217
+
218
+ def __init__(self, config):
219
+ super().__init__()
220
+ self.config = config
221
+ assert not getattr(self.config, "has_audio", False)
222
+
223
+ @torch.autocast("cuda", enabled=False)
224
+ def forward(self, x, rope_angles):
225
+ cos, sin = rope_angles
226
+ x0, x1 = x.float().unfold(-1, 2, 2).unbind(-1)
227
+ y0 = x0 * cos - x1 * sin
228
+ y1 = x1 * cos + x0 * sin
229
+ return torch.cat((y0, y1), dim=-1).type_as(x)
230
+
231
+
232
+ class Attn(nn.Module):
233
+ """Self-attention with RoPE and optional GQA, value residual, and gated attention."""
234
+
235
+ def __init__(self, config, layer_idx):
236
+ super().__init__()
237
+ self.config = config
238
+ self.layer_idx = layer_idx
239
+
240
+ self.value_residual = getattr(config, "value_residual", False)
241
+ if self.value_residual:
242
+ self.v_lamb = nn.Parameter(torch.tensor(0.5))
243
+
244
+ self.n_heads = config.n_heads
245
+ self.n_kv_heads = getattr(config, "n_kv_heads", None) or config.n_heads
246
+ self.d_head = config.d_model // self.n_heads
247
+ assert config.d_model % self.n_heads == 0
248
+
249
+ self.enable_gqa = self.n_heads != self.n_kv_heads
250
+
251
+ self.q_proj = nn.Linear(config.d_model, self.n_heads * self.d_head, bias=False)
252
+ self.k_proj = nn.Linear(
253
+ config.d_model, self.n_kv_heads * self.d_head, bias=False
254
+ )
255
+ self.v_proj = nn.Linear(
256
+ config.d_model, self.n_kv_heads * self.d_head, bias=False
257
+ )
258
+ self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False)
259
+
260
+ self.rope = OrthoRoPE(config)
261
+
262
+ self.gated_attn = getattr(config, "gated_attn", False)
263
+ if self.gated_attn:
264
+ self.gate_proj = nn.Linear(
265
+ self.n_heads, self.n_heads, bias=False
266
+ )
267
+ nn.init.zeros_(self.gate_proj.weight)
268
+
269
+ def forward(self, x, pos_ids, rope_angles, v1, kv_cache):
270
+ from torch.nn.attention.flex_attention import flex_attention
271
+
272
+ q = eo.rearrange(
273
+ self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads, d=self.d_head
274
+ )
275
+ k = eo.rearrange(
276
+ self.k_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
277
+ )
278
+ v = eo.rearrange(
279
+ self.v_proj(x), "b t (h d) -> b h t d", h=self.n_kv_heads, d=self.d_head
280
+ )
281
+
282
+ if self.value_residual:
283
+ v1 = v if v1 is None else v1
284
+ v = torch.lerp(v, v1.view_as(v), self.v_lamb)
285
+
286
+ q, k = rms_norm(q), rms_norm(k)
287
+ q, k = self.rope(q, rope_angles), self.rope(k, rope_angles)
288
+
289
+ k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
290
+ y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
291
+
292
+ if self.gated_attn:
293
+ gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
294
+ y = y * gates.permute(0, 2, 1).unsqueeze(-1)
295
+ y = eo.rearrange(y, "b h t d -> b t (h d)")
296
+ y = self.out_proj(y)
297
+ return y, v1
298
+
299
+
300
+ class MergedQKVAttn(Attn):
301
+ def __init__(self, src: Attn, config):
302
+ super().__init__(config, src.layer_idx)
303
+ self.to(device=src.q_proj.weight.device, dtype=src.q_proj.weight.dtype)
304
+ self.load_state_dict(
305
+ src.state_dict(), strict=False
306
+ )
307
+ self.train(src.training)
308
+
309
+ self.q_out = self.n_heads * self.d_head
310
+ self.kv_out = self.n_kv_heads * self.d_head
311
+
312
+ self.qkv_proj = nn.Linear(
313
+ self.q_proj.in_features,
314
+ self.q_out + 2 * self.kv_out,
315
+ bias=False,
316
+ device=self.q_proj.weight.device,
317
+ dtype=self.q_proj.weight.dtype,
318
+ )
319
+ with torch.no_grad():
320
+ self.qkv_proj.weight.copy_(
321
+ torch.cat(
322
+ [self.q_proj.weight, self.k_proj.weight, self.v_proj.weight], dim=0
323
+ )
324
+ )
325
+
326
+ del self.q_proj, self.k_proj, self.v_proj
327
+
328
+ def forward(self, x, pos_ids, rope_angles, v1, kv_cache):
329
+ from torch.nn.attention.flex_attention import flex_attention
330
+
331
+ q, k, v = self.qkv_proj(x).split((self.q_out, self.kv_out, self.kv_out), dim=-1)
332
+
333
+ B, T = x.shape[:2]
334
+ q = q.reshape(B, T, self.n_heads, self.d_head).transpose(1, 2)
335
+ k = k.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
336
+ v = v.reshape(B, T, self.n_kv_heads, self.d_head).transpose(1, 2)
337
+
338
+ if self.value_residual:
339
+ v1 = v if v1 is None else v1
340
+ v = torch.lerp(v, v1.view_as(v), self.v_lamb)
341
+
342
+ q, k = rms_norm(q), rms_norm(k)
343
+ q, k = self.rope(q, rope_angles), self.rope(k, rope_angles)
344
+
345
+ k, v, bm = kv_cache.upsert(k, v, pos_ids, self.layer_idx)
346
+ y = flex_attention(q, k, v, block_mask=bm, enable_gqa=self.enable_gqa)
347
+
348
+ if self.gated_attn:
349
+ gates = torch.sigmoid(self.gate_proj(x[..., : self.n_heads]))
350
+ y = y * gates.permute(0, 2, 1).unsqueeze(-1)
351
+
352
+ y = y.transpose(1, 2).reshape(B, T, -1)
353
+ y = self.out_proj(y)
354
+ return y, v1
355
+
356
+
357
+ class CrossAttention(nn.Module):
358
+ """Cross-attention for prompt conditioning."""
359
+
360
+ def __init__(self, config, context_dim=None):
361
+ super().__init__()
362
+ assert config.d_model % config.n_heads == 0
363
+
364
+ self.d_head = config.d_model // config.n_heads
365
+ self.inner_dim = context_dim or config.d_model
366
+ assert self.inner_dim % self.d_head == 0
367
+ self.n_heads = self.inner_dim // self.d_head
368
+ self.q_proj = nn.Linear(config.d_model, self.inner_dim, bias=False)
369
+ self.k_proj = nn.Linear(
370
+ context_dim or config.d_model, self.inner_dim, bias=False
371
+ )
372
+ self.v_proj = nn.Linear(
373
+ context_dim or config.d_model, self.inner_dim, bias=False
374
+ )
375
+
376
+ self.out_proj = nn.Linear(self.inner_dim, config.d_model, bias=False)
377
+ self.out_proj.weight.detach().zero_()
378
+
379
+ def forward(self, x, context, context_pad_mask=None):
380
+ from torch.nn.attention.flex_attention import flex_attention
381
+
382
+ q = eo.rearrange(self.q_proj(x), "b t (h d) -> b h t d", h=self.n_heads)
383
+ k = eo.rearrange(self.k_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
384
+ v = eo.rearrange(self.v_proj(context), "b t (h d) -> b h t d", h=self.n_heads)
385
+ q, k = rms_norm(q), rms_norm(k)
386
+ out = flex_attention(q, k, v)
387
+ out = out.transpose(1, 2).contiguous().reshape(x.size(0), x.size(1), -1)
388
+ return self.out_proj(out)
389
+
390
+
391
+ # ---------------------------------------------------------------------------
392
+ # Inference caching
393
+ # ---------------------------------------------------------------------------
394
+
395
+ def _bf16_u16(x: Tensor) -> Tensor:
396
+ return x.contiguous().view(torch.int16).to(torch.int32) & 0xFFFF
397
+
398
+
399
+ class CachedDenoiseStepEmb(nn.Module):
400
+ """bf16 sigma -> bf16 embedding via 64k LUT."""
401
+
402
+ def __init__(self, base: nn.Module, sigmas: list[float]):
403
+ super().__init__()
404
+ device = next(base.parameters()).device
405
+
406
+ levels = torch.tensor(sigmas, device=device, dtype=torch.bfloat16)
407
+ bits = _bf16_u16(levels)
408
+ if torch.unique(bits).numel() != bits.numel():
409
+ raise ValueError(
410
+ "scheduler_sigmas collide in bf16; caching would be ambiguous"
411
+ )
412
+
413
+ with torch.no_grad():
414
+ table = (
415
+ base(levels[:, None]).squeeze(1).to(torch.bfloat16).contiguous()
416
+ )
417
+
418
+ lut = torch.full((65536,), -1, device=device, dtype=torch.int32)
419
+ lut[bits] = torch.arange(bits.numel(), device=device, dtype=torch.int32)
420
+
421
+ self.register_buffer("table", table, persistent=False)
422
+ self.register_buffer("lut", lut, persistent=False)
423
+ self.register_buffer(
424
+ "oob",
425
+ torch.tensor(bits.numel(), device=device, dtype=torch.int32),
426
+ persistent=False,
427
+ )
428
+
429
+ def forward(self, sigma: Tensor) -> Tensor:
430
+ if sigma.dtype is not torch.bfloat16:
431
+ raise RuntimeError("CachedDenoiseStepEmb expects sigma bf16")
432
+ idx = self.lut[_bf16_u16(sigma)]
433
+ idx = torch.where(idx >= 0, idx, self.oob)
434
+ return self.table[idx.to(torch.int64)]
435
+
436
+
437
+ class CachedCondHead(nn.Module):
438
+ """bf16 cond -> cached conditioning; invalid cond => OOB index error."""
439
+
440
+ def __init__(
441
+ self, base, cached_denoise_step_emb: CachedDenoiseStepEmb, max_key_dims: int = 8
442
+ ):
443
+ super().__init__()
444
+ table = cached_denoise_step_emb.table
445
+ S, D = table.shape
446
+
447
+ with torch.no_grad():
448
+ emb = table[:, None, :]
449
+ cache = (
450
+ torch.stack([t.squeeze(1) for t in base(emb)], 0)
451
+ .to(torch.bfloat16)
452
+ .contiguous()
453
+ )
454
+
455
+ key_dim = None
456
+ for d in range(min(D, max_key_dims)):
457
+ b = _bf16_u16(table[:, d])
458
+ if torch.unique(b).numel() == S:
459
+ key_dim = d
460
+ key_bits = b
461
+ break
462
+ if key_dim is None:
463
+ raise ValueError(
464
+ "Could not find a unique bf16 key dim for cond->sigma mapping"
465
+ )
466
+
467
+ lut = torch.full((65536,), -1, device=table.device, dtype=torch.int32)
468
+ lut[key_bits] = torch.arange(S, device=table.device, dtype=torch.int32)
469
+
470
+ self.key_dim = int(key_dim)
471
+ self.register_buffer("cache", cache, persistent=False)
472
+ self.register_buffer("lut", lut, persistent=False)
473
+ self.register_buffer(
474
+ "oob",
475
+ torch.tensor(S, device=table.device, dtype=torch.int32),
476
+ persistent=False,
477
+ )
478
+
479
+ def forward(self, cond: Tensor):
480
+ if cond.dtype is not torch.bfloat16:
481
+ raise RuntimeError("CachedCondHead expects cond bf16")
482
+ idx = self.lut[_bf16_u16(cond[..., self.key_dim])]
483
+ idx = torch.where(idx >= 0, idx, self.oob)
484
+ g = self.cache[:, idx.to(torch.int64)]
485
+ return tuple(g.unbind(0))
486
+
487
+
488
+ # ---------------------------------------------------------------------------
489
+ # Quantization
490
+ # ---------------------------------------------------------------------------
491
+
492
+ QUANTS = [None]
493
+
494
+ try:
495
+ from flashinfer import nvfp4_quantize, mm_fp4, SfLayout
496
+ QUANTS.append("nvfp4")
497
+ except ImportError:
498
+ pass
499
+
500
+
501
+ @torch.library.custom_op("world_engine::fp4_linear", mutates_args=())
502
+ def fp4_linear(
503
+ a_bf16: torch.Tensor,
504
+ b_fp4_T: torch.Tensor,
505
+ a_global_sf: torch.Tensor,
506
+ b_sf_T: torch.Tensor,
507
+ alpha: torch.Tensor,
508
+ ) -> torch.Tensor:
509
+ a_fp4, a_sf = nvfp4_quantize(
510
+ a_bf16, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False,
511
+ )
512
+ return mm_fp4(
513
+ a_fp4, b_fp4_T, a_sf, b_sf_T, alpha, out_dtype=torch.bfloat16, backend="cutlass"
514
+ )
515
+
516
+
517
+ @fp4_linear.register_fake
518
+ def _fp4_linear_fake(
519
+ a_bf16: torch.Tensor, b_fp4_T: torch.Tensor,
520
+ a_global_sf: torch.Tensor, b_sf_T: torch.Tensor, alpha: torch.Tensor,
521
+ ) -> torch.Tensor:
522
+ return torch.empty(
523
+ (a_bf16.shape[0], b_fp4_T.shape[1]), device=a_bf16.device, dtype=torch.bfloat16
524
+ )
525
+
526
+
527
+ class FP4Linear(nn.Module):
528
+ """FP4 Linear layer using FlashInfer's NVFP4 quantization."""
529
+
530
+ def __init__(self, lin: nn.Linear):
531
+ super().__init__()
532
+ self.in_features = lin.in_features
533
+ self.out_features = lin.out_features
534
+ assert self.in_features % 32 == 0 and self.out_features % 32 == 0
535
+
536
+ self.weight = nn.Parameter(lin.weight.detach().clone())
537
+ self._weight_fp4_T = None
538
+ self._weight_scales_T = None
539
+ self._alpha = None
540
+ self._dummy_scale = None
541
+ self._weight_global_sf = None
542
+
543
+ with torch.no_grad():
544
+ self._dummy_scale = torch.full((1,), 1.0, device=self.weight.device, dtype=torch.float32)
545
+ weight_bf16 = self.weight.to(torch.bfloat16).to(self.weight.device).contiguous()
546
+ weight_amax = weight_bf16.float().abs().nan_to_num().max()
547
+ self._weight_global_sf = (1.0) / weight_amax
548
+ self._alpha = 1.0 / (self._weight_global_sf * self._dummy_scale)
549
+ w_fp4, w_sf = nvfp4_quantize(
550
+ weight_bf16, self._weight_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False,
551
+ )
552
+ self._weight_fp4_T = w_fp4.t()
553
+ self._weight_scales_T = w_sf.t()
554
+
555
+ assert self.weight.is_cuda
556
+ lazy_x = torch.zeros((1, lin.in_features), device=self.weight.device, dtype=torch.bfloat16)
557
+ fp4_linear(lazy_x, self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha)
558
+
559
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
560
+ x_flat = x.reshape(-1, x.shape[-1])
561
+ y = fp4_linear(
562
+ x_flat.to(torch.bfloat16).contiguous(),
563
+ self._weight_fp4_T, self._dummy_scale, self._weight_scales_T, self._alpha,
564
+ )
565
+ return y.reshape(x.shape[:-1] + (-1,))
566
+
567
+
568
+ class FP8W8A8Linear(nn.Module):
569
+ __constants__ = ("in_features", "out_features")
570
+
571
+ def __init__(self, lin: nn.Linear):
572
+ super().__init__()
573
+ self.in_features, self.out_features = lin.in_features, lin.out_features
574
+ f8 = torch.float8_e4m3fn
575
+ inv = 1.0 / float(torch.finfo(f8).max)
576
+ self._inv = inv
577
+ w = lin.weight.detach()
578
+ ws = (w.abs().amax() * inv).clamp_min(1e-8).float()
579
+ wf8 = (w / ws.to(w.dtype)).to(f8).contiguous()
580
+ self.register_buffer("wT", wf8.t())
581
+ self.register_buffer("ws", ws)
582
+ if lin.bias is None:
583
+ self.bias = None
584
+ else:
585
+ self.register_buffer("bias", lin.bias.detach().to(torch.float16))
586
+
587
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
588
+ s = x.shape
589
+ x2 = x.reshape(-1, s[-1])
590
+ xs = (x2.abs().amax() * self._inv).clamp_min(1e-8).float()
591
+ xf8 = (x2 / xs.to(x2.dtype)).to(torch.float8_e4m3fn).contiguous()
592
+ y = torch._scaled_mm(
593
+ xf8, self.wT, xs, self.ws,
594
+ bias=self.bias, out_dtype=torch.float16, use_fast_accum=True,
595
+ )
596
+ return y.reshape(*s[:-1], self.out_features).to(x.dtype)
597
+
598
+
599
+ class FP8Linear(nn.Module):
600
+ def __init__(self, lin: nn.Linear):
601
+ super().__init__()
602
+ self.in_features, self.out_features = lin.in_features, lin.out_features
603
+ self.bias = (
604
+ nn.Parameter(lin.bias.data.clone().to(torch.float8_e4m3fn))
605
+ if lin.bias is not None else None
606
+ )
607
+ w_amax = lin.weight.data.abs().amax()
608
+ w = lin.weight.data.clone().div(w_amax).to(torch.float8_e4m3fn)
609
+ self.register_buffer("w_amax", w_amax)
610
+ self.register_buffer("weightT", w.t())
611
+ self.dummy_scale = torch.ones((), device=lin.weight.device, dtype=torch.float32)
612
+
613
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
614
+ x_fp8 = x.to(torch.float8_e4m3fn).reshape(-1, x.size(-1)).contiguous()
615
+ result = torch._scaled_mm(
616
+ x_fp8, self.weightT,
617
+ bias=self.bias, scale_a=self.dummy_scale, scale_b=self.w_amax,
618
+ out_dtype=torch.bfloat16, use_fast_accum=True,
619
+ )
620
+ return result.reshape(x.shape[:-1] + (-1,))
621
+
622
+
623
+ def quantize_model(model: nn.Module, quant: str):
624
+ if quant is None:
625
+ return model
626
+
627
+ def eligible(m: nn.Module) -> bool:
628
+ w = getattr(m, "weight", None)
629
+ if not isinstance(m, nn.Linear):
630
+ return False
631
+ if getattr(w, "dtype", None) != torch.bfloat16:
632
+ return False
633
+ o, k = w.shape
634
+ return (o % 32 == 0) and (k % 32 == 0)
635
+
636
+ new_linear = {"w8a8": FP8W8A8Linear, "nvfp4": FP4Linear, "fp8": FP8Linear}[quant]
637
+
638
+ for name, child in model.named_children():
639
+ setattr(model, name, new_linear(child)) if eligible(child) else quantize_model(child, quant)
640
+ return model
641
+
642
+
643
+ # ---------------------------------------------------------------------------
644
+ # Inference patches
645
+ # ---------------------------------------------------------------------------
646
+
647
+ def patch_cached_noise_conditioning(model) -> None:
648
+ cached_denoise_step_emb = CachedDenoiseStepEmb(
649
+ model.denoise_step_emb, model.config.scheduler_sigmas
650
+ )
651
+ model.denoise_step_emb = cached_denoise_step_emb
652
+ for blk in model.transformer.blocks:
653
+ blk.attn_cond_head = CachedCondHead(blk.attn_cond_head, cached_denoise_step_emb)
654
+ blk.mlp_cond_head = CachedCondHead(blk.mlp_cond_head, cached_denoise_step_emb)
655
+
656
+
657
+ def patch_Attn_merge_qkv(model) -> None:
658
+ for name, mod in list(model.named_modules()):
659
+ if isinstance(mod, Attn) and not isinstance(mod, MergedQKVAttn):
660
+ model.set_submodule(name, MergedQKVAttn(mod, model.config))
661
+
662
+
663
+ def _apply_inference_patches(model) -> None:
664
+ patch_cached_noise_conditioning(model)
665
+ patch_Attn_merge_qkv(model)
666
+
667
+
668
+ # ---------------------------------------------------------------------------
669
+ # Model components
670
+ # ---------------------------------------------------------------------------
671
+
672
+ class CFG(nn.Module):
673
+ def __init__(self, d_model: int, dropout: float):
674
+ super().__init__()
675
+ self.dropout = dropout
676
+ self.null_emb = nn.Parameter(torch.zeros(1, 1, d_model))
677
+
678
+ def forward(
679
+ self, x: torch.Tensor, is_conditioned: bool | None = None
680
+ ) -> torch.Tensor:
681
+ B, L, _ = x.shape
682
+ null = self.null_emb.expand(B, L, -1)
683
+
684
+ if self.training or is_conditioned is None:
685
+ if self.dropout == 0.0:
686
+ return x
687
+ drop = torch.rand(B, 1, 1, device=x.device) < self.dropout
688
+ return torch.where(drop, null, x)
689
+
690
+ return x if is_conditioned else null
691
+
692
+
693
+ class ControllerInputEmbedding(nn.Module):
694
+ """Embeds controller inputs (mouse + buttons) into model dimension."""
695
+
696
+ def __init__(self, n_buttons: int, d_model: int, mlp_ratio: int = 4):
697
+ super().__init__()
698
+ self.mlp = MLP(n_buttons + 3, d_model * mlp_ratio, d_model)
699
+
700
+ def forward(self, mouse: Tensor, button: Tensor, scroll: Tensor):
701
+ assert len(mouse.shape) == 3
702
+ x = torch.cat((mouse, button, scroll), dim=-1)
703
+ return self.mlp(x)
704
+
705
+
706
+ class MLPFusion(nn.Module):
707
+ """Fuses per-group conditioning into tokens via split linear projections."""
708
+
709
+ def __init__(self, d_model: int):
710
+ super().__init__()
711
+ self.fc1_x = nn.Linear(d_model, d_model, bias=False)
712
+ self.fc1_c = nn.Linear(d_model, d_model, bias=False)
713
+ self.fc2 = nn.Linear(d_model, d_model, bias=False)
714
+
715
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
716
+ B, _, D = x.shape
717
+ L = cond.shape[1]
718
+ x = x.reshape(B, L, -1, D)
719
+ return self.fc2(F.silu(self.fc1_x(x) + self.fc1_c(cond).unsqueeze(2))).flatten(
720
+ 1, 2
721
+ )
722
+
723
+
724
+ class MoEWithoutFBGEMM(nn.Module):
725
+ """MoE implementation using torch grouped_mm (no fbgemm dependency)."""
726
+
727
+ def __init__(self, config):
728
+ super().__init__()
729
+ self.config = config
730
+ self.top_k = config.moe_top_k
731
+ moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or config.mlp_ratio / config.moe_top_k
732
+ d_intermediate = int(config.d_model * moe_mlp_ratio)
733
+ self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False)
734
+ self.expert_in_proj = nn.Parameter(
735
+ torch.empty(config.moe_n_experts, d_intermediate * (2 if config.gated_linear else 1), config.d_model)
736
+ )
737
+ self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_intermediate))
738
+
739
+ def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor:
740
+ if self.training or torch.is_grad_enabled():
741
+ raise NotImplementedError("inference only")
742
+
743
+ orig_shape = x.shape
744
+ x = x.reshape(-1, orig_shape[-1])
745
+ logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1))
746
+
747
+ logits_fp32 = logits.float()
748
+ scores, expert = logits.topk(self.top_k, dim=-1, sorted=False)
749
+ weights = (scores.float() - logits_fp32.logsumexp(dim=-1, keepdim=True)).exp().to(x.dtype)
750
+
751
+ expert = expert.flatten()
752
+ expert_sorted, sort_idx = expert.sort()
753
+ expert_ids = torch.arange(self.expert_in_proj.size(0), device=expert.device, dtype=expert_sorted.dtype)
754
+ offsets = torch.searchsorted(expert_sorted, expert_ids, right=True).to(torch.int32)
755
+
756
+ src = sort_idx // self.top_k
757
+ x_grouped = x.index_select(0, torch.cat((src, src[:1]), dim=0))
758
+ h = F.grouped_mm(x_grouped, self.expert_in_proj.transpose(-2, -1), offs=offsets)
759
+ h[-1].zero_()
760
+
761
+ if self.config.gated_linear:
762
+ gate_act, up = h.chunk(2, dim=-1)
763
+ h = F.silu(gate_act) * up
764
+ else:
765
+ h = F.silu(h)
766
+
767
+ y_grouped = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offsets)[:-1]
768
+ y = torch.empty_like(y_grouped).index_copy_(0, sort_idx, y_grouped).view(x.size(0), self.top_k, -1)
769
+ return (y * weights.unsqueeze(-1)).sum(dim=1).reshape(orig_shape)
770
+
771
+
772
+ class MoE(nn.Module):
773
+ """MoE implementation using fbgemm optimized kernels."""
774
+
775
+ def __init__(self, config):
776
+ super().__init__()
777
+ self.config = config
778
+ self.top_k = config.moe_top_k
779
+ moe_mlp_ratio = getattr(config, "moe_mlp_ratio", None) or (config.mlp_ratio / config.moe_top_k)
780
+ d_int = int(config.d_model * moe_mlp_ratio)
781
+
782
+ self.router = nn.Linear(config.d_model, config.moe_n_experts, bias=False)
783
+ self.expert_in_proj = nn.Parameter(
784
+ torch.empty(config.moe_n_experts, d_int * (2 if config.gated_linear else 1), config.d_model)
785
+ )
786
+ self.expert_out_proj = nn.Parameter(torch.empty(config.moe_n_experts, config.d_model, d_int))
787
+
788
+ def forward(self, x: torch.Tensor, gate: torch.Tensor | None = None) -> torch.Tensor:
789
+ if self.training or torch.is_grad_enabled():
790
+ raise NotImplementedError("inference only")
791
+
792
+ orig = x.shape
793
+ x = x.reshape(-1, orig[-1])
794
+ logits = self.router(x) if gate is None else gate.reshape(-1, gate.size(-1))
795
+
796
+ logits32 = logits.float()
797
+ token_counts, expert_sorted, src = index_shuffling(logits32, top_k=self.top_k)
798
+
799
+ E = self.expert_in_proj.size(0)
800
+ offs = token_counts[:E].cumsum(0).to(torch.int32)
801
+
802
+ src = src.to(torch.long)
803
+ expert_sorted = expert_sorted.to(torch.long)
804
+ logZ = logits32.logsumexp(-1)
805
+ w = (logits32[src, expert_sorted] - logZ[src]).exp().to(x.dtype)
806
+
807
+ xg = x.index_select(0, torch.cat((src, src[:1]), 0))
808
+ h = F.grouped_mm(xg, self.expert_in_proj.transpose(-2, -1), offs=offs)
809
+ if self.config.gated_linear:
810
+ ga, up = h.chunk(2, -1)
811
+ h = F.silu(ga) * up
812
+ else:
813
+ h = F.silu(h)
814
+
815
+ yg = F.grouped_mm(h, self.expert_out_proj.transpose(-2, -1), offs=offs)[:-1]
816
+ out = torch.zeros_like(x)
817
+ torch.ops.fbgemm.scatter_add_dense_tokens(out, (yg * w.unsqueeze(-1)).contiguous(), src)
818
+ return out.reshape(orig)
819
+
820
+
821
+ class CondHead(nn.Module):
822
+ """Per-layer conditioning head: bias_in -> SiLU -> Linear -> chunk(n_cond)."""
823
+
824
+ def __init__(self, d_model: int, noise_conditioning: str = "wan", n_cond: int = 3):
825
+ super().__init__()
826
+ self.bias_in = (
827
+ nn.Parameter(torch.zeros(d_model)) if noise_conditioning == "wan" else None
828
+ )
829
+ self.cond_proj = nn.ModuleList(
830
+ [nn.Linear(d_model, d_model, bias=False) for _ in range(n_cond)]
831
+ )
832
+
833
+ def forward(self, cond):
834
+ cond = cond + self.bias_in if self.bias_in is not None else cond
835
+ h = F.silu(cond)
836
+ return tuple(p(h) for p in self.cond_proj)
837
+
838
+
839
+ # ---------------------------------------------------------------------------
840
+ # Transformer blocks
841
+ # ---------------------------------------------------------------------------
842
+
843
+ class WorldDiTBlock(nn.Module):
844
+ """Single transformer block with self-attention, optional cross-attention, and MLP."""
845
+
846
+ def __init__(
847
+ self, d_model, n_heads, mlp_ratio, layer_idx,
848
+ prompt_conditioning, prompt_conditioning_period, prompt_embedding_dim,
849
+ ctrl_conditioning_period, noise_conditioning, config,
850
+ ):
851
+ super().__init__()
852
+ self.config = config
853
+ self.attn = Attn(config, layer_idx)
854
+ if getattr(config, "moe", False):
855
+ self.dit_mlp = MoE(config) if HAS_FBGEMM else MoEWithoutFBGEMM(config)
856
+ else:
857
+ self.dit_mlp = MLP(d_model, d_model * mlp_ratio, d_model)
858
+ self.attn_cond_head = CondHead(d_model, noise_conditioning, n_cond=3)
859
+ self.mlp_cond_head = CondHead(d_model, noise_conditioning, n_cond=3)
860
+
861
+ do_prompt_cond = (
862
+ prompt_conditioning is not None
863
+ and layer_idx % prompt_conditioning_period == 0
864
+ )
865
+ self.prompt_cross_attn = (
866
+ CrossAttention(config, prompt_embedding_dim) if do_prompt_cond else None
867
+ )
868
+ do_ctrl_cond = ctrl_conditioning_period is not None and layer_idx % ctrl_conditioning_period == 0
869
+ self.ctrl_mlpfusion = MLPFusion(d_model) if do_ctrl_cond else None
870
+
871
+ def forward(self, x, pos_ids, rope_angles, cond, ctx, v, kv_cache=None):
872
+ s0, b0, g0 = self.attn_cond_head(cond)
873
+ s1, b1, g1 = self.mlp_cond_head(cond)
874
+
875
+ residual = x
876
+ x = ada_rmsnorm(x, s0, b0)
877
+ x, v = self.attn(x, pos_ids, rope_angles, v, kv_cache=kv_cache)
878
+ x = ada_gate(x, g0) + residual
879
+
880
+ if self.prompt_cross_attn is not None:
881
+ x = (
882
+ self.prompt_cross_attn(
883
+ rms_norm(x),
884
+ context=rms_norm(ctx["prompt_emb"]),
885
+ context_pad_mask=ctx["prompt_pad_mask"],
886
+ )
887
+ + x
888
+ )
889
+
890
+ if self.ctrl_mlpfusion is not None:
891
+ x = self.ctrl_mlpfusion(rms_norm(x), rms_norm(ctx["ctrl_emb"])) + x
892
+
893
+ x = ada_gate(self.dit_mlp(ada_rmsnorm(x, s1, b1)), g1) + x
894
+
895
+ return x, v
896
+
897
+
898
+ class WorldDiT(nn.Module):
899
+ """Stack of WorldDiTBlocks with shared parameters."""
900
+
901
+ def __init__(self, config):
902
+ super().__init__()
903
+ self.config = config
904
+ self.blocks = nn.ModuleList(
905
+ [
906
+ WorldDiTBlock(
907
+ d_model=config.d_model,
908
+ n_heads=config.n_heads,
909
+ mlp_ratio=config.mlp_ratio,
910
+ layer_idx=idx,
911
+ prompt_conditioning=config.prompt_conditioning,
912
+ prompt_conditioning_period=config.prompt_conditioning_period,
913
+ prompt_embedding_dim=config.prompt_embedding_dim,
914
+ ctrl_conditioning_period=config.ctrl_conditioning_period,
915
+ noise_conditioning=config.noise_conditioning,
916
+ config=config,
917
+ )
918
+ for idx in range(config.n_layers)
919
+ ]
920
+ )
921
+ self.rope_angles = OrthoRoPEAngles(config)
922
+
923
+ def forward(self, x, pos_ids, cond, ctx, kv_cache=None):
924
+ rope_angles = self.rope_angles(pos_ids)
925
+ v = None
926
+ for i, block in enumerate(self.blocks):
927
+ x, v = block(x, pos_ids, rope_angles, cond, ctx, v, kv_cache=kv_cache)
928
+ return x
929
+
930
+
931
+ # ---------------------------------------------------------------------------
932
+ # Top-level model
933
+ # ---------------------------------------------------------------------------
934
+
935
+ class WorldModel(ModelMixin, ConfigMixin):
936
+ """
937
+ WORLD: Wayfarer Operator-driven Rectified-flow Long-context Diffuser.
938
+
939
+ Denoises a frame given:
940
+ - All previous frames (via KV cache)
941
+ - The prompt embedding
942
+ - The controller input embedding
943
+ - The current noise level
944
+ """
945
+
946
+ _supports_gradient_checkpointing = False
947
+ _keep_in_fp32_modules = ["denoise_step_emb", "rope_angles"]
948
+
949
+ @register_to_config
950
+ def __init__(
951
+ self,
952
+ d_model: int = 2048,
953
+ n_heads: int = 32,
954
+ n_kv_heads: int | None = None,
955
+ n_layers: int = 24,
956
+ mlp_ratio: int = 4,
957
+ channels: int = 32,
958
+ height: int = 16,
959
+ width: int = 16,
960
+ patch: tuple = (2, 2),
961
+ tokens_per_frame: int = 256,
962
+ n_frames: int = 4096,
963
+ local_window: int = 16,
964
+ global_window: int = 128,
965
+ global_attn_period: int = 4,
966
+ global_pinned_dilation: int = 8,
967
+ global_attn_offset: int = 0,
968
+ value_residual: bool = True,
969
+ gated_attn: bool = False,
970
+ n_buttons: int = 256,
971
+ ctrl_conditioning: str | None = "mlp_fusion",
972
+ ctrl_conditioning_period: int | None = 3,
973
+ ctrl_cond_dropout: float = 0.0,
974
+ prompt_conditioning: str | None = None,
975
+ prompt_conditioning_period: int = 3,
976
+ prompt_embedding_dim: int = 2048,
977
+ prompt_cond_dropout: float = 0.0,
978
+ noise_conditioning: str = "wan",
979
+ scheduler_sigmas: list[float] | None = [
980
+ 1.0, 0.8609585762023926, 0.729332447052002, 0.3205108940601349, 0.0,
981
+ ],
982
+ base_fps: int = 60,
983
+ causal: bool = True,
984
+ mlp_gradient_checkpointing: bool = True,
985
+ block_gradient_checkpointing: bool = True,
986
+ rope_impl: str = "ortho",
987
+ moe: bool = False,
988
+ moe_top_k: int = 2,
989
+ moe_n_experts: int = 8,
990
+ moe_mlp_ratio: float | None = None,
991
+ gated_linear: bool = False,
992
+ temporal_compression: int = 1,
993
+ inference_fps: int | None = None,
994
+ taehv_ae: bool = False,
995
+ rope_nyquist_frac: float = 0.8,
996
+ rope_theta: float = 10000.0,
997
+ ):
998
+ super().__init__()
999
+
1000
+ self.denoise_step_emb = NoiseConditioner(d_model)
1001
+ self.ctrl_emb = ControllerInputEmbedding(n_buttons, d_model, mlp_ratio)
1002
+
1003
+ if self.config.ctrl_conditioning is not None:
1004
+ self.ctrl_cfg = CFG(self.config.d_model, self.config.ctrl_cond_dropout)
1005
+ if self.config.prompt_conditioning is not None:
1006
+ self.prompt_cfg = CFG(
1007
+ self.config.prompt_embedding_dim, self.config.prompt_cond_dropout
1008
+ )
1009
+
1010
+ self.transformer = WorldDiT(self.config)
1011
+ self.patch = tuple(patch)
1012
+
1013
+ C, D = channels, d_model
1014
+ self.patchify = nn.Conv2d(
1015
+ C, D, kernel_size=self.patch, stride=self.patch, bias=False
1016
+ )
1017
+ self.unpatchify = nn.ConvTranspose2d(
1018
+ D, C, kernel_size=self.patch, stride=self.patch, bias=True
1019
+ )
1020
+ self.out_norm = AdaLN(d_model)
1021
+
1022
+ T = tokens_per_frame
1023
+ idx = torch.arange(T, dtype=torch.long)
1024
+ self.register_buffer(
1025
+ "_t_pos_1f", torch.empty(T, dtype=torch.long), persistent=False
1026
+ )
1027
+ self.register_buffer(
1028
+ "_y_pos_1f", idx.div(width, rounding_mode="floor"), persistent=False
1029
+ )
1030
+ self.register_buffer("_x_pos_1f", idx.remainder(width), persistent=False)
1031
+
1032
+ def forward(
1033
+ self,
1034
+ x: Tensor,
1035
+ sigma: Tensor,
1036
+ frame_timestamp: Tensor,
1037
+ frame_idx: Tensor | None = None,
1038
+ prompt_emb: Tensor | None = None,
1039
+ prompt_pad_mask: Tensor | None = None,
1040
+ mouse: Tensor | None = None,
1041
+ button: Tensor | None = None,
1042
+ scroll: Tensor | None = None,
1043
+ kv_cache=None,
1044
+ ):
1045
+ B, N, C, H, W = x.shape
1046
+ ph, pw = self.patch
1047
+ assert (H % ph == 0) and (W % pw == 0), "H, W must be divisible by patch"
1048
+ Hp, Wp = H // ph, W // pw
1049
+ torch._assert(
1050
+ Hp * Wp == self.config.tokens_per_frame,
1051
+ f"{Hp} * {Wp} != {self.config.tokens_per_frame}",
1052
+ )
1053
+
1054
+ torch._assert(
1055
+ B == 1 and N == 1, "WorldModel.forward currently supports B==1, N==1"
1056
+ )
1057
+ self._t_pos_1f.copy_(frame_timestamp[0, 0].expand_as(self._t_pos_1f))
1058
+ pos_ids = TensorDict(
1059
+ {
1060
+ "f_pos": (frame_timestamp if frame_idx is None else frame_idx)[0, 0].expand_as(self._t_pos_1f)[None],
1061
+ "t_pos": self._t_pos_1f[None],
1062
+ "y_pos": self._y_pos_1f[None],
1063
+ "x_pos": self._x_pos_1f[None],
1064
+ },
1065
+ batch_size=[1, self._t_pos_1f.numel()],
1066
+ )
1067
+ cond = self.denoise_step_emb(sigma)
1068
+
1069
+ assert button is not None
1070
+ ctx = {
1071
+ "ctrl_emb": self.ctrl_emb(mouse, button, scroll),
1072
+ "prompt_emb": prompt_emb,
1073
+ "prompt_pad_mask": prompt_pad_mask,
1074
+ }
1075
+
1076
+ D = self.config.d_model
1077
+ x = self.patchify(x.reshape(B * N, C, H, W))
1078
+ x = eo.rearrange(x.view(B, N, D, Hp, Wp), "b n d hp wp -> b (n hp wp) d")
1079
+ x = self.transformer(x, pos_ids, cond, ctx, kv_cache)
1080
+ x = F.silu(self.out_norm(x, cond))
1081
+ x = eo.rearrange(x, "b (n hp wp) d -> (b n) d hp wp", n=N, hp=Hp, wp=Wp)
1082
+ x = self.unpatchify(x)
1083
+ x = x.view(B, N, C, H, W)
1084
+
1085
+ return x
1086
+
1087
+ def get_active_parameters(self) -> int:
1088
+ total = sum(p.numel() for p in self.parameters())
1089
+ c = self.config
1090
+ if getattr(c, "moe", False):
1091
+ moe_mlp_ratio = getattr(c, "moe_mlp_ratio", None) or c.mlp_ratio / c.moe_top_k
1092
+ hidden, top_k = int(c.d_model * moe_mlp_ratio), min(c.moe_top_k, c.moe_n_experts)
1093
+ total -= (c.moe_n_experts - top_k) * c.n_layers * c.d_model * hidden * (3 if c.gated_linear else 2)
1094
+ return total
1095
+
1096
+ def quantize(self, quant_type: str):
1097
+ quantize_model(self, quant_type)
1098
+
1099
+ def apply_inference_patches(self):
1100
+ _apply_inference_patches(self)
vae/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from .ae_model import ChunkedStreamingTAEHV
17
+
18
+ __all__ = ["ChunkedStreamingTAEHV"]
vae/ae_model.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2025 Hugging Face Team and Overworld
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ """Streaming TAEHV autoencoder for WorldEngine wp-1.5 temporal-compressed latent decoding."""
17
+
18
+ from collections import namedtuple
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ from torch import Tensor
23
+
24
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
25
+ from diffusers.models.modeling_utils import ModelMixin
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Building blocks (mirror the taehv library)
30
+ # ---------------------------------------------------------------------------
31
+
32
+ TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
33
+
34
+
35
+ def _conv(n_in, n_out, **kwargs):
36
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
37
+
38
+
39
+ class Clamp(nn.Module):
40
+ def forward(self, x):
41
+ return torch.tanh(x / 3) * 3
42
+
43
+
44
+ class MemBlock(nn.Module):
45
+ def __init__(self, n_in, n_out):
46
+ super().__init__()
47
+ self.conv = nn.Sequential(
48
+ _conv(n_in * 2, n_out),
49
+ nn.ReLU(inplace=True),
50
+ _conv(n_out, n_out),
51
+ nn.ReLU(inplace=True),
52
+ _conv(n_out, n_out),
53
+ )
54
+ self.skip = (
55
+ nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
56
+ )
57
+ self.act = nn.ReLU(inplace=True)
58
+
59
+ def forward(self, x, past):
60
+ return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
61
+
62
+
63
+ class TPool(nn.Module):
64
+ def __init__(self, n_f, stride):
65
+ super().__init__()
66
+ self.stride = stride
67
+ self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
68
+
69
+ def forward(self, x):
70
+ _NT, C, H, W = x.shape
71
+ return self.conv(x.reshape(-1, self.stride * C, H, W))
72
+
73
+
74
+ class TGrow(nn.Module):
75
+ def __init__(self, n_f, stride):
76
+ super().__init__()
77
+ self.stride = stride
78
+ self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
79
+
80
+ def forward(self, x):
81
+ _NT, C, H, W = x.shape
82
+ x = self.conv(x)
83
+ return x.reshape(-1, C, H, W)
84
+
85
+
86
+ # ---------------------------------------------------------------------------
87
+ # Sequential streaming helpers
88
+ # ---------------------------------------------------------------------------
89
+
90
+ def _sequential_single_step(model, memory, work_queue):
91
+ """Process the work queue until an output frame is produced or the queue is empty."""
92
+ while work_queue:
93
+ xt, i = work_queue.pop(0)
94
+ if i == len(model):
95
+ return xt.unsqueeze(1)
96
+ b = model[i]
97
+ if isinstance(b, MemBlock):
98
+ if memory[i] is None:
99
+ xt_new = b(xt, xt * 0)
100
+ else:
101
+ xt_new = b(xt, memory[i])
102
+ memory[i] = xt
103
+ work_queue.insert(0, TWorkItem(xt_new, i + 1))
104
+ elif isinstance(b, TPool):
105
+ if memory[i] is None:
106
+ memory[i] = []
107
+ memory[i].append(xt)
108
+ if len(memory[i]) == b.stride:
109
+ N, C, H, W = xt.shape
110
+ xt = b(torch.cat(memory[i], 1).view(N * b.stride, C, H, W))
111
+ memory[i] = []
112
+ work_queue.insert(0, TWorkItem(xt, i + 1))
113
+ elif isinstance(b, TGrow):
114
+ xt = b(xt)
115
+ NT, C, H, W = xt.shape
116
+ for xt_next in reversed(
117
+ xt.view(NT // b.stride, b.stride * C, H, W).chunk(b.stride, 1)
118
+ ):
119
+ work_queue.insert(0, TWorkItem(xt_next, i + 1))
120
+ else:
121
+ xt = b(xt)
122
+ work_queue.insert(0, TWorkItem(xt, i + 1))
123
+ return None
124
+
125
+
126
+ def _apply_parallel(model, x):
127
+ """Apply model with parallelization over time axis. x: NTCHW."""
128
+ N, T, C, H, W = x.shape
129
+ x = x.reshape(N * T, C, H, W)
130
+ for b in model:
131
+ if isinstance(b, MemBlock):
132
+ NT, C, H, W = x.shape
133
+ T = NT // N
134
+ _x = x.reshape(N, T, C, H, W)
135
+ block_memory = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(
136
+ x.shape
137
+ )
138
+ x = b(x, block_memory)
139
+ else:
140
+ x = b(x)
141
+ NT, C, H, W = x.shape
142
+ T = NT // N
143
+ return x.view(N, T, C, H, W)
144
+
145
+
146
+ # ---------------------------------------------------------------------------
147
+ # ChunkedStreamingTAEHV
148
+ # ---------------------------------------------------------------------------
149
+
150
+ class ChunkedStreamingTAEHV(ModelMixin, ConfigMixin):
151
+ """Streaming TAEHV autoencoder for temporal-compressed latent decoding.
152
+
153
+ Owns the encoder/decoder weights directly so diffusers can load them
154
+ from safetensors. Provides a streaming interface that processes one
155
+ temporal chunk at a time, maintaining internal state across calls.
156
+ """
157
+
158
+ _supports_gradient_checkpointing = False
159
+
160
+ @register_to_config
161
+ def __init__(
162
+ self,
163
+ latent_channels: int = 32,
164
+ patch_size: int = 2,
165
+ image_channels: int = 3,
166
+ encoder_time_downscale: tuple[bool, ...] = (True, True, False),
167
+ decoder_time_upscale: tuple[bool, ...] = (False, True, True),
168
+ decoder_space_upscale: tuple[bool, ...] = (True, True, True),
169
+ ):
170
+ super().__init__()
171
+
172
+ in_ch = image_channels * patch_size ** 2
173
+
174
+ self.encoder = nn.Sequential(
175
+ _conv(in_ch, 64), nn.ReLU(inplace=True),
176
+ TPool(64, 2 if encoder_time_downscale[0] else 1),
177
+ _conv(64, 64, stride=2, bias=False),
178
+ MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
179
+ TPool(64, 2 if encoder_time_downscale[1] else 1),
180
+ _conv(64, 64, stride=2, bias=False),
181
+ MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
182
+ TPool(64, 2 if encoder_time_downscale[2] else 1),
183
+ _conv(64, 64, stride=2, bias=False),
184
+ MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
185
+ _conv(64, latent_channels),
186
+ )
187
+
188
+ n_f = [256, 128, 64, 64]
189
+ self.decoder = nn.Sequential(
190
+ Clamp(),
191
+ _conv(latent_channels, n_f[0]), nn.ReLU(inplace=True),
192
+ MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]),
193
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
194
+ TGrow(n_f[0], 2 if decoder_time_upscale[0] else 1),
195
+ _conv(n_f[0], n_f[1], bias=False),
196
+ MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]),
197
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
198
+ TGrow(n_f[1], 2 if decoder_time_upscale[1] else 1),
199
+ _conv(n_f[1], n_f[2], bias=False),
200
+ MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]),
201
+ nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
202
+ TGrow(n_f[2], 2 if decoder_time_upscale[2] else 1),
203
+ _conv(n_f[2], n_f[3], bias=False),
204
+ nn.ReLU(inplace=True),
205
+ _conv(n_f[3], image_channels * patch_size ** 2),
206
+ )
207
+
208
+ # Computed properties
209
+ self.t_downscale = 2 ** sum(
210
+ t.stride == 2 for t in self.encoder if isinstance(t, TPool)
211
+ )
212
+ self.t_upscale = 2 ** sum(
213
+ t.stride == 2 for t in self.decoder if isinstance(t, TGrow)
214
+ )
215
+ self.frames_to_trim = self.t_upscale - 1
216
+ self.patch_size = patch_size
217
+
218
+ # Streaming state (initialised on first use / reset)
219
+ self._encoder_work_queue: list[TWorkItem] = []
220
+ self._encoder_memory: list = [None] * len(self.encoder)
221
+ self._decoder_work_queue: list[TWorkItem] = []
222
+ self._decoder_memory: list = [None] * len(self.decoder)
223
+ self._n_frames_encoded: int = 0
224
+ self._n_frames_decoded: int = 0
225
+ self._last_encoder_input_frame: Tensor | None = None
226
+
227
+ # ------------------------------------------------------------------
228
+ # Streaming state management
229
+ # ------------------------------------------------------------------
230
+
231
+ def reset(self):
232
+ """Reset streaming state for a new sequence."""
233
+ self._encoder_work_queue = []
234
+ self._encoder_memory = [None] * len(self.encoder)
235
+ self._decoder_work_queue = []
236
+ self._decoder_memory = [None] * len(self.decoder)
237
+ self._n_frames_encoded = 0
238
+ self._n_frames_decoded = 0
239
+ self._last_encoder_input_frame = None
240
+
241
+ # ------------------------------------------------------------------
242
+ # Pre/post processing
243
+ # ------------------------------------------------------------------
244
+
245
+ def _preprocess_input_frames(self, x: Tensor) -> Tensor:
246
+ if self.patch_size > 1:
247
+ x = F.pixel_unshuffle(x, self.patch_size)
248
+ return x
249
+
250
+ def _postprocess_output_frames(self, x: Tensor) -> Tensor:
251
+ if self.patch_size > 1:
252
+ x = F.pixel_shuffle(x, self.patch_size)
253
+ return x.clamp_(0, 1)
254
+
255
+ # ------------------------------------------------------------------
256
+ # Streaming encode / decode (one chunk at a time)
257
+ # ------------------------------------------------------------------
258
+
259
+ def _streaming_encode_step(self, x: Tensor | None = None) -> Tensor | None:
260
+ """Feed an input frame and try to produce an encoder output.
261
+
262
+ Args:
263
+ x: N1CHW RGB frame tensor with values in [0, 1], or None.
264
+ Returns:
265
+ N1CHW latent tensor, or None if not enough input accumulated.
266
+ """
267
+ if x is not None:
268
+ self._last_encoder_input_frame = x[:, -1:]
269
+ x = self._preprocess_input_frames(x)
270
+ self._encoder_work_queue.extend(
271
+ TWorkItem(xt, 0) for xt in x.unbind(1)
272
+ )
273
+ self._n_frames_encoded += x.shape[1]
274
+ return _sequential_single_step(
275
+ self.encoder, self._encoder_memory, self._encoder_work_queue
276
+ )
277
+
278
+ def _streaming_decode_step(self, x: Tensor | None = None) -> Tensor | None:
279
+ """Feed a latent and try to produce a decoded frame.
280
+
281
+ Args:
282
+ x: N1CHW latent tensor, or None to retrieve the next pending frame.
283
+ Returns:
284
+ N1CHW decoded RGB frame tensor, or None.
285
+ """
286
+ if x is not None:
287
+ self._decoder_work_queue.extend(
288
+ TWorkItem(xt, 0) for xt in x.unbind(1)
289
+ )
290
+ while True:
291
+ xt = _sequential_single_step(
292
+ self.decoder, self._decoder_memory, self._decoder_work_queue
293
+ )
294
+ if xt is None:
295
+ return None
296
+ self._n_frames_decoded += 1
297
+ if self._n_frames_decoded <= self.frames_to_trim:
298
+ continue
299
+ return self._postprocess_output_frames(xt)
300
+
301
+ def _flush_decoder(self) -> list[Tensor]:
302
+ """Drain all remaining decoded frames from the decoder."""
303
+ frames = []
304
+ while (frame := self._streaming_decode_step()) is not None:
305
+ frames.append(frame)
306
+ return frames
307
+
308
+ # ------------------------------------------------------------------
309
+ # Pipeline-facing encode / decode
310
+ # ------------------------------------------------------------------
311
+
312
+ @torch.inference_mode()
313
+ def encode(self, img: Tensor) -> Tensor:
314
+ """Encode a chunk of frames to a single latent.
315
+
316
+ Args:
317
+ img: [T, H, W, C] uint8 where T == t_downscale
318
+
319
+ Returns:
320
+ latent: [B, C, h, w]
321
+ """
322
+ assert img.dim() == 4 and img.shape[-1] == 3, "Expected [T, H, W, C] RGB uint8"
323
+
324
+ if img.shape[0] != self.t_downscale:
325
+ raise ValueError(
326
+ f"Expected {self.t_downscale} frames, got {img.shape[0]}"
327
+ )
328
+
329
+ rgb = (
330
+ img.unsqueeze(0)
331
+ .to(device=self.device, dtype=self.dtype)
332
+ .permute(0, 1, 4, 2, 3)
333
+ .contiguous()
334
+ .div(255)
335
+ )
336
+
337
+ latent = self._streaming_encode_step(rgb)
338
+ if latent is None:
339
+ raise RuntimeError("Expected a latent after a full chunk")
340
+
341
+ return latent.squeeze(1)
342
+
343
+ @torch.inference_mode()
344
+ def decode(self, latent: Tensor) -> Tensor:
345
+ """Decode a latent to frames.
346
+
347
+ Args:
348
+ latent: [B, C, h, w]
349
+
350
+ Returns:
351
+ frames: [T, H, W, C] uint8
352
+ """
353
+ assert latent.dim() == 4, "Expected [B, C, h, w] latent tensor"
354
+
355
+ z = latent.unsqueeze(1).to(device=self.device, dtype=self.dtype)
356
+
357
+ if self._n_frames_decoded == 0:
358
+ for _ in range(self.frames_to_trim):
359
+ self._streaming_decode_step(z)
360
+ self._flush_decoder()
361
+
362
+ first = self._streaming_decode_step(z)
363
+ assert first is not None, "Expected decoded output after a latent"
364
+ frames = [first, *self._flush_decoder()]
365
+
366
+ decoded = torch.cat(frames, dim=1)
367
+ decoded = (decoded.clamp(0, 1) * 255).round().to(torch.uint8)
368
+ return decoded.squeeze(0).permute(0, 2, 3, 1)[..., :3]
vae/config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "ChunkedStreamingTAEHV",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "AutoModel": "ae_model.ChunkedStreamingTAEHV"
6
+ },
7
+ "latent_channels": 32,
8
+ "patch_size": 2,
9
+ "image_channels": 3,
10
+ "encoder_time_downscale": [true, true, false],
11
+ "decoder_time_upscale": [false, true, true],
12
+ "decoder_space_upscale": [true, true, true]
13
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b52e245bb86c62e159f50338e2e8f422d4b6f98b467164939c1c031c7d61352e
3
+ size 22755856