kaihuac commited on
Commit
428c072
·
verified ·
1 Parent(s): 73c451c

Delete unet/models/diffusion_vas/unet_diffusion_vas.py

Browse files
unet/models/diffusion_vas/unet_diffusion_vas.py DELETED
@@ -1,496 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import Dict, Optional, Tuple, Union
3
-
4
- import torch
5
- import torch.nn as nn
6
-
7
- import diffusers
8
- from diffusers.configuration_utils import ConfigMixin, register_to_config
9
- from diffusers.loaders import UNet2DConditionLoadersMixin
10
- from diffusers.utils import BaseOutput, logging
11
- from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
12
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
13
- from diffusers.models.modeling_utils import ModelMixin
14
- from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
15
-
16
-
17
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
-
19
-
20
- @dataclass
21
- class UNetSpatioTemporalConditionOutput(BaseOutput):
22
- """
23
- The output of [`UNetSpatioTemporalConditionModel`].
24
-
25
- Args:
26
- sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
27
- The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
28
- """
29
-
30
- sample: torch.FloatTensor = None
31
-
32
-
33
- class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
34
- r"""
35
- A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and returns a sample
36
- shaped output.
37
-
38
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
39
- for all models (such as downloading or saving).
40
-
41
- Parameters:
42
- sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
43
- Height and width of input/output sample.
44
- in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
45
- out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
46
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
47
- The tuple of downsample blocks to use.
48
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
49
- The tuple of upsample blocks to use.
50
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
51
- The tuple of output channels for each block.
52
- addition_time_embed_dim: (`int`, defaults to 256):
53
- Dimension to to encode the additional time ids.
54
- projection_class_embeddings_input_dim (`int`, defaults to 768):
55
- The dimension of the projection of encoded `added_time_ids`.
56
- layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
57
- cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
58
- The dimension of the cross attention features.
59
- transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
60
- The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
61
- [`~models.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`], [`~models.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
62
- [`~models.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
63
- num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
64
- The number of attention heads.
65
- dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
66
- """
67
-
68
- _supports_gradient_checkpointing = True
69
-
70
- @register_to_config
71
- def __init__(
72
- self,
73
- sample_size: Optional[int] = None,
74
- in_channels: int = 8,
75
- out_channels: int = 4,
76
- down_block_types: Tuple[str] = (
77
- "CrossAttnDownBlockSpatioTemporal",
78
- "CrossAttnDownBlockSpatioTemporal",
79
- "CrossAttnDownBlockSpatioTemporal",
80
- "DownBlockSpatioTemporal",
81
- ),
82
- up_block_types: Tuple[str] = (
83
- "UpBlockSpatioTemporal",
84
- "CrossAttnUpBlockSpatioTemporal",
85
- "CrossAttnUpBlockSpatioTemporal",
86
- "CrossAttnUpBlockSpatioTemporal",
87
- ),
88
- block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
89
- addition_time_embed_dim: int = 256,
90
- projection_class_embeddings_input_dim: int = 768,
91
- layers_per_block: Union[int, Tuple[int]] = 2,
92
- cross_attention_dim: Union[int, Tuple[int]] = 1024,
93
- transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
94
- num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
95
- num_frames: int = 25,
96
- ):
97
- super().__init__()
98
-
99
- self.sample_size = sample_size
100
-
101
- # Check inputs
102
- if len(down_block_types) != len(up_block_types):
103
- raise ValueError(
104
- f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
105
- )
106
-
107
- if len(block_out_channels) != len(down_block_types):
108
- raise ValueError(
109
- f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
110
- )
111
-
112
- if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
113
- raise ValueError(
114
- f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
115
- )
116
-
117
- if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
118
- raise ValueError(
119
- f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
120
- )
121
-
122
- if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
123
- raise ValueError(
124
- f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
125
- )
126
-
127
-
128
- # input
129
- self.conv_in = nn.Conv2d(
130
- in_channels,
131
- block_out_channels[0],
132
- kernel_size=3,
133
- padding=1,
134
- )
135
-
136
- self.conv_in2 = nn.Conv2d(
137
- 12,
138
- block_out_channels[0],
139
- kernel_size=3,
140
- padding=1,
141
- )
142
-
143
- # time
144
- time_embed_dim = block_out_channels[0] * 4
145
-
146
- self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
147
- timestep_input_dim = block_out_channels[0]
148
-
149
- self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
150
-
151
- self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
152
- self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
153
-
154
- self.down_blocks = nn.ModuleList([])
155
- self.up_blocks = nn.ModuleList([])
156
-
157
- if isinstance(num_attention_heads, int):
158
- num_attention_heads = (num_attention_heads,) * len(down_block_types)
159
-
160
- if isinstance(cross_attention_dim, int):
161
- cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
162
-
163
- if isinstance(layers_per_block, int):
164
- layers_per_block = [layers_per_block] * len(down_block_types)
165
-
166
- if isinstance(transformer_layers_per_block, int):
167
- transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
168
-
169
- blocks_time_embed_dim = time_embed_dim
170
-
171
- # down
172
- output_channel = block_out_channels[0]
173
- for i, down_block_type in enumerate(down_block_types):
174
- input_channel = output_channel
175
- output_channel = block_out_channels[i]
176
- is_final_block = i == len(block_out_channels) - 1
177
-
178
- down_block = get_down_block(
179
- down_block_type,
180
- num_layers=layers_per_block[i],
181
- transformer_layers_per_block=transformer_layers_per_block[i],
182
- in_channels=input_channel,
183
- out_channels=output_channel,
184
- temb_channels=blocks_time_embed_dim,
185
- add_downsample=not is_final_block,
186
- resnet_eps=1e-5,
187
- cross_attention_dim=cross_attention_dim[i],
188
- num_attention_heads=num_attention_heads[i],
189
- resnet_act_fn="silu",
190
- )
191
- self.down_blocks.append(down_block)
192
-
193
- # mid
194
- self.mid_block = UNetMidBlockSpatioTemporal(
195
- block_out_channels[-1],
196
- temb_channels=blocks_time_embed_dim,
197
- transformer_layers_per_block=transformer_layers_per_block[-1],
198
- cross_attention_dim=cross_attention_dim[-1],
199
- num_attention_heads=num_attention_heads[-1],
200
- )
201
-
202
- # count how many layers upsample the images
203
- self.num_upsamplers = 0
204
-
205
- # up
206
- reversed_block_out_channels = list(reversed(block_out_channels))
207
- reversed_num_attention_heads = list(reversed(num_attention_heads))
208
- reversed_layers_per_block = list(reversed(layers_per_block))
209
- reversed_cross_attention_dim = list(reversed(cross_attention_dim))
210
- reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
211
-
212
- output_channel = reversed_block_out_channels[0]
213
- for i, up_block_type in enumerate(up_block_types):
214
- is_final_block = i == len(block_out_channels) - 1
215
-
216
- prev_output_channel = output_channel
217
- output_channel = reversed_block_out_channels[i]
218
- input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
219
-
220
- # add upsample block for all BUT final layer
221
- if not is_final_block:
222
- add_upsample = True
223
- self.num_upsamplers += 1
224
- else:
225
- add_upsample = False
226
-
227
- up_block = get_up_block(
228
- up_block_type,
229
- num_layers=reversed_layers_per_block[i] + 1,
230
- transformer_layers_per_block=reversed_transformer_layers_per_block[i],
231
- in_channels=input_channel,
232
- out_channels=output_channel,
233
- prev_output_channel=prev_output_channel,
234
- temb_channels=blocks_time_embed_dim,
235
- add_upsample=add_upsample,
236
- resnet_eps=1e-5,
237
- resolution_idx=i,
238
- cross_attention_dim=reversed_cross_attention_dim[i],
239
- num_attention_heads=reversed_num_attention_heads[i],
240
- resnet_act_fn="silu",
241
- )
242
- self.up_blocks.append(up_block)
243
- prev_output_channel = output_channel
244
-
245
- # out
246
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
247
- self.conv_act = nn.SiLU()
248
-
249
- self.conv_out = nn.Conv2d(
250
- block_out_channels[0],
251
- out_channels,
252
- kernel_size=3,
253
- padding=1,
254
- )
255
-
256
- @property
257
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
258
- r"""
259
- Returns:
260
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
261
- indexed by its weight name.
262
- """
263
- # set recursively
264
- processors = {}
265
-
266
- def fn_recursive_add_processors(
267
- name: str,
268
- module: torch.nn.Module,
269
- processors: Dict[str, AttentionProcessor],
270
- ):
271
- if hasattr(module, "get_processor"):
272
- processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
273
-
274
- for sub_name, child in module.named_children():
275
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
276
-
277
- return processors
278
-
279
- for name, module in self.named_children():
280
- fn_recursive_add_processors(name, module, processors)
281
-
282
- return processors
283
-
284
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
285
- r"""
286
- Sets the attention processor to use to compute attention.
287
-
288
- Parameters:
289
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
290
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
291
- for **all** `Attention` layers.
292
-
293
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
294
- processor. This is strongly recommended when setting trainable attention processors.
295
-
296
- """
297
- count = len(self.attn_processors.keys())
298
-
299
- if isinstance(processor, dict) and len(processor) != count:
300
- raise ValueError(
301
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
302
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
303
- )
304
-
305
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
306
- if hasattr(module, "set_processor"):
307
- if not isinstance(processor, dict):
308
- module.set_processor(processor)
309
- else:
310
- module.set_processor(processor.pop(f"{name}.processor"))
311
-
312
- for sub_name, child in module.named_children():
313
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
314
-
315
- for name, module in self.named_children():
316
- fn_recursive_attn_processor(name, module, processor)
317
-
318
- def set_default_attn_processor(self):
319
- """
320
- Disables custom attention processors and sets the default attention implementation.
321
- """
322
- if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
323
- processor = AttnProcessor()
324
- else:
325
- raise ValueError(
326
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
327
- )
328
-
329
- self.set_attn_processor(processor)
330
-
331
- def _set_gradient_checkpointing(self, module, value=False):
332
- if hasattr(module, "gradient_checkpointing"):
333
- module.gradient_checkpointing = value
334
-
335
- # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
336
- def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
337
- """
338
- Sets the attention processor to use [feed forward
339
- chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
340
-
341
- Parameters:
342
- chunk_size (`int`, *optional*):
343
- The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
344
- over each tensor of dim=`dim`.
345
- dim (`int`, *optional*, defaults to `0`):
346
- The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
347
- or dim=1 (sequence length).
348
- """
349
- if dim not in [0, 1]:
350
- raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
351
-
352
- # By default chunk size is 1
353
- chunk_size = chunk_size or 1
354
-
355
- def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
356
- if hasattr(module, "set_chunk_feed_forward"):
357
- module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
358
-
359
- for child in module.children():
360
- fn_recursive_feed_forward(child, chunk_size, dim)
361
-
362
- for module in self.children():
363
- fn_recursive_feed_forward(module, chunk_size, dim)
364
-
365
- def forward(
366
- self,
367
- sample: torch.FloatTensor,
368
- timestep: Union[torch.Tensor, float, int],
369
- encoder_hidden_states: torch.Tensor,
370
- added_time_ids: torch.Tensor,
371
- return_dict: bool = True,
372
- ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
373
- r"""
374
- The [`UNetSpatioTemporalConditionModel`] forward method.
375
-
376
- Args:
377
- sample (`torch.FloatTensor`):
378
- The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
379
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
380
- encoder_hidden_states (`torch.FloatTensor`):
381
- The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
382
- added_time_ids: (`torch.FloatTensor`):
383
- The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
384
- embeddings and added to the time embeddings.
385
- return_dict (`bool`, *optional*, defaults to `True`):
386
- Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead of a plain
387
- tuple.
388
- Returns:
389
- [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
390
- If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is returned, otherwise
391
- a `tuple` is returned where the first element is the sample tensor.
392
- """
393
- # 1. time
394
- timesteps = timestep
395
- if not torch.is_tensor(timesteps):
396
- # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
397
- # This would be a good case for the `match` statement (Python 3.10+)
398
- is_mps = sample.device.type == "mps"
399
- if isinstance(timestep, float):
400
- dtype = torch.float32 if is_mps else torch.float64
401
- else:
402
- dtype = torch.int32 if is_mps else torch.int64
403
- timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
404
- elif len(timesteps.shape) == 0:
405
- timesteps = timesteps[None].to(sample.device)
406
-
407
- # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
408
- batch_size, num_frames = sample.shape[:2]
409
- timesteps = timesteps.expand(batch_size)
410
-
411
- t_emb = self.time_proj(timesteps)
412
-
413
- # `Timesteps` does not contain any weights and will always return f32 tensors
414
- # but time_embedding might actually be running in fp16. so we need to cast here.
415
- # there might be better ways to encapsulate this.
416
- t_emb = t_emb.to(dtype=sample.dtype)
417
-
418
- emb = self.time_embedding(t_emb)
419
-
420
- time_embeds = self.add_time_proj(added_time_ids.flatten())
421
- time_embeds = time_embeds.reshape((batch_size, -1))
422
- time_embeds = time_embeds.to(emb.dtype)
423
- aug_emb = self.add_embedding(time_embeds)
424
- emb = emb + aug_emb
425
-
426
- # Flatten the batch and frames dimensions
427
- # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
428
- sample = sample.flatten(0, 1)
429
- # Repeat the embeddings num_video_frames times
430
- # emb: [batch, channels] -> [batch * frames, channels]
431
- emb = emb.repeat_interleave(num_frames, dim=0)
432
-
433
- # 2. pre-process
434
- sample = self.conv_in2(sample)
435
-
436
- image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
437
-
438
- down_block_res_samples = (sample,)
439
- for downsample_block in self.down_blocks:
440
- if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
441
- sample, res_samples = downsample_block(
442
- hidden_states=sample,
443
- temb=emb,
444
- encoder_hidden_states=encoder_hidden_states,
445
- image_only_indicator=image_only_indicator,
446
- )
447
- else:
448
- sample, res_samples = downsample_block(
449
- hidden_states=sample,
450
- temb=emb,
451
- image_only_indicator=image_only_indicator,
452
- )
453
-
454
- down_block_res_samples += res_samples
455
-
456
- # 4. mid
457
- sample = self.mid_block(
458
- hidden_states=sample,
459
- temb=emb,
460
- encoder_hidden_states=encoder_hidden_states,
461
- image_only_indicator=image_only_indicator,
462
- )
463
-
464
- # 5. up
465
- for i, upsample_block in enumerate(self.up_blocks):
466
- res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
467
- down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
468
-
469
- if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
470
- sample = upsample_block(
471
- hidden_states=sample,
472
- temb=emb,
473
- res_hidden_states_tuple=res_samples,
474
- encoder_hidden_states=encoder_hidden_states,
475
- image_only_indicator=image_only_indicator,
476
- )
477
- else:
478
- sample = upsample_block(
479
- hidden_states=sample,
480
- temb=emb,
481
- res_hidden_states_tuple=res_samples,
482
- image_only_indicator=image_only_indicator,
483
- )
484
-
485
- # 6. post-process
486
- sample = self.conv_norm_out(sample)
487
- sample = self.conv_act(sample)
488
- sample = self.conv_out(sample)
489
-
490
- # 7. Reshape back to original shape
491
- sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
492
-
493
- if not return_dict:
494
- return (sample,)
495
-
496
- return UNetSpatioTemporalConditionOutput(sample=sample)