Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Upload 42 files
Browse files- README.md +1 -1
- animatelcm/models/attention.py +296 -0
- animatelcm/models/embeddings.py +213 -0
- animatelcm/models/motion_module.py +337 -0
- animatelcm/models/resnet.py +313 -0
- animatelcm/models/unet.py +568 -0
- animatelcm/models/unet_blocks.py +904 -0
- animatelcm/pipelines/pipeline_animation.py +456 -0
- animatelcm/scheduler/lcm_scheduler.py +722 -0
- animatelcm/utils/convert_from_ckpt.py +951 -0
- animatelcm/utils/convert_lora_safetensor_to_diffusers.py +152 -0
- animatelcm/utils/lcm_utils.py +237 -0
- animatelcm/utils/util.py +153 -0
- app.py +392 -0
- models/.DS_Store +0 -0
- models/DreamBooth_LoRA/cartoon2d.safetensors +3 -0
- models/DreamBooth_LoRA/cartoon3d.safetensors +3 -0
- models/DreamBooth_LoRA/realistic1.safetensors +3 -0
- models/DreamBooth_LoRA/realistic2.safetensors +3 -0
- models/LCM_LoRA/Put LCMLoRA checkpoints here.txt +0 -0
- models/LCM_LoRA/sd15_t2v_beta_lora.safetensors +3 -0
- models/Motion_Module/Put motion module checkpoints here.txt +0 -0
- models/Motion_Module/sd15_t2v_beta_motion.ckpt +3 -0
- models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt +0 -0
- models/StableDiffusion/stable-diffusion-v1-5/.gitattributes +35 -0
- models/StableDiffusion/stable-diffusion-v1-5/README.md +207 -0
- models/StableDiffusion/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json +20 -0
- models/StableDiffusion/stable-diffusion-v1-5/model_index.json +32 -0
- models/StableDiffusion/stable-diffusion-v1-5/safety_checker/config.json +175 -0
- models/StableDiffusion/stable-diffusion-v1-5/scheduler/scheduler_config.json +13 -0
- models/StableDiffusion/stable-diffusion-v1-5/text_encoder/config.json +25 -0
- models/StableDiffusion/stable-diffusion-v1-5/text_encoder/model.safetensors +3 -0
- models/StableDiffusion/stable-diffusion-v1-5/tokenizer/merges.txt +0 -0
- models/StableDiffusion/stable-diffusion-v1-5/tokenizer/special_tokens_map.json +24 -0
- models/StableDiffusion/stable-diffusion-v1-5/tokenizer/tokenizer_config.json +34 -0
- models/StableDiffusion/stable-diffusion-v1-5/tokenizer/vocab.json +0 -0
- models/StableDiffusion/stable-diffusion-v1-5/unet/config.json +36 -0
- models/StableDiffusion/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin +3 -0
- models/StableDiffusion/stable-diffusion-v1-5/v1-inference.yaml +70 -0
- models/StableDiffusion/stable-diffusion-v1-5/vae/config.json +29 -0
- models/StableDiffusion/stable-diffusion-v1-5/vae/diffusion_pytorch_model.bin +3 -0
- requirements.txt +15 -0
    	
        README.md
    CHANGED
    
    | @@ -4,7 +4,7 @@ emoji: 🦀 | |
| 4 | 
             
            colorFrom: red
         | 
| 5 | 
             
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
|  | |
| 4 | 
             
            colorFrom: red
         | 
| 5 | 
             
            colorTo: blue
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 3.48.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
    	
        animatelcm/models/attention.py
    ADDED
    
    | @@ -0,0 +1,296 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Optional
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 9 | 
            +
            from diffusers.modeling_utils import ModelMixin
         | 
| 10 | 
            +
            from diffusers.utils import BaseOutput
         | 
| 11 | 
            +
            from diffusers.utils.import_utils import is_xformers_available
         | 
| 12 | 
            +
            from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from einops import rearrange, repeat
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            @dataclass
         | 
| 17 | 
            +
            class Transformer3DModelOutput(BaseOutput):
         | 
| 18 | 
            +
                sample: torch.FloatTensor
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            if is_xformers_available():
         | 
| 22 | 
            +
                import xformers
         | 
| 23 | 
            +
                import xformers.ops
         | 
| 24 | 
            +
            else:
         | 
| 25 | 
            +
                xformers = None
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            class Transformer3DModel(ModelMixin, ConfigMixin):
         | 
| 29 | 
            +
                @register_to_config
         | 
| 30 | 
            +
                def __init__(
         | 
| 31 | 
            +
                    self,
         | 
| 32 | 
            +
                    num_attention_heads: int = 16,
         | 
| 33 | 
            +
                    attention_head_dim: int = 88,
         | 
| 34 | 
            +
                    in_channels: Optional[int] = None,
         | 
| 35 | 
            +
                    num_layers: int = 1,
         | 
| 36 | 
            +
                    dropout: float = 0.0,
         | 
| 37 | 
            +
                    norm_num_groups: int = 32,
         | 
| 38 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 39 | 
            +
                    attention_bias: bool = False,
         | 
| 40 | 
            +
                    activation_fn: str = "geglu",
         | 
| 41 | 
            +
                    num_embeds_ada_norm: Optional[int] = None,
         | 
| 42 | 
            +
                    use_linear_projection: bool = False,
         | 
| 43 | 
            +
                    only_cross_attention: bool = False,
         | 
| 44 | 
            +
                    upcast_attention: bool = False,
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    unet_use_cross_frame_attention=None,
         | 
| 47 | 
            +
                    unet_use_temporal_attention=None,
         | 
| 48 | 
            +
                ):
         | 
| 49 | 
            +
                    super().__init__()
         | 
| 50 | 
            +
                    self.use_linear_projection = use_linear_projection
         | 
| 51 | 
            +
                    self.num_attention_heads = num_attention_heads
         | 
| 52 | 
            +
                    self.attention_head_dim = attention_head_dim
         | 
| 53 | 
            +
                    inner_dim = num_attention_heads * attention_head_dim
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Define input layers
         | 
| 56 | 
            +
                    self.in_channels = in_channels
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 59 | 
            +
                    if use_linear_projection:
         | 
| 60 | 
            +
                        self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 61 | 
            +
                    else:
         | 
| 62 | 
            +
                        self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # Define transformers blocks
         | 
| 65 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 66 | 
            +
                        [
         | 
| 67 | 
            +
                            BasicTransformerBlock(
         | 
| 68 | 
            +
                                inner_dim,
         | 
| 69 | 
            +
                                num_attention_heads,
         | 
| 70 | 
            +
                                attention_head_dim,
         | 
| 71 | 
            +
                                dropout=dropout,
         | 
| 72 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 73 | 
            +
                                activation_fn=activation_fn,
         | 
| 74 | 
            +
                                num_embeds_ada_norm=num_embeds_ada_norm,
         | 
| 75 | 
            +
                                attention_bias=attention_bias,
         | 
| 76 | 
            +
                                only_cross_attention=only_cross_attention,
         | 
| 77 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                                unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 80 | 
            +
                                unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 81 | 
            +
                            )
         | 
| 82 | 
            +
                            for d in range(num_layers)
         | 
| 83 | 
            +
                        ]
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                    # 4. Define output layers
         | 
| 87 | 
            +
                    if use_linear_projection:
         | 
| 88 | 
            +
                        self.proj_out = nn.Linear(in_channels, inner_dim)
         | 
| 89 | 
            +
                    else:
         | 
| 90 | 
            +
                        self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
         | 
| 93 | 
            +
                    # Input
         | 
| 94 | 
            +
                    assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
         | 
| 95 | 
            +
                    video_length = hidden_states.shape[2]
         | 
| 96 | 
            +
                    hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
         | 
| 97 | 
            +
                    encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    batch, channel, height, weight = hidden_states.shape
         | 
| 100 | 
            +
                    residual = hidden_states
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    hidden_states = self.norm(hidden_states)
         | 
| 103 | 
            +
                    if not self.use_linear_projection:
         | 
| 104 | 
            +
                        hidden_states = self.proj_in(hidden_states)
         | 
| 105 | 
            +
                        inner_dim = hidden_states.shape[1]
         | 
| 106 | 
            +
                        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
         | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        inner_dim = hidden_states.shape[1]
         | 
| 109 | 
            +
                        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
         | 
| 110 | 
            +
                        hidden_states = self.proj_in(hidden_states)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    # Blocks
         | 
| 113 | 
            +
                    for block in self.transformer_blocks:
         | 
| 114 | 
            +
                        hidden_states = block(
         | 
| 115 | 
            +
                            hidden_states,
         | 
| 116 | 
            +
                            encoder_hidden_states=encoder_hidden_states,
         | 
| 117 | 
            +
                            timestep=timestep,
         | 
| 118 | 
            +
                            video_length=video_length
         | 
| 119 | 
            +
                        )
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                    # Output
         | 
| 122 | 
            +
                    if not self.use_linear_projection:
         | 
| 123 | 
            +
                        hidden_states = (
         | 
| 124 | 
            +
                            hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
         | 
| 125 | 
            +
                        )
         | 
| 126 | 
            +
                        hidden_states = self.proj_out(hidden_states)
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        hidden_states = self.proj_out(hidden_states)
         | 
| 129 | 
            +
                        hidden_states = (
         | 
| 130 | 
            +
                            hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
         | 
| 131 | 
            +
                        )
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                    output = hidden_states + residual
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 136 | 
            +
                    if not return_dict:
         | 
| 137 | 
            +
                        return (output,)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    return Transformer3DModelOutput(sample=output)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 143 | 
            +
                def __init__(
         | 
| 144 | 
            +
                    self,
         | 
| 145 | 
            +
                    dim: int,
         | 
| 146 | 
            +
                    num_attention_heads: int,
         | 
| 147 | 
            +
                    attention_head_dim: int,
         | 
| 148 | 
            +
                    dropout=0.0,
         | 
| 149 | 
            +
                    cross_attention_dim: Optional[int] = None,
         | 
| 150 | 
            +
                    activation_fn: str = "geglu",
         | 
| 151 | 
            +
                    num_embeds_ada_norm: Optional[int] = None,
         | 
| 152 | 
            +
                    attention_bias: bool = False,
         | 
| 153 | 
            +
                    only_cross_attention: bool = False,
         | 
| 154 | 
            +
                    upcast_attention: bool = False,
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    unet_use_cross_frame_attention = None,
         | 
| 157 | 
            +
                    unet_use_temporal_attention = None,
         | 
| 158 | 
            +
                ):
         | 
| 159 | 
            +
                    super().__init__()
         | 
| 160 | 
            +
                    self.only_cross_attention = only_cross_attention
         | 
| 161 | 
            +
                    self.use_ada_layer_norm = num_embeds_ada_norm is not None
         | 
| 162 | 
            +
                    self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
         | 
| 163 | 
            +
                    self.unet_use_temporal_attention = unet_use_temporal_attention
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    # SC-Attn
         | 
| 166 | 
            +
                    assert unet_use_cross_frame_attention is not None
         | 
| 167 | 
            +
                    if unet_use_cross_frame_attention:
         | 
| 168 | 
            +
                        self.attn1 = SparseCausalAttention2D(
         | 
| 169 | 
            +
                            query_dim=dim,
         | 
| 170 | 
            +
                            heads=num_attention_heads,
         | 
| 171 | 
            +
                            dim_head=attention_head_dim,
         | 
| 172 | 
            +
                            dropout=dropout,
         | 
| 173 | 
            +
                            bias=attention_bias,
         | 
| 174 | 
            +
                            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         | 
| 175 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 176 | 
            +
                        )
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        self.attn1 = CrossAttention(
         | 
| 179 | 
            +
                            query_dim=dim,
         | 
| 180 | 
            +
                            heads=num_attention_heads,
         | 
| 181 | 
            +
                            dim_head=attention_head_dim,
         | 
| 182 | 
            +
                            dropout=dropout,
         | 
| 183 | 
            +
                            bias=attention_bias,
         | 
| 184 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 185 | 
            +
                        )
         | 
| 186 | 
            +
                    self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # Cross-Attn
         | 
| 189 | 
            +
                    if cross_attention_dim is not None:
         | 
| 190 | 
            +
                        self.attn2 = CrossAttention(
         | 
| 191 | 
            +
                            query_dim=dim,
         | 
| 192 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 193 | 
            +
                            heads=num_attention_heads,
         | 
| 194 | 
            +
                            dim_head=attention_head_dim,
         | 
| 195 | 
            +
                            dropout=dropout,
         | 
| 196 | 
            +
                            bias=attention_bias,
         | 
| 197 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 198 | 
            +
                        )
         | 
| 199 | 
            +
                    else:
         | 
| 200 | 
            +
                        self.attn2 = None
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    if cross_attention_dim is not None:
         | 
| 203 | 
            +
                        self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
         | 
| 204 | 
            +
                    else:
         | 
| 205 | 
            +
                        self.norm2 = None
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # Feed-forward
         | 
| 208 | 
            +
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
         | 
| 209 | 
            +
                    self.norm3 = nn.LayerNorm(dim)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    # Temp-Attn
         | 
| 212 | 
            +
                    assert unet_use_temporal_attention is not None
         | 
| 213 | 
            +
                    if unet_use_temporal_attention:
         | 
| 214 | 
            +
                        self.attn_temp = CrossAttention(
         | 
| 215 | 
            +
                            query_dim=dim,
         | 
| 216 | 
            +
                            heads=num_attention_heads,
         | 
| 217 | 
            +
                            dim_head=attention_head_dim,
         | 
| 218 | 
            +
                            dropout=dropout,
         | 
| 219 | 
            +
                            bias=attention_bias,
         | 
| 220 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 221 | 
            +
                        )
         | 
| 222 | 
            +
                        nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
         | 
| 223 | 
            +
                        self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
         | 
| 226 | 
            +
                    if not is_xformers_available():
         | 
| 227 | 
            +
                        raise ModuleNotFoundError(
         | 
| 228 | 
            +
                            "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
         | 
| 229 | 
            +
                            " xformers",
         | 
| 230 | 
            +
                            name="xformers",
         | 
| 231 | 
            +
                        )
         | 
| 232 | 
            +
                    elif not torch.cuda.is_available():
         | 
| 233 | 
            +
                        raise ValueError(
         | 
| 234 | 
            +
                            "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
         | 
| 235 | 
            +
                            " available for GPU "
         | 
| 236 | 
            +
                        )
         | 
| 237 | 
            +
                    else:
         | 
| 238 | 
            +
                        try:
         | 
| 239 | 
            +
                            # Make sure we can run the memory efficient attention
         | 
| 240 | 
            +
                            _ = xformers.ops.memory_efficient_attention(
         | 
| 241 | 
            +
                                torch.randn((1, 2, 40), device="cuda"),
         | 
| 242 | 
            +
                                torch.randn((1, 2, 40), device="cuda"),
         | 
| 243 | 
            +
                                torch.randn((1, 2, 40), device="cuda"),
         | 
| 244 | 
            +
                            )
         | 
| 245 | 
            +
                        except Exception as e:
         | 
| 246 | 
            +
                            raise e
         | 
| 247 | 
            +
                        self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
         | 
| 248 | 
            +
                        if self.attn2 is not None:
         | 
| 249 | 
            +
                            self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
         | 
| 250 | 
            +
                        # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
         | 
| 253 | 
            +
                    # SparseCausal-Attention
         | 
| 254 | 
            +
                    norm_hidden_states = (
         | 
| 255 | 
            +
                        self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
         | 
| 256 | 
            +
                    )
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    # if self.only_cross_attention:
         | 
| 259 | 
            +
                    #     hidden_states = (
         | 
| 260 | 
            +
                    #         self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
         | 
| 261 | 
            +
                    #     )
         | 
| 262 | 
            +
                    # else:
         | 
| 263 | 
            +
                    #     hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    # pdb.set_trace()
         | 
| 266 | 
            +
                    if self.unet_use_cross_frame_attention:
         | 
| 267 | 
            +
                        hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
         | 
| 268 | 
            +
                    else:
         | 
| 269 | 
            +
                        hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    if self.attn2 is not None:
         | 
| 272 | 
            +
                        # Cross-Attention
         | 
| 273 | 
            +
                        norm_hidden_states = (
         | 
| 274 | 
            +
                            self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
         | 
| 275 | 
            +
                        )
         | 
| 276 | 
            +
                        hidden_states = (
         | 
| 277 | 
            +
                            self.attn2(
         | 
| 278 | 
            +
                                norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
         | 
| 279 | 
            +
                            )
         | 
| 280 | 
            +
                            + hidden_states
         | 
| 281 | 
            +
                        )
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    # Feed-forward
         | 
| 284 | 
            +
                    hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # Temporal-Attention
         | 
| 287 | 
            +
                    if self.unet_use_temporal_attention:
         | 
| 288 | 
            +
                        d = hidden_states.shape[1]
         | 
| 289 | 
            +
                        hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
         | 
| 290 | 
            +
                        norm_hidden_states = (
         | 
| 291 | 
            +
                            self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
         | 
| 292 | 
            +
                        )
         | 
| 293 | 
            +
                        hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
         | 
| 294 | 
            +
                        hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    return hidden_states
         | 
    	
        animatelcm/models/embeddings.py
    ADDED
    
    | @@ -0,0 +1,213 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2022 The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            import math
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            import numpy as np
         | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            from torch import nn
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def get_timestep_embedding(
         | 
| 22 | 
            +
                timesteps: torch.Tensor,
         | 
| 23 | 
            +
                embedding_dim: int,
         | 
| 24 | 
            +
                flip_sin_to_cos: bool = False,
         | 
| 25 | 
            +
                downscale_freq_shift: float = 1,
         | 
| 26 | 
            +
                scale: float = 1,
         | 
| 27 | 
            +
                max_period: int = 10000,
         | 
| 28 | 
            +
            ):
         | 
| 29 | 
            +
                """
         | 
| 30 | 
            +
                This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                :param timesteps: a 1-D Tensor of N indices, one per batch element.
         | 
| 33 | 
            +
                                  These may be fractional.
         | 
| 34 | 
            +
                :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
         | 
| 35 | 
            +
                embeddings. :return: an [N x dim] Tensor of positional embeddings.
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                half_dim = embedding_dim // 2
         | 
| 40 | 
            +
                exponent = -math.log(max_period) * torch.arange(
         | 
| 41 | 
            +
                    start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
         | 
| 42 | 
            +
                )
         | 
| 43 | 
            +
                exponent = exponent / (half_dim - downscale_freq_shift)
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                emb = torch.exp(exponent)
         | 
| 46 | 
            +
                emb = timesteps[:, None].float() * emb[None, :]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                # scale embeddings
         | 
| 49 | 
            +
                emb = scale * emb
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                # concat sine and cosine embeddings
         | 
| 52 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # flip sine and cosine embeddings
         | 
| 55 | 
            +
                if flip_sin_to_cos:
         | 
| 56 | 
            +
                    emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # zero pad
         | 
| 59 | 
            +
                if embedding_dim % 2 == 1:
         | 
| 60 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
         | 
| 61 | 
            +
                return emb
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            def zero_module(module):
         | 
| 64 | 
            +
                # Zero out the parameters of a module and return it.
         | 
| 65 | 
            +
                for p in module.parameters():
         | 
| 66 | 
            +
                    p.detach().zero_()
         | 
| 67 | 
            +
                return module
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            class TimestepEmbedding(nn.Module):
         | 
| 70 | 
            +
                def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, time_cond_proj_dim=None):
         | 
| 71 | 
            +
                    super().__init__()
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    self.linear_1 = nn.Linear(in_channels, time_embed_dim)
         | 
| 74 | 
            +
                    self.act = None
         | 
| 75 | 
            +
                    if act_fn == "silu":
         | 
| 76 | 
            +
                        self.act = nn.SiLU()
         | 
| 77 | 
            +
                    elif act_fn == "mish":
         | 
| 78 | 
            +
                        self.act = nn.Mish()
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    if time_cond_proj_dim is not None:
         | 
| 81 | 
            +
                        self.cond_proj = zero_module(nn.Linear(time_cond_proj_dim, in_channels, bias=False))
         | 
| 82 | 
            +
                    else:
         | 
| 83 | 
            +
                        self.cond_proj = None
         | 
| 84 | 
            +
                        
         | 
| 85 | 
            +
                        
         | 
| 86 | 
            +
                    if out_dim is not None:
         | 
| 87 | 
            +
                        time_embed_dim_out = out_dim
         | 
| 88 | 
            +
                    else:
         | 
| 89 | 
            +
                        time_embed_dim_out = time_embed_dim
         | 
| 90 | 
            +
                    self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def forward(self, sample, condition=None):
         | 
| 93 | 
            +
                    if condition is not None:
         | 
| 94 | 
            +
                        sample = sample + self.cond_proj(condition)
         | 
| 95 | 
            +
                    sample = self.linear_1(sample)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    if self.act is not None:
         | 
| 98 | 
            +
                        sample = self.act(sample)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    sample = self.linear_2(sample)
         | 
| 101 | 
            +
                    return sample
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            class Timesteps(nn.Module):
         | 
| 105 | 
            +
                def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
         | 
| 106 | 
            +
                    super().__init__()
         | 
| 107 | 
            +
                    self.num_channels = num_channels
         | 
| 108 | 
            +
                    self.flip_sin_to_cos = flip_sin_to_cos
         | 
| 109 | 
            +
                    self.downscale_freq_shift = downscale_freq_shift
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def forward(self, timesteps):
         | 
| 112 | 
            +
                    t_emb = get_timestep_embedding(
         | 
| 113 | 
            +
                        timesteps,
         | 
| 114 | 
            +
                        self.num_channels,
         | 
| 115 | 
            +
                        flip_sin_to_cos=self.flip_sin_to_cos,
         | 
| 116 | 
            +
                        downscale_freq_shift=self.downscale_freq_shift,
         | 
| 117 | 
            +
                    )
         | 
| 118 | 
            +
                    return t_emb
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            class GaussianFourierProjection(nn.Module):
         | 
| 122 | 
            +
                """Gaussian Fourier embeddings for noise levels."""
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                def __init__(
         | 
| 125 | 
            +
                    self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
         | 
| 126 | 
            +
                ):
         | 
| 127 | 
            +
                    super().__init__()
         | 
| 128 | 
            +
                    self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
         | 
| 129 | 
            +
                    self.log = log
         | 
| 130 | 
            +
                    self.flip_sin_to_cos = flip_sin_to_cos
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                    if set_W_to_weight:
         | 
| 133 | 
            +
                        # to delete later
         | 
| 134 | 
            +
                        self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        self.weight = self.W
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                def forward(self, x):
         | 
| 139 | 
            +
                    if self.log:
         | 
| 140 | 
            +
                        x = torch.log(x)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                    x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    if self.flip_sin_to_cos:
         | 
| 145 | 
            +
                        out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
         | 
| 146 | 
            +
                    else:
         | 
| 147 | 
            +
                        out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
         | 
| 148 | 
            +
                    return out
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            class ImagePositionalEmbeddings(nn.Module):
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
                Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
         | 
| 154 | 
            +
                height and width of the latent space.
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                For VQ-diffusion:
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                Output vector embeddings are used as input for the transformer.
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                Args:
         | 
| 165 | 
            +
                    num_embed (`int`):
         | 
| 166 | 
            +
                        Number of embeddings for the latent pixels embeddings.
         | 
| 167 | 
            +
                    height (`int`):
         | 
| 168 | 
            +
                        Height of the latent image i.e. the number of height embeddings.
         | 
| 169 | 
            +
                    width (`int`):
         | 
| 170 | 
            +
                        Width of the latent image i.e. the number of width embeddings.
         | 
| 171 | 
            +
                    embed_dim (`int`):
         | 
| 172 | 
            +
                        Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
         | 
| 173 | 
            +
                """
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def __init__(
         | 
| 176 | 
            +
                    self,
         | 
| 177 | 
            +
                    num_embed: int,
         | 
| 178 | 
            +
                    height: int,
         | 
| 179 | 
            +
                    width: int,
         | 
| 180 | 
            +
                    embed_dim: int,
         | 
| 181 | 
            +
                ):
         | 
| 182 | 
            +
                    super().__init__()
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    self.height = height
         | 
| 185 | 
            +
                    self.width = width
         | 
| 186 | 
            +
                    self.num_embed = num_embed
         | 
| 187 | 
            +
                    self.embed_dim = embed_dim
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    self.emb = nn.Embedding(self.num_embed, embed_dim)
         | 
| 190 | 
            +
                    self.height_emb = nn.Embedding(self.height, embed_dim)
         | 
| 191 | 
            +
                    self.width_emb = nn.Embedding(self.width, embed_dim)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def forward(self, index):
         | 
| 194 | 
            +
                    emb = self.emb(index)
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                    height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    # 1 x H x D -> 1 x H x 1 x D
         | 
| 199 | 
            +
                    height_emb = height_emb.unsqueeze(2)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                    # 1 x W x D -> 1 x 1 x W x D
         | 
| 204 | 
            +
                    width_emb = width_emb.unsqueeze(1)
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    pos_emb = height_emb + width_emb
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                    # 1 x H x W x D -> 1 x L xD
         | 
| 209 | 
            +
                    pos_emb = pos_emb.view(1, self.height * self.width, -1)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    emb = emb + pos_emb[:, : emb.shape[1], :]
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    return emb
         | 
    	
        animatelcm/models/motion_module.py
    ADDED
    
    | @@ -0,0 +1,337 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from torch import nn
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 9 | 
            +
            from diffusers.modeling_utils import ModelMixin
         | 
| 10 | 
            +
            from diffusers.utils import BaseOutput
         | 
| 11 | 
            +
            from diffusers.utils.import_utils import is_xformers_available
         | 
| 12 | 
            +
            from diffusers.models.attention import CrossAttention, FeedForward
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from einops import rearrange, repeat
         | 
| 15 | 
            +
            import math
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def zero_module(module):
         | 
| 19 | 
            +
                for p in module.parameters():
         | 
| 20 | 
            +
                    p.detach().zero_()
         | 
| 21 | 
            +
                return module
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            @dataclass
         | 
| 25 | 
            +
            class TemporalTransformer3DModelOutput(BaseOutput):
         | 
| 26 | 
            +
                sample: torch.FloatTensor
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            if is_xformers_available():
         | 
| 30 | 
            +
                import xformers
         | 
| 31 | 
            +
                import xformers.ops
         | 
| 32 | 
            +
            else:
         | 
| 33 | 
            +
                xformers = None
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            def get_motion_module(
         | 
| 37 | 
            +
                in_channels,
         | 
| 38 | 
            +
                motion_module_type: str,
         | 
| 39 | 
            +
                motion_module_kwargs: dict
         | 
| 40 | 
            +
            ):
         | 
| 41 | 
            +
                if motion_module_type == "Vanilla":
         | 
| 42 | 
            +
                    return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
         | 
| 43 | 
            +
                else:
         | 
| 44 | 
            +
                    raise ValueError
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class VanillaTemporalModule(nn.Module):
         | 
| 48 | 
            +
                def __init__(
         | 
| 49 | 
            +
                    self,
         | 
| 50 | 
            +
                    in_channels,
         | 
| 51 | 
            +
                    num_attention_heads=8,
         | 
| 52 | 
            +
                    num_transformer_block=2,
         | 
| 53 | 
            +
                    attention_block_types=("Temporal_Self", "Temporal_Self"),
         | 
| 54 | 
            +
                    cross_frame_attention_mode=None,
         | 
| 55 | 
            +
                    temporal_position_encoding=False,
         | 
| 56 | 
            +
                    temporal_attention_dim_div=1,
         | 
| 57 | 
            +
                    zero_initialize=True,
         | 
| 58 | 
            +
                ):
         | 
| 59 | 
            +
                    super().__init__()
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    self.temporal_transformer = TemporalTransformer3DModel(
         | 
| 62 | 
            +
                        in_channels=in_channels,
         | 
| 63 | 
            +
                        num_attention_heads=num_attention_heads,
         | 
| 64 | 
            +
                        attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
         | 
| 65 | 
            +
                        num_layers=num_transformer_block,
         | 
| 66 | 
            +
                        attention_block_types=attention_block_types,
         | 
| 67 | 
            +
                        cross_frame_attention_mode=cross_frame_attention_mode,
         | 
| 68 | 
            +
                        temporal_position_encoding=temporal_position_encoding,
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                    if zero_initialize:
         | 
| 72 | 
            +
                        self.temporal_transformer.proj_out = zero_module(
         | 
| 73 | 
            +
                            self.temporal_transformer.proj_out)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
         | 
| 76 | 
            +
                    hidden_states = input_tensor
         | 
| 77 | 
            +
                    hidden_states = self.temporal_transformer(
         | 
| 78 | 
            +
                        hidden_states, encoder_hidden_states, attention_mask)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    output = hidden_states
         | 
| 81 | 
            +
                    return output
         | 
| 82 | 
            +
             | 
| 83 | 
            +
             | 
| 84 | 
            +
            class TemporalTransformer3DModel(nn.Module):
         | 
| 85 | 
            +
                def __init__(
         | 
| 86 | 
            +
                    self,
         | 
| 87 | 
            +
                    in_channels,
         | 
| 88 | 
            +
                    num_attention_heads,
         | 
| 89 | 
            +
                    attention_head_dim,
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                    num_layers,
         | 
| 92 | 
            +
                    attention_block_types=("Temporal_Self", "Temporal_Self", ),
         | 
| 93 | 
            +
                    dropout=0.0,
         | 
| 94 | 
            +
                    norm_num_groups=32,
         | 
| 95 | 
            +
                    cross_attention_dim=768,
         | 
| 96 | 
            +
                    activation_fn="geglu",
         | 
| 97 | 
            +
                    attention_bias=False,
         | 
| 98 | 
            +
                    upcast_attention=False,
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                    cross_frame_attention_mode=None,
         | 
| 101 | 
            +
                    temporal_position_encoding=False,
         | 
| 102 | 
            +
                ):
         | 
| 103 | 
            +
                    super().__init__()
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                    inner_dim = num_attention_heads * attention_head_dim
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    self.norm = torch.nn.GroupNorm(
         | 
| 108 | 
            +
                        num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
         | 
| 109 | 
            +
                    self.proj_in = nn.Linear(in_channels, inner_dim)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 112 | 
            +
                        [
         | 
| 113 | 
            +
                            TemporalTransformerBlock(
         | 
| 114 | 
            +
                                dim=inner_dim,
         | 
| 115 | 
            +
                                num_attention_heads=num_attention_heads,
         | 
| 116 | 
            +
                                attention_head_dim=attention_head_dim,
         | 
| 117 | 
            +
                                attention_block_types=attention_block_types,
         | 
| 118 | 
            +
                                dropout=dropout,
         | 
| 119 | 
            +
                                norm_num_groups=norm_num_groups,
         | 
| 120 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 121 | 
            +
                                activation_fn=activation_fn,
         | 
| 122 | 
            +
                                attention_bias=attention_bias,
         | 
| 123 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 124 | 
            +
                                cross_frame_attention_mode=cross_frame_attention_mode,
         | 
| 125 | 
            +
                                temporal_position_encoding=temporal_position_encoding,
         | 
| 126 | 
            +
                            )
         | 
| 127 | 
            +
                            for d in range(num_layers)
         | 
| 128 | 
            +
                        ]
         | 
| 129 | 
            +
                    )
         | 
| 130 | 
            +
                    self.proj_out = nn.Linear(inner_dim, in_channels)
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
         | 
| 133 | 
            +
                    assert hidden_states.dim(
         | 
| 134 | 
            +
                    ) == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
         | 
| 135 | 
            +
                    video_length = hidden_states.shape[2]
         | 
| 136 | 
            +
                    hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    batch, channel, height, weight = hidden_states.shape
         | 
| 139 | 
            +
                    residual = hidden_states
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    hidden_states = self.norm(hidden_states)
         | 
| 142 | 
            +
                    inner_dim = hidden_states.shape[1]
         | 
| 143 | 
            +
                    hidden_states = hidden_states.permute(
         | 
| 144 | 
            +
                        0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
         | 
| 145 | 
            +
                    hidden_states = self.proj_in(hidden_states)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    # Transformer Blocks
         | 
| 148 | 
            +
                    for block in self.transformer_blocks:
         | 
| 149 | 
            +
                        hidden_states = block(
         | 
| 150 | 
            +
                            hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    # output
         | 
| 153 | 
            +
                    hidden_states = self.proj_out(hidden_states)
         | 
| 154 | 
            +
                    hidden_states = hidden_states.reshape(
         | 
| 155 | 
            +
                        batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    output = hidden_states + residual
         | 
| 158 | 
            +
                    output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                    return output
         | 
| 161 | 
            +
             | 
| 162 | 
            +
             | 
| 163 | 
            +
            class TemporalTransformerBlock(nn.Module):
         | 
| 164 | 
            +
                def __init__(
         | 
| 165 | 
            +
                    self,
         | 
| 166 | 
            +
                    dim,
         | 
| 167 | 
            +
                    num_attention_heads,
         | 
| 168 | 
            +
                    attention_head_dim,
         | 
| 169 | 
            +
                    attention_block_types=("Temporal_Self", "Temporal_Self", ),
         | 
| 170 | 
            +
                    dropout=0.0,
         | 
| 171 | 
            +
                    norm_num_groups=32,
         | 
| 172 | 
            +
                    cross_attention_dim=768,
         | 
| 173 | 
            +
                    activation_fn="geglu",
         | 
| 174 | 
            +
                    attention_bias=False,
         | 
| 175 | 
            +
                    upcast_attention=False,
         | 
| 176 | 
            +
                    cross_frame_attention_mode=None,
         | 
| 177 | 
            +
                    temporal_position_encoding=False,
         | 
| 178 | 
            +
                ):
         | 
| 179 | 
            +
                    super().__init__()
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    attention_blocks = []
         | 
| 182 | 
            +
                    norms = []
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    for block_name in attention_block_types:
         | 
| 185 | 
            +
                        attention_blocks.append(
         | 
| 186 | 
            +
                            VersatileAttention(
         | 
| 187 | 
            +
                                attention_mode=block_name.split("_")[0],
         | 
| 188 | 
            +
                                cross_attention_dim=cross_attention_dim if block_name.endswith(
         | 
| 189 | 
            +
                                    "_Cross") else None,
         | 
| 190 | 
            +
             | 
| 191 | 
            +
                                query_dim=dim,
         | 
| 192 | 
            +
                                heads=num_attention_heads,
         | 
| 193 | 
            +
                                dim_head=attention_head_dim,
         | 
| 194 | 
            +
                                dropout=dropout,
         | 
| 195 | 
            +
                                bias=attention_bias,
         | 
| 196 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                                cross_frame_attention_mode=cross_frame_attention_mode,
         | 
| 199 | 
            +
                                temporal_position_encoding=temporal_position_encoding,
         | 
| 200 | 
            +
                            )
         | 
| 201 | 
            +
                        )
         | 
| 202 | 
            +
                        norms.append(nn.LayerNorm(dim))
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    self.attention_blocks = nn.ModuleList(attention_blocks)
         | 
| 205 | 
            +
                    self.norms = nn.ModuleList(norms)
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    self.ff = FeedForward(dim, dropout=dropout,
         | 
| 208 | 
            +
                                          activation_fn=activation_fn)
         | 
| 209 | 
            +
                    self.ff_norm = nn.LayerNorm(dim)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
         | 
| 212 | 
            +
                    for attention_block, norm in zip(self.attention_blocks, self.norms):
         | 
| 213 | 
            +
                        norm_hidden_states = norm(hidden_states)
         | 
| 214 | 
            +
                        hidden_states = attention_block(
         | 
| 215 | 
            +
                            norm_hidden_states,
         | 
| 216 | 
            +
                            encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
         | 
| 217 | 
            +
                            video_length=video_length,
         | 
| 218 | 
            +
                        ) + hidden_states
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                    output = hidden_states
         | 
| 223 | 
            +
                    return output
         | 
| 224 | 
            +
             | 
| 225 | 
            +
             | 
| 226 | 
            +
            class PositionalEncoding(nn.Module):
         | 
| 227 | 
            +
                def __init__(
         | 
| 228 | 
            +
                    self,
         | 
| 229 | 
            +
                    d_model,
         | 
| 230 | 
            +
                    dropout=0.,
         | 
| 231 | 
            +
                ):
         | 
| 232 | 
            +
                    super().__init__()
         | 
| 233 | 
            +
                    
         | 
| 234 | 
            +
                    max_length = 64
         | 
| 235 | 
            +
                    self.dropout = nn.Dropout(p=dropout)
         | 
| 236 | 
            +
                    position = torch.arange(max_length).unsqueeze(1)
         | 
| 237 | 
            +
                    div_term = torch.exp(torch.arange(0, d_model, 2)
         | 
| 238 | 
            +
                                         * (-math.log(10000.0) / d_model))
         | 
| 239 | 
            +
                    pe = torch.zeros(1, max_length, d_model)
         | 
| 240 | 
            +
                    pe[0, :, 0::2] = torch.sin(position * div_term)
         | 
| 241 | 
            +
                    pe[0, :, 1::2] = torch.cos(position * div_term)
         | 
| 242 | 
            +
                    self.register_buffer('pos_encoding', pe)
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def forward(self, x):
         | 
| 245 | 
            +
                    x = x + self.pos_encoding[:, :x.size(1)]
         | 
| 246 | 
            +
                    return self.dropout(x)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            class VersatileAttention(CrossAttention):
         | 
| 250 | 
            +
                def __init__(
         | 
| 251 | 
            +
                    self,
         | 
| 252 | 
            +
                    attention_mode=None,
         | 
| 253 | 
            +
                    cross_frame_attention_mode=None,
         | 
| 254 | 
            +
                    temporal_position_encoding=False,
         | 
| 255 | 
            +
                    *args, **kwargs
         | 
| 256 | 
            +
                ):
         | 
| 257 | 
            +
                    super().__init__(*args, **kwargs)
         | 
| 258 | 
            +
                    assert attention_mode == "Temporal"
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                    self.attention_mode = attention_mode
         | 
| 261 | 
            +
                    self.is_cross_attention = kwargs["cross_attention_dim"] is not None
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                    self.pos_encoder = PositionalEncoding(
         | 
| 264 | 
            +
                        kwargs["query_dim"],
         | 
| 265 | 
            +
                        dropout=0.,
         | 
| 266 | 
            +
                    ) if (temporal_position_encoding and attention_mode == "Temporal") else None
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                def extra_repr(self):
         | 
| 269 | 
            +
                    return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
         | 
| 272 | 
            +
                    batch_size, sequence_length, _ = hidden_states.shape
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    if self.attention_mode == "Temporal":
         | 
| 275 | 
            +
                        d = hidden_states.shape[1]
         | 
| 276 | 
            +
                        hidden_states = rearrange(
         | 
| 277 | 
            +
                            hidden_states, "(b f) d c -> (b d) f c", f=video_length)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                        if self.pos_encoder is not None:
         | 
| 280 | 
            +
                            hidden_states = self.pos_encoder(hidden_states)
         | 
| 281 | 
            +
             | 
| 282 | 
            +
                        encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c",
         | 
| 283 | 
            +
                                                       d=d) if encoder_hidden_states is not None else encoder_hidden_states
         | 
| 284 | 
            +
                    else:
         | 
| 285 | 
            +
                        raise NotImplementedError
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                    encoder_hidden_states = encoder_hidden_states
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    if self.group_norm is not None:
         | 
| 290 | 
            +
                        hidden_states = self.group_norm(
         | 
| 291 | 
            +
                            hidden_states.transpose(1, 2)).transpose(1, 2)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    query = self.to_q(hidden_states)
         | 
| 294 | 
            +
                    dim = query.shape[-1]
         | 
| 295 | 
            +
                    query = self.reshape_heads_to_batch_dim(query)
         | 
| 296 | 
            +
             | 
| 297 | 
            +
                    if self.added_kv_proj_dim is not None:
         | 
| 298 | 
            +
                        raise NotImplementedError
         | 
| 299 | 
            +
             | 
| 300 | 
            +
                    encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
         | 
| 301 | 
            +
                    key = self.to_k(encoder_hidden_states)
         | 
| 302 | 
            +
                    value = self.to_v(encoder_hidden_states)
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                    key = self.reshape_heads_to_batch_dim(key)
         | 
| 305 | 
            +
                    value = self.reshape_heads_to_batch_dim(value)
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    if attention_mask is not None:
         | 
| 308 | 
            +
                        if attention_mask.shape[-1] != query.shape[1]:
         | 
| 309 | 
            +
                            target_length = query.shape[1]
         | 
| 310 | 
            +
                            attention_mask = F.pad(
         | 
| 311 | 
            +
                                attention_mask, (0, target_length), value=0.0)
         | 
| 312 | 
            +
                            attention_mask = attention_mask.repeat_interleave(
         | 
| 313 | 
            +
                                self.heads, dim=0)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    if self._use_memory_efficient_attention_xformers:
         | 
| 316 | 
            +
                        hidden_states = self._memory_efficient_attention_xformers(
         | 
| 317 | 
            +
                            query, key, value, attention_mask)
         | 
| 318 | 
            +
                        hidden_states = hidden_states.to(query.dtype)
         | 
| 319 | 
            +
                    else:
         | 
| 320 | 
            +
                        if self._slice_size is None or query.shape[0] // self._slice_size == 1:
         | 
| 321 | 
            +
                            hidden_states = self._attention(
         | 
| 322 | 
            +
                                query, key, value, attention_mask)
         | 
| 323 | 
            +
                        else:
         | 
| 324 | 
            +
                            hidden_states = self._sliced_attention(
         | 
| 325 | 
            +
                                query, key, value, sequence_length, dim, attention_mask)
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    # linear proj
         | 
| 328 | 
            +
                    hidden_states = self.to_out[0](hidden_states)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    # dropout
         | 
| 331 | 
            +
                    hidden_states = self.to_out[1](hidden_states)
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    if self.attention_mode == "Temporal":
         | 
| 334 | 
            +
                        hidden_states = rearrange(
         | 
| 335 | 
            +
                            hidden_states, "(b d) f c -> (b f) d c", d=d)
         | 
| 336 | 
            +
             | 
| 337 | 
            +
                    return hidden_states
         | 
    	
        animatelcm/models/resnet.py
    ADDED
    
    | @@ -0,0 +1,313 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn as nn
         | 
| 5 | 
            +
            import torch.nn.functional as F
         | 
| 6 | 
            +
            from typing import Optional
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class InflatedConv3d(nn.Conv2d):
         | 
| 12 | 
            +
                def forward(self, x):
         | 
| 13 | 
            +
                    video_length = x.shape[2]
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    x = rearrange(x, "b c f h w -> (b f) c h w")
         | 
| 16 | 
            +
                    x = super().forward(x)
         | 
| 17 | 
            +
                    x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    return x
         | 
| 20 | 
            +
             | 
| 21 | 
            +
             | 
| 22 | 
            +
            class InflatedGroupNorm(nn.GroupNorm):
         | 
| 23 | 
            +
                def forward(self, x):
         | 
| 24 | 
            +
                    video_length = x.shape[2]
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    x = rearrange(x, "b c f h w -> (b f) c h w")
         | 
| 27 | 
            +
                    x = super().forward(x)
         | 
| 28 | 
            +
                    x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    return x
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class Upsample3D(nn.Module):
         | 
| 34 | 
            +
                def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
         | 
| 35 | 
            +
                    super().__init__()
         | 
| 36 | 
            +
                    self.channels = channels
         | 
| 37 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 38 | 
            +
                    self.use_conv = use_conv
         | 
| 39 | 
            +
                    self.use_conv_transpose = use_conv_transpose
         | 
| 40 | 
            +
                    self.name = name
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    conv = None
         | 
| 43 | 
            +
                    if use_conv_transpose:
         | 
| 44 | 
            +
                        raise NotImplementedError
         | 
| 45 | 
            +
                    elif use_conv:
         | 
| 46 | 
            +
                        self.conv = InflatedConv3d(
         | 
| 47 | 
            +
                            self.channels, self.out_channels, 3, padding=1)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def forward(self, hidden_states, output_size=None):
         | 
| 50 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    if self.use_conv_transpose:
         | 
| 53 | 
            +
                        raise NotImplementedError
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                    # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
         | 
| 56 | 
            +
                    dtype = hidden_states.dtype
         | 
| 57 | 
            +
                    if dtype == torch.bfloat16:
         | 
| 58 | 
            +
                        hidden_states = hidden_states.to(torch.float32)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
         | 
| 61 | 
            +
                    if hidden_states.shape[0] >= 64:
         | 
| 62 | 
            +
                        hidden_states = hidden_states.contiguous()
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # if `output_size` is passed we force the interpolation output
         | 
| 65 | 
            +
                    # size and do not make use of `scale_factor=2`
         | 
| 66 | 
            +
                    if output_size is None:
         | 
| 67 | 
            +
                        hidden_states = F.interpolate(hidden_states, scale_factor=[
         | 
| 68 | 
            +
                                                      1.0, 2.0, 2.0], mode="nearest")
         | 
| 69 | 
            +
                    else:
         | 
| 70 | 
            +
                        hidden_states = F.interpolate(
         | 
| 71 | 
            +
                            hidden_states, size=output_size, mode="nearest")
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    # If the input is bfloat16, we cast back to bfloat16
         | 
| 74 | 
            +
                    if dtype == torch.bfloat16:
         | 
| 75 | 
            +
                        hidden_states = hidden_states.to(dtype)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    # if self.use_conv:
         | 
| 78 | 
            +
                    #     if self.name == "conv":
         | 
| 79 | 
            +
                    #         hidden_states = self.conv(hidden_states)
         | 
| 80 | 
            +
                    #     else:
         | 
| 81 | 
            +
                    #         hidden_states = self.Conv2d_0(hidden_states)
         | 
| 82 | 
            +
                    hidden_states = self.conv(hidden_states)
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    return hidden_states
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            class Downsample3D(nn.Module):
         | 
| 88 | 
            +
                def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
         | 
| 89 | 
            +
                    super().__init__()
         | 
| 90 | 
            +
                    self.channels = channels
         | 
| 91 | 
            +
                    self.out_channels = out_channels or channels
         | 
| 92 | 
            +
                    self.use_conv = use_conv
         | 
| 93 | 
            +
                    self.padding = padding
         | 
| 94 | 
            +
                    stride = 2
         | 
| 95 | 
            +
                    self.name = name
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    if use_conv:
         | 
| 98 | 
            +
                        self.conv = InflatedConv3d(
         | 
| 99 | 
            +
                            self.channels, self.out_channels, 3, stride=stride, padding=padding)
         | 
| 100 | 
            +
                    else:
         | 
| 101 | 
            +
                        raise NotImplementedError
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def forward(self, hidden_states):
         | 
| 104 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 105 | 
            +
                    if self.use_conv and self.padding == 0:
         | 
| 106 | 
            +
                        raise NotImplementedError
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    assert hidden_states.shape[1] == self.channels
         | 
| 109 | 
            +
                    hidden_states = self.conv(hidden_states)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    return hidden_states
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            class ResnetBlock3D(nn.Module):
         | 
| 115 | 
            +
                def __init__(
         | 
| 116 | 
            +
                    self,
         | 
| 117 | 
            +
                    *,
         | 
| 118 | 
            +
                    in_channels,
         | 
| 119 | 
            +
                    out_channels=None,
         | 
| 120 | 
            +
                    conv_shortcut=False,
         | 
| 121 | 
            +
                    dropout=0.0,
         | 
| 122 | 
            +
                    temb_channels=512,
         | 
| 123 | 
            +
                    groups=32,
         | 
| 124 | 
            +
                    groups_out=None,
         | 
| 125 | 
            +
                    pre_norm=True,
         | 
| 126 | 
            +
                    eps=1e-6,
         | 
| 127 | 
            +
                    non_linearity="swish",
         | 
| 128 | 
            +
                    time_embedding_norm="default",
         | 
| 129 | 
            +
                    output_scale_factor=1.0,
         | 
| 130 | 
            +
                    use_in_shortcut=None,
         | 
| 131 | 
            +
                    use_inflated_groupnorm=None,
         | 
| 132 | 
            +
                    use_temporal_conv=False,
         | 
| 133 | 
            +
                    use_temporal_mixer=False,
         | 
| 134 | 
            +
                ):
         | 
| 135 | 
            +
                    super().__init__()
         | 
| 136 | 
            +
                    self.pre_norm = pre_norm
         | 
| 137 | 
            +
                    self.pre_norm = True
         | 
| 138 | 
            +
                    self.in_channels = in_channels
         | 
| 139 | 
            +
                    out_channels = in_channels if out_channels is None else out_channels
         | 
| 140 | 
            +
                    self.out_channels = out_channels
         | 
| 141 | 
            +
                    self.use_conv_shortcut = conv_shortcut
         | 
| 142 | 
            +
                    self.time_embedding_norm = time_embedding_norm
         | 
| 143 | 
            +
                    self.output_scale_factor = output_scale_factor
         | 
| 144 | 
            +
                    self.use_temporal_mixer = use_temporal_mixer
         | 
| 145 | 
            +
                    if use_temporal_mixer:
         | 
| 146 | 
            +
                        self.temporal_mixer = AlphaBlender(0.3, "learned", None)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    if groups_out is None:
         | 
| 149 | 
            +
                        groups_out = groups
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    assert use_inflated_groupnorm != None
         | 
| 152 | 
            +
                    if use_inflated_groupnorm:
         | 
| 153 | 
            +
                        self.norm1 = InflatedGroupNorm(
         | 
| 154 | 
            +
                            num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
         | 
| 155 | 
            +
                    else:
         | 
| 156 | 
            +
                        self.norm1 = torch.nn.GroupNorm(
         | 
| 157 | 
            +
                            num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    if use_temporal_conv:
         | 
| 160 | 
            +
                        self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=(
         | 
| 161 | 
            +
                            3, 1, 1), stride=1, padding=(1, 0, 0))
         | 
| 162 | 
            +
                    else:
         | 
| 163 | 
            +
                        self.conv1 = InflatedConv3d(
         | 
| 164 | 
            +
                            in_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    if temb_channels is not None:
         | 
| 167 | 
            +
                        if self.time_embedding_norm == "default":
         | 
| 168 | 
            +
                            time_emb_proj_out_channels = out_channels
         | 
| 169 | 
            +
                        elif self.time_embedding_norm == "scale_shift":
         | 
| 170 | 
            +
                            time_emb_proj_out_channels = out_channels * 2
         | 
| 171 | 
            +
                        else:
         | 
| 172 | 
            +
                            raise ValueError(
         | 
| 173 | 
            +
                                f"unknown time_embedding_norm : {self.time_embedding_norm} ")
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                        self.time_emb_proj = torch.nn.Linear(
         | 
| 176 | 
            +
                            temb_channels, time_emb_proj_out_channels)
         | 
| 177 | 
            +
                    else:
         | 
| 178 | 
            +
                        self.time_emb_proj = None
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    if use_inflated_groupnorm:
         | 
| 181 | 
            +
                        self.norm2 = InflatedGroupNorm(
         | 
| 182 | 
            +
                            num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
         | 
| 183 | 
            +
                    else:
         | 
| 184 | 
            +
                        self.norm2 = torch.nn.GroupNorm(
         | 
| 185 | 
            +
                            num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    self.dropout = torch.nn.Dropout(dropout)
         | 
| 188 | 
            +
                    if use_temporal_conv:
         | 
| 189 | 
            +
                        self.conv2 = nn.Conv3d(in_channels, out_channels, kernel_size=(
         | 
| 190 | 
            +
                            3, 1, 1), stride=1, padding=(1, 0, 0))
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        self.conv2 = InflatedConv3d(
         | 
| 193 | 
            +
                            out_channels, out_channels, kernel_size=3, stride=1, padding=1)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if non_linearity == "swish":
         | 
| 196 | 
            +
                        self.nonlinearity = lambda x: F.silu(x)
         | 
| 197 | 
            +
                    elif non_linearity == "mish":
         | 
| 198 | 
            +
                        self.nonlinearity = Mish()
         | 
| 199 | 
            +
                    elif non_linearity == "silu":
         | 
| 200 | 
            +
                        self.nonlinearity = nn.SiLU()
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    self.conv_shortcut = None
         | 
| 205 | 
            +
                    if self.use_in_shortcut:
         | 
| 206 | 
            +
                        self.conv_shortcut = InflatedConv3d(
         | 
| 207 | 
            +
                            in_channels, out_channels, kernel_size=1, stride=1, padding=0)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def forward(self, input_tensor, temb):
         | 
| 210 | 
            +
                    if self.use_temporal_mixer:
         | 
| 211 | 
            +
                        residual = input_tensor
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    hidden_states = input_tensor
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    hidden_states = self.norm1(hidden_states)
         | 
| 216 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    hidden_states = self.conv1(hidden_states)
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    if temb is not None:
         | 
| 221 | 
            +
                        temb = self.time_emb_proj(self.nonlinearity(temb))[
         | 
| 222 | 
            +
                            :, :, None, None, None]
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    if temb is not None and self.time_embedding_norm == "default":
         | 
| 225 | 
            +
                        hidden_states = hidden_states + temb
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    hidden_states = self.norm2(hidden_states)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
                    if temb is not None and self.time_embedding_norm == "scale_shift":
         | 
| 230 | 
            +
                        scale, shift = torch.chunk(temb, 2, dim=1)
         | 
| 231 | 
            +
                        hidden_states = hidden_states * (1 + scale) + shift
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    hidden_states = self.nonlinearity(hidden_states)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                    hidden_states = self.dropout(hidden_states)
         | 
| 236 | 
            +
                    hidden_states = self.conv2(hidden_states)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    if self.conv_shortcut is not None:
         | 
| 239 | 
            +
                        input_tensor = self.conv_shortcut(input_tensor)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                    output_tensor = (input_tensor + hidden_states) / \
         | 
| 242 | 
            +
                        self.output_scale_factor
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                    if self.use_temporal_mixer:
         | 
| 245 | 
            +
                        output_tensor = self.temporal_mixer(residual, output_tensor, None)
         | 
| 246 | 
            +
                        # return residual + 0.0 * self.temporal_mixer(residual, output_tensor, None)
         | 
| 247 | 
            +
                    return output_tensor
         | 
| 248 | 
            +
             | 
| 249 | 
            +
             | 
| 250 | 
            +
            class Mish(torch.nn.Module):
         | 
| 251 | 
            +
                def forward(self, hidden_states):
         | 
| 252 | 
            +
                    return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
         | 
| 253 | 
            +
             | 
| 254 | 
            +
             | 
| 255 | 
            +
            class AlphaBlender(nn.Module):
         | 
| 256 | 
            +
                strategies = ["learned", "fixed", "learned_with_images"]
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                def __init__(
         | 
| 259 | 
            +
                    self,
         | 
| 260 | 
            +
                    alpha: float,
         | 
| 261 | 
            +
                    merge_strategy: str = "learned_with_images",
         | 
| 262 | 
            +
                    rearrange_pattern: str = "b t -> (b t) 1 1",
         | 
| 263 | 
            +
                ):
         | 
| 264 | 
            +
                    super().__init__()
         | 
| 265 | 
            +
                    self.merge_strategy = merge_strategy
         | 
| 266 | 
            +
                    self.rearrange_pattern = rearrange_pattern
         | 
| 267 | 
            +
                    self.scaler = 10.
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    assert (
         | 
| 270 | 
            +
                        merge_strategy in self.strategies
         | 
| 271 | 
            +
                    ), f"merge_strategy needs to be in {self.strategies}"
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    if self.merge_strategy == "fixed":
         | 
| 274 | 
            +
                        self.register_buffer("mix_factor", torch.Tensor([alpha]))
         | 
| 275 | 
            +
                    elif (
         | 
| 276 | 
            +
                        self.merge_strategy == "learned"
         | 
| 277 | 
            +
                        or self.merge_strategy == "learned_with_images"
         | 
| 278 | 
            +
                    ):
         | 
| 279 | 
            +
                        self.register_parameter(
         | 
| 280 | 
            +
                            "mix_factor", torch.nn.Parameter(torch.Tensor([alpha]))
         | 
| 281 | 
            +
                        )
         | 
| 282 | 
            +
                    else:
         | 
| 283 | 
            +
                        raise ValueError(f"unknown merge strategy {self.merge_strategy}")
         | 
| 284 | 
            +
             | 
| 285 | 
            +
                def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
         | 
| 286 | 
            +
                    if self.merge_strategy == "fixed":
         | 
| 287 | 
            +
                        alpha = self.mix_factor
         | 
| 288 | 
            +
                    elif self.merge_strategy == "learned":
         | 
| 289 | 
            +
                        alpha = torch.sigmoid(self.mix_factor*self.scaler)
         | 
| 290 | 
            +
                    elif self.merge_strategy == "learned_with_images":
         | 
| 291 | 
            +
                        assert image_only_indicator is not None, "need image_only_indicator ..."
         | 
| 292 | 
            +
                        alpha = torch.where(
         | 
| 293 | 
            +
                            image_only_indicator.bool(),
         | 
| 294 | 
            +
                            torch.ones(1, 1, device=image_only_indicator.device),
         | 
| 295 | 
            +
                            rearrange(torch.sigmoid(self.mix_factor), "... -> ... 1"),
         | 
| 296 | 
            +
                        )
         | 
| 297 | 
            +
                        alpha = rearrange(alpha, self.rearrange_pattern)
         | 
| 298 | 
            +
                    else:
         | 
| 299 | 
            +
                        raise NotImplementedError
         | 
| 300 | 
            +
                    return alpha
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                def forward(
         | 
| 303 | 
            +
                    self,
         | 
| 304 | 
            +
                    x_spatial: torch.Tensor,
         | 
| 305 | 
            +
                    x_temporal: torch.Tensor,
         | 
| 306 | 
            +
                    image_only_indicator: Optional[torch.Tensor] = None,
         | 
| 307 | 
            +
                ) -> torch.Tensor:
         | 
| 308 | 
            +
                    alpha = self.get_alpha(image_only_indicator)
         | 
| 309 | 
            +
                    x = (
         | 
| 310 | 
            +
                        alpha.to(x_spatial.dtype) * x_spatial
         | 
| 311 | 
            +
                        + (1.0 - alpha).to(x_spatial.dtype) * x_temporal
         | 
| 312 | 
            +
                    )
         | 
| 313 | 
            +
                    return x
         | 
    	
        animatelcm/models/unet.py
    ADDED
    
    | @@ -0,0 +1,568 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from dataclasses import dataclass
         | 
| 4 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import json
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            import torch.utils.checkpoint
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 14 | 
            +
            from diffusers.modeling_utils import ModelMixin
         | 
| 15 | 
            +
            from diffusers.utils import BaseOutput, logging
         | 
| 16 | 
            +
            from animatelcm.models.embeddings import TimestepEmbedding, Timesteps
         | 
| 17 | 
            +
            from .unet_blocks import (
         | 
| 18 | 
            +
                CrossAttnDownBlock3D,
         | 
| 19 | 
            +
                CrossAttnUpBlock3D,
         | 
| 20 | 
            +
                DownBlock3D,
         | 
| 21 | 
            +
                UNetMidBlock3DCrossAttn,
         | 
| 22 | 
            +
                UpBlock3D,
         | 
| 23 | 
            +
                get_down_block,
         | 
| 24 | 
            +
                get_up_block,
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
            from .resnet import InflatedConv3d, InflatedGroupNorm
         | 
| 27 | 
            +
            # from .adapter import Adapter, PixelAdapter # Not ready
         | 
| 28 | 
            +
            from einops import repeat
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            @dataclass
         | 
| 35 | 
            +
            class UNet3DConditionOutput(BaseOutput):
         | 
| 36 | 
            +
                sample: torch.FloatTensor
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            class UNet3DConditionModel(ModelMixin, ConfigMixin):
         | 
| 40 | 
            +
                _supports_gradient_checkpointing = True
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                @register_to_config
         | 
| 43 | 
            +
                def __init__(
         | 
| 44 | 
            +
                    self,
         | 
| 45 | 
            +
                    sample_size: Optional[int] = None,
         | 
| 46 | 
            +
                    in_channels: int = 4,
         | 
| 47 | 
            +
                    out_channels: int = 4,
         | 
| 48 | 
            +
                    center_input_sample: bool = False,
         | 
| 49 | 
            +
                    flip_sin_to_cos: bool = True,
         | 
| 50 | 
            +
                    freq_shift: int = 0,
         | 
| 51 | 
            +
                    down_block_types: Tuple[str] = (
         | 
| 52 | 
            +
                        "CrossAttnDownBlock3D",
         | 
| 53 | 
            +
                        "CrossAttnDownBlock3D",
         | 
| 54 | 
            +
                        "CrossAttnDownBlock3D",
         | 
| 55 | 
            +
                        "DownBlock3D",
         | 
| 56 | 
            +
                    ),
         | 
| 57 | 
            +
                    mid_block_type: str = "UNetMidBlock3DCrossAttn",
         | 
| 58 | 
            +
                    up_block_types: Tuple[str] = (
         | 
| 59 | 
            +
                        "UpBlock3D",
         | 
| 60 | 
            +
                        "CrossAttnUpBlock3D",
         | 
| 61 | 
            +
                        "CrossAttnUpBlock3D",
         | 
| 62 | 
            +
                        "CrossAttnUpBlock3D"
         | 
| 63 | 
            +
                    ),
         | 
| 64 | 
            +
                    only_cross_attention: Union[bool, Tuple[bool]] = False,
         | 
| 65 | 
            +
                    block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
         | 
| 66 | 
            +
                    layers_per_block: int = 2,
         | 
| 67 | 
            +
                    downsample_padding: int = 1,
         | 
| 68 | 
            +
                    mid_block_scale_factor: float = 1,
         | 
| 69 | 
            +
                    act_fn: str = "silu",
         | 
| 70 | 
            +
                    norm_num_groups: int = 32,
         | 
| 71 | 
            +
                    norm_eps: float = 1e-5,
         | 
| 72 | 
            +
                    cross_attention_dim: int = 1280,
         | 
| 73 | 
            +
                    attention_head_dim: Union[int, Tuple[int]] = 8,
         | 
| 74 | 
            +
                    dual_cross_attention: bool = False,
         | 
| 75 | 
            +
                    use_linear_projection: bool = False,
         | 
| 76 | 
            +
                    class_embed_type: Optional[str] = None,
         | 
| 77 | 
            +
                    num_class_embeds: Optional[int] = None,
         | 
| 78 | 
            +
                    upcast_attention: bool = False,
         | 
| 79 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    use_inflated_groupnorm=False,
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    # Additional
         | 
| 84 | 
            +
                    use_motion_module=False,
         | 
| 85 | 
            +
                    use_motion_resnet=False,
         | 
| 86 | 
            +
                    motion_module_resolutions=(1, 2, 4, 8),
         | 
| 87 | 
            +
                    motion_module_mid_block=False,
         | 
| 88 | 
            +
                    motion_module_decoder_only=False,
         | 
| 89 | 
            +
                    motion_module_type=None,
         | 
| 90 | 
            +
                    motion_module_kwargs={},
         | 
| 91 | 
            +
                    unet_use_cross_frame_attention=None,
         | 
| 92 | 
            +
                    unet_use_temporal_attention=None,
         | 
| 93 | 
            +
                    time_cond_proj_dim=None, # not ready
         | 
| 94 | 
            +
                    use_img_encoder=False,
         | 
| 95 | 
            +
                    use_pixel_encoder=False,
         | 
| 96 | 
            +
                ):
         | 
| 97 | 
            +
                    super().__init__()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.sample_size = sample_size
         | 
| 100 | 
            +
                    time_embed_dim = block_out_channels[0] * 4
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    self.img_encoder = None if use_img_encoder else None # not ready
         | 
| 103 | 
            +
                    self.pixel_encoder = None if use_pixel_encoder else None # not ready
         | 
| 104 | 
            +
                    
         | 
| 105 | 
            +
                    
         | 
| 106 | 
            +
                    # input
         | 
| 107 | 
            +
                    self.conv_in = InflatedConv3d(
         | 
| 108 | 
            +
                        in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    # time
         | 
| 111 | 
            +
                    self.time_proj = Timesteps(
         | 
| 112 | 
            +
                        block_out_channels[0], flip_sin_to_cos, freq_shift)
         | 
| 113 | 
            +
                    timestep_input_dim = block_out_channels[0]
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    self.time_embedding = TimestepEmbedding(
         | 
| 116 | 
            +
                        timestep_input_dim, time_embed_dim, time_cond_proj_dim=time_cond_proj_dim)
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    # class embedding
         | 
| 119 | 
            +
                    if class_embed_type is None and num_class_embeds is not None:
         | 
| 120 | 
            +
                        self.class_embedding = nn.Embedding(
         | 
| 121 | 
            +
                            num_class_embeds, time_embed_dim)
         | 
| 122 | 
            +
                    elif class_embed_type == "timestep":
         | 
| 123 | 
            +
                        self.class_embedding = TimestepEmbedding(
         | 
| 124 | 
            +
                            timestep_input_dim, time_embed_dim)
         | 
| 125 | 
            +
                    elif class_embed_type == "identity":
         | 
| 126 | 
            +
                        self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
         | 
| 127 | 
            +
                    else:
         | 
| 128 | 
            +
                        self.class_embedding = None
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    self.down_blocks = nn.ModuleList([])
         | 
| 131 | 
            +
                    self.mid_block = None
         | 
| 132 | 
            +
                    self.up_blocks = nn.ModuleList([])
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    if isinstance(only_cross_attention, bool):
         | 
| 135 | 
            +
                        only_cross_attention = [
         | 
| 136 | 
            +
                            only_cross_attention] * len(down_block_types)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    if isinstance(attention_head_dim, int):
         | 
| 139 | 
            +
                        attention_head_dim = (attention_head_dim,) * len(down_block_types)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    # down
         | 
| 142 | 
            +
                    output_channel = block_out_channels[0]
         | 
| 143 | 
            +
                    for i, down_block_type in enumerate(down_block_types):
         | 
| 144 | 
            +
                        res = 2 ** i
         | 
| 145 | 
            +
                        input_channel = output_channel
         | 
| 146 | 
            +
                        output_channel = block_out_channels[i]
         | 
| 147 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                        down_block = get_down_block(
         | 
| 150 | 
            +
                            down_block_type,
         | 
| 151 | 
            +
                            num_layers=layers_per_block,
         | 
| 152 | 
            +
                            in_channels=input_channel,
         | 
| 153 | 
            +
                            out_channels=output_channel,
         | 
| 154 | 
            +
                            temb_channels=time_embed_dim,
         | 
| 155 | 
            +
                            add_downsample=not is_final_block,
         | 
| 156 | 
            +
                            resnet_eps=norm_eps,
         | 
| 157 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 158 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 159 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 160 | 
            +
                            attn_num_head_channels=attention_head_dim[i],
         | 
| 161 | 
            +
                            downsample_padding=downsample_padding,
         | 
| 162 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 163 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 164 | 
            +
                            only_cross_attention=only_cross_attention[i],
         | 
| 165 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 166 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                            unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 169 | 
            +
                            unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 170 | 
            +
                            use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                            use_motion_module=use_motion_module and (
         | 
| 173 | 
            +
                                res in motion_module_resolutions) and (not motion_module_decoder_only),
         | 
| 174 | 
            +
                            use_motion_resnet=use_motion_resnet and (
         | 
| 175 | 
            +
                                res in motion_module_resolutions) and (not motion_module_decoder_only),
         | 
| 176 | 
            +
                            motion_module_type=motion_module_type,
         | 
| 177 | 
            +
                            motion_module_kwargs=motion_module_kwargs,
         | 
| 178 | 
            +
                        )
         | 
| 179 | 
            +
                        self.down_blocks.append(down_block)
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    # mid
         | 
| 182 | 
            +
                    if mid_block_type == "UNetMidBlock3DCrossAttn":
         | 
| 183 | 
            +
                        self.mid_block = UNetMidBlock3DCrossAttn(
         | 
| 184 | 
            +
                            in_channels=block_out_channels[-1],
         | 
| 185 | 
            +
                            temb_channels=time_embed_dim,
         | 
| 186 | 
            +
                            resnet_eps=norm_eps,
         | 
| 187 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 188 | 
            +
                            output_scale_factor=mid_block_scale_factor,
         | 
| 189 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 190 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 191 | 
            +
                            attn_num_head_channels=attention_head_dim[-1],
         | 
| 192 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 193 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 194 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 195 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                            unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 198 | 
            +
                            unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 199 | 
            +
                            use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                            use_motion_module=use_motion_module and motion_module_mid_block,
         | 
| 202 | 
            +
                            use_motion_resnet=use_motion_resnet and motion_module_mid_block,
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                            motion_module_type=motion_module_type,
         | 
| 205 | 
            +
                            motion_module_kwargs=motion_module_kwargs,
         | 
| 206 | 
            +
                        )
         | 
| 207 | 
            +
                    else:
         | 
| 208 | 
            +
                        raise ValueError(f"unknown mid_block_type : {mid_block_type}")
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                    # count how many layers upsample the videos
         | 
| 211 | 
            +
                    self.num_upsamplers = 0
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # up
         | 
| 214 | 
            +
                    reversed_block_out_channels = list(reversed(block_out_channels))
         | 
| 215 | 
            +
                    reversed_attention_head_dim = list(reversed(attention_head_dim))
         | 
| 216 | 
            +
                    only_cross_attention = list(reversed(only_cross_attention))
         | 
| 217 | 
            +
                    output_channel = reversed_block_out_channels[0]
         | 
| 218 | 
            +
                    for i, up_block_type in enumerate(up_block_types):
         | 
| 219 | 
            +
                        res = 2 ** (3 - i)
         | 
| 220 | 
            +
                        is_final_block = i == len(block_out_channels) - 1
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                        prev_output_channel = output_channel
         | 
| 223 | 
            +
                        output_channel = reversed_block_out_channels[i]
         | 
| 224 | 
            +
                        input_channel = reversed_block_out_channels[min(
         | 
| 225 | 
            +
                            i + 1, len(block_out_channels) - 1)]
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                        # add upsample block for all BUT final layer
         | 
| 228 | 
            +
                        if not is_final_block:
         | 
| 229 | 
            +
                            add_upsample = True
         | 
| 230 | 
            +
                            self.num_upsamplers += 1
         | 
| 231 | 
            +
                        else:
         | 
| 232 | 
            +
                            add_upsample = False
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                        up_block = get_up_block(
         | 
| 235 | 
            +
                            up_block_type,
         | 
| 236 | 
            +
                            num_layers=layers_per_block + 1,
         | 
| 237 | 
            +
                            in_channels=input_channel,
         | 
| 238 | 
            +
                            out_channels=output_channel,
         | 
| 239 | 
            +
                            prev_output_channel=prev_output_channel,
         | 
| 240 | 
            +
                            temb_channels=time_embed_dim,
         | 
| 241 | 
            +
                            add_upsample=add_upsample,
         | 
| 242 | 
            +
                            resnet_eps=norm_eps,
         | 
| 243 | 
            +
                            resnet_act_fn=act_fn,
         | 
| 244 | 
            +
                            resnet_groups=norm_num_groups,
         | 
| 245 | 
            +
                            cross_attention_dim=cross_attention_dim,
         | 
| 246 | 
            +
                            attn_num_head_channels=reversed_attention_head_dim[i],
         | 
| 247 | 
            +
                            dual_cross_attention=dual_cross_attention,
         | 
| 248 | 
            +
                            use_linear_projection=use_linear_projection,
         | 
| 249 | 
            +
                            only_cross_attention=only_cross_attention[i],
         | 
| 250 | 
            +
                            upcast_attention=upcast_attention,
         | 
| 251 | 
            +
                            resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                            unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 254 | 
            +
                            unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 255 | 
            +
                            use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                            use_motion_module=use_motion_module and (
         | 
| 258 | 
            +
                                res in motion_module_resolutions),
         | 
| 259 | 
            +
                            use_motion_resnet=use_motion_resnet and (
         | 
| 260 | 
            +
                                res in motion_module_resolutions),
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                            motion_module_type=motion_module_type,
         | 
| 263 | 
            +
                            motion_module_kwargs=motion_module_kwargs,
         | 
| 264 | 
            +
                        )
         | 
| 265 | 
            +
                        self.up_blocks.append(up_block)
         | 
| 266 | 
            +
                        prev_output_channel = output_channel
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                    # out
         | 
| 269 | 
            +
                    if use_inflated_groupnorm:
         | 
| 270 | 
            +
                        self.conv_norm_out = InflatedGroupNorm(
         | 
| 271 | 
            +
                            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
         | 
| 272 | 
            +
                    else:
         | 
| 273 | 
            +
                        self.conv_norm_out = nn.GroupNorm(
         | 
| 274 | 
            +
                            num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
         | 
| 275 | 
            +
                    self.conv_act = nn.SiLU()
         | 
| 276 | 
            +
                    self.conv_out = InflatedConv3d(
         | 
| 277 | 
            +
                        block_out_channels[0], out_channels, kernel_size=3, padding=1)
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def set_attention_slice(self, slice_size):
         | 
| 280 | 
            +
                    r"""
         | 
| 281 | 
            +
                    Enable sliced attention computation.
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    When this option is enabled, the attention module will split the input tensor in slices, to compute attention
         | 
| 284 | 
            +
                    in several steps. This is useful to save some memory in exchange for a small speed decrease.
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    Args:
         | 
| 287 | 
            +
                        slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
         | 
| 288 | 
            +
                            When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
         | 
| 289 | 
            +
                            `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
         | 
| 290 | 
            +
                            provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
         | 
| 291 | 
            +
                            must be a multiple of `slice_size`.
         | 
| 292 | 
            +
                    """
         | 
| 293 | 
            +
                    sliceable_head_dims = []
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                    def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
         | 
| 296 | 
            +
                        if hasattr(module, "set_attention_slice"):
         | 
| 297 | 
            +
                            sliceable_head_dims.append(module.sliceable_head_dim)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                        for child in module.children():
         | 
| 300 | 
            +
                            fn_recursive_retrieve_slicable_dims(child)
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                    # retrieve number of attention layers
         | 
| 303 | 
            +
                    for module in self.children():
         | 
| 304 | 
            +
                        fn_recursive_retrieve_slicable_dims(module)
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    num_slicable_layers = len(sliceable_head_dims)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                    if slice_size == "auto":
         | 
| 309 | 
            +
                        # half the attention head size is usually a good trade-off between
         | 
| 310 | 
            +
                        # speed and memory
         | 
| 311 | 
            +
                        slice_size = [dim // 2 for dim in sliceable_head_dims]
         | 
| 312 | 
            +
                    elif slice_size == "max":
         | 
| 313 | 
            +
                        # make smallest slice possible
         | 
| 314 | 
            +
                        slice_size = num_slicable_layers * [1]
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    slice_size = num_slicable_layers * \
         | 
| 317 | 
            +
                        [slice_size] if not isinstance(slice_size, list) else slice_size
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    if len(slice_size) != len(sliceable_head_dims):
         | 
| 320 | 
            +
                        raise ValueError(
         | 
| 321 | 
            +
                            f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
         | 
| 322 | 
            +
                            f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
         | 
| 323 | 
            +
                        )
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    for i in range(len(slice_size)):
         | 
| 326 | 
            +
                        size = slice_size[i]
         | 
| 327 | 
            +
                        dim = sliceable_head_dims[i]
         | 
| 328 | 
            +
                        if size is not None and size > dim:
         | 
| 329 | 
            +
                            raise ValueError(
         | 
| 330 | 
            +
                                f"size {size} has to be smaller or equal to {dim}.")
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    # Recursively walk through all the children.
         | 
| 333 | 
            +
                    # Any children which exposes the set_attention_slice method
         | 
| 334 | 
            +
                    # gets the message
         | 
| 335 | 
            +
                    def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
         | 
| 336 | 
            +
                        if hasattr(module, "set_attention_slice"):
         | 
| 337 | 
            +
                            module.set_attention_slice(slice_size.pop())
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                        for child in module.children():
         | 
| 340 | 
            +
                            fn_recursive_set_attention_slice(child, slice_size)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    reversed_slice_size = list(reversed(slice_size))
         | 
| 343 | 
            +
                    for module in self.children():
         | 
| 344 | 
            +
                        fn_recursive_set_attention_slice(module, reversed_slice_size)
         | 
| 345 | 
            +
             | 
| 346 | 
            +
                def _set_gradient_checkpointing(self, module, value=False):
         | 
| 347 | 
            +
                    if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
         | 
| 348 | 
            +
                        module.gradient_checkpointing = value
         | 
| 349 | 
            +
             | 
| 350 | 
            +
                def forward(
         | 
| 351 | 
            +
                    self,
         | 
| 352 | 
            +
                    sample: torch.FloatTensor,
         | 
| 353 | 
            +
                    timestep: Union[torch.Tensor, float, int],
         | 
| 354 | 
            +
                    encoder_hidden_states: torch.Tensor,
         | 
| 355 | 
            +
                    img_latent: torch.FloatTensor = None,
         | 
| 356 | 
            +
                    control: torch.FloatTensor = None,
         | 
| 357 | 
            +
                    time_cond: torch.FloatTensor = None,  # not ready
         | 
| 358 | 
            +
                    class_labels: Optional[torch.Tensor] = None,
         | 
| 359 | 
            +
                    attention_mask: Optional[torch.Tensor] = None,
         | 
| 360 | 
            +
                    return_dict: bool = True,
         | 
| 361 | 
            +
                ) -> Union[UNet3DConditionOutput, Tuple]:
         | 
| 362 | 
            +
                    r"""
         | 
| 363 | 
            +
                    Args:
         | 
| 364 | 
            +
                        sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
         | 
| 365 | 
            +
                        timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
         | 
| 366 | 
            +
                        encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
         | 
| 367 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 368 | 
            +
                            Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    Returns:
         | 
| 371 | 
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
         | 
| 372 | 
            +
                        [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
         | 
| 373 | 
            +
                        returning a tuple, the first element is the sample tensor.
         | 
| 374 | 
            +
                    """
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    if img_latent is not None and self.img_encoder is not None:
         | 
| 377 | 
            +
                        f = sample.shape[2]
         | 
| 378 | 
            +
                        img_latent = repeat(img_latent, "b c  h w  -> b c f h w",
         | 
| 379 | 
            +
                                            f=f) if img_latent.ndim == 4 else img_latent
         | 
| 380 | 
            +
                        img_features = self.img_encoder(img_latent)
         | 
| 381 | 
            +
                    else:
         | 
| 382 | 
            +
                        img_features = None
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                    if control is not None and self.pixel_encoder is not None:
         | 
| 385 | 
            +
                        ctrl_features = self.pixel_encoder(control)
         | 
| 386 | 
            +
                    else:
         | 
| 387 | 
            +
                        # assert 0
         | 
| 388 | 
            +
                        ctrl_features = None
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                    # By default samples have to be AT least a multiple of the overall upsampling factor.
         | 
| 391 | 
            +
                    # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
         | 
| 392 | 
            +
                    # However, the upsampling interpolation output size can be forced to fit any upsampling size
         | 
| 393 | 
            +
                    # on the fly if necessary.
         | 
| 394 | 
            +
                    default_overall_up_factor = 2**self.num_upsamplers
         | 
| 395 | 
            +
             | 
| 396 | 
            +
                    # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
         | 
| 397 | 
            +
                    forward_upsample_size = False
         | 
| 398 | 
            +
                    upsample_size = None
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                    if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
         | 
| 401 | 
            +
                        logger.info(
         | 
| 402 | 
            +
                            "Forward upsample size to force interpolation output size.")
         | 
| 403 | 
            +
                        forward_upsample_size = True
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                    # prepare attention_mask
         | 
| 406 | 
            +
                    if attention_mask is not None:
         | 
| 407 | 
            +
                        attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
         | 
| 408 | 
            +
                        attention_mask = attention_mask.unsqueeze(1)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    # center input if necessary
         | 
| 411 | 
            +
                    if self.config.center_input_sample:
         | 
| 412 | 
            +
                        sample = 2 * sample - 1.0
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                    # time
         | 
| 415 | 
            +
                    timesteps = timestep
         | 
| 416 | 
            +
                    if not torch.is_tensor(timesteps):
         | 
| 417 | 
            +
                        # This would be a good case for the `match` statement (Python 3.10+)
         | 
| 418 | 
            +
                        is_mps = sample.device.type == "mps"
         | 
| 419 | 
            +
                        if isinstance(timestep, float):
         | 
| 420 | 
            +
                            dtype = torch.float32 if is_mps else torch.float64
         | 
| 421 | 
            +
                        else:
         | 
| 422 | 
            +
                            dtype = torch.int32 if is_mps else torch.int64
         | 
| 423 | 
            +
                        timesteps = torch.tensor(
         | 
| 424 | 
            +
                            [timesteps], dtype=dtype, device=sample.device)
         | 
| 425 | 
            +
                    elif len(timesteps.shape) == 0:
         | 
| 426 | 
            +
                        timesteps = timesteps[None].to(sample.device)
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
         | 
| 429 | 
            +
                    timesteps = timesteps.expand(sample.shape[0])
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    t_emb = self.time_proj(timesteps)
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                    # timesteps does not contain any weights and will always return f32 tensors
         | 
| 434 | 
            +
                    # but time_embedding might actually be running in fp16. so we need to cast here.
         | 
| 435 | 
            +
                    # there might be better ways to encapsulate this.
         | 
| 436 | 
            +
                    t_emb = t_emb.to(dtype=self.dtype)
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                    emb = self.time_embedding(t_emb)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    if self.class_embedding is not None:
         | 
| 441 | 
            +
                        if class_labels is None:
         | 
| 442 | 
            +
                            raise ValueError(
         | 
| 443 | 
            +
                                "class_labels should be provided when num_class_embeds > 0")
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                        if self.config.class_embed_type == "timestep":
         | 
| 446 | 
            +
                            class_labels = self.time_proj(class_labels)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                        class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
         | 
| 449 | 
            +
                        emb = emb + class_emb
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    # pre-process
         | 
| 452 | 
            +
                    sample = self.conv_in(sample)
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    # down
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    down_block_res_samples = (sample,)
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                    img_feature_idx = 0
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    for downsample_block in self.down_blocks:
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                        added_feature = img_features[img_feature_idx] if img_features is not None else torch.tensor(
         | 
| 463 | 
            +
                            0.).to(sample.device, sample.dtype)
         | 
| 464 | 
            +
                        added_feature = added_feature + \
         | 
| 465 | 
            +
                            ctrl_features[img_feature_idx] if ctrl_features is not None else added_feature
         | 
| 466 | 
            +
                        added_feature = None if added_feature.abs().mean() == 0 else added_feature
         | 
| 467 | 
            +
                        if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
         | 
| 468 | 
            +
                            sample, res_samples = downsample_block(
         | 
| 469 | 
            +
                                hidden_states=sample,
         | 
| 470 | 
            +
                                temb=emb,
         | 
| 471 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 472 | 
            +
                                attention_mask=attention_mask,
         | 
| 473 | 
            +
                                img_feature=added_feature
         | 
| 474 | 
            +
                            )
         | 
| 475 | 
            +
                        else:
         | 
| 476 | 
            +
                            sample, res_samples = downsample_block(
         | 
| 477 | 
            +
                                hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, img_feature=added_feature)
         | 
| 478 | 
            +
             | 
| 479 | 
            +
                        down_block_res_samples += res_samples
         | 
| 480 | 
            +
                        img_feature_idx += 1
         | 
| 481 | 
            +
                    # mid
         | 
| 482 | 
            +
                    sample = self.mid_block(
         | 
| 483 | 
            +
                        sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
         | 
| 484 | 
            +
                    )
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    # up
         | 
| 487 | 
            +
                    for i, upsample_block in enumerate(self.up_blocks):
         | 
| 488 | 
            +
                        is_final_block = i == len(self.up_blocks) - 1
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                        res_samples = down_block_res_samples[-len(upsample_block.resnets):]
         | 
| 491 | 
            +
                        down_block_res_samples = down_block_res_samples[: -len(
         | 
| 492 | 
            +
                            upsample_block.resnets)]
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                        # if we have not reached the final block and need to forward the
         | 
| 495 | 
            +
                        # upsample size, we do it here
         | 
| 496 | 
            +
                        if not is_final_block and forward_upsample_size:
         | 
| 497 | 
            +
                            upsample_size = down_block_res_samples[-1].shape[2:]
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                        if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
         | 
| 500 | 
            +
                            sample = upsample_block(
         | 
| 501 | 
            +
                                hidden_states=sample,
         | 
| 502 | 
            +
                                temb=emb,
         | 
| 503 | 
            +
                                res_hidden_states_tuple=res_samples,
         | 
| 504 | 
            +
                                encoder_hidden_states=encoder_hidden_states,
         | 
| 505 | 
            +
                                upsample_size=upsample_size,
         | 
| 506 | 
            +
                                attention_mask=attention_mask,
         | 
| 507 | 
            +
                            )
         | 
| 508 | 
            +
                        else:
         | 
| 509 | 
            +
                            sample = upsample_block(
         | 
| 510 | 
            +
                                hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states,
         | 
| 511 | 
            +
                            )
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    # post-process
         | 
| 514 | 
            +
                    sample = self.conv_norm_out(sample)
         | 
| 515 | 
            +
                    sample = self.conv_act(sample)
         | 
| 516 | 
            +
                    sample = self.conv_out(sample)
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    if not return_dict:
         | 
| 519 | 
            +
                        return (sample,)
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                    return UNet3DConditionOutput(sample=sample)
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                @classmethod
         | 
| 524 | 
            +
                def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None):
         | 
| 525 | 
            +
                    if subfolder is not None:
         | 
| 526 | 
            +
                        pretrained_model_path = os.path.join(
         | 
| 527 | 
            +
                            pretrained_model_path, subfolder)
         | 
| 528 | 
            +
                    print(
         | 
| 529 | 
            +
                        f"loaded temporal unet's pretrained weights from {pretrained_model_path} ...")
         | 
| 530 | 
            +
             | 
| 531 | 
            +
                    config_file = os.path.join(pretrained_model_path, 'config.json')
         | 
| 532 | 
            +
                    if not os.path.isfile(config_file):
         | 
| 533 | 
            +
                        raise RuntimeError(f"{config_file} does not exist")
         | 
| 534 | 
            +
                    with open(config_file, "r") as f:
         | 
| 535 | 
            +
                        config = json.load(f)
         | 
| 536 | 
            +
                    config["_class_name"] = cls.__name__
         | 
| 537 | 
            +
                    config["down_block_types"] = [
         | 
| 538 | 
            +
                        "CrossAttnDownBlock3D",
         | 
| 539 | 
            +
                        "CrossAttnDownBlock3D",
         | 
| 540 | 
            +
                        "CrossAttnDownBlock3D",
         | 
| 541 | 
            +
                        "DownBlock3D"
         | 
| 542 | 
            +
                    ]
         | 
| 543 | 
            +
                    config["up_block_types"] = [
         | 
| 544 | 
            +
                        "UpBlock3D",
         | 
| 545 | 
            +
                        "CrossAttnUpBlock3D",
         | 
| 546 | 
            +
                        "CrossAttnUpBlock3D",
         | 
| 547 | 
            +
                        "CrossAttnUpBlock3D"
         | 
| 548 | 
            +
                    ]
         | 
| 549 | 
            +
             | 
| 550 | 
            +
                    from diffusers.utils import WEIGHTS_NAME
         | 
| 551 | 
            +
                    model = cls.from_config(config, **unet_additional_kwargs)
         | 
| 552 | 
            +
                    model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
         | 
| 553 | 
            +
                    if not os.path.isfile(model_file):
         | 
| 554 | 
            +
                        raise RuntimeError(f"{model_file} does not exist")
         | 
| 555 | 
            +
                    state_dict = torch.load(model_file, map_location="cpu")
         | 
| 556 | 
            +
                    if "state_dict" in state_dict.keys():
         | 
| 557 | 
            +
                        state_dict = state_dict["state_dict"]
         | 
| 558 | 
            +
                        state_dict = {k.replace("module.", ""): v for k,
         | 
| 559 | 
            +
                                      v in state_dict.items()}
         | 
| 560 | 
            +
                    m, u = model.load_state_dict(state_dict, strict=False)
         | 
| 561 | 
            +
                    print("###load unet weights")
         | 
| 562 | 
            +
                    print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
         | 
| 563 | 
            +
             | 
| 564 | 
            +
                    params = [p.numel() if "motion" in n else 0 for n,
         | 
| 565 | 
            +
                              p in model.named_parameters()]
         | 
| 566 | 
            +
                    print(f"### Temporal Module Parameters: {sum(params) / 1e6} M")
         | 
| 567 | 
            +
             | 
| 568 | 
            +
                    return model
         | 
    	
        animatelcm/models/unet_blocks.py
    ADDED
    
    | @@ -0,0 +1,904 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from torch import nn
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from .attention import Transformer3DModel
         | 
| 7 | 
            +
            from .resnet import Downsample3D, ResnetBlock3D, Upsample3D, AlphaBlender
         | 
| 8 | 
            +
            from .motion_module import get_motion_module
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def get_down_block(
         | 
| 12 | 
            +
                down_block_type,
         | 
| 13 | 
            +
                num_layers,
         | 
| 14 | 
            +
                in_channels,
         | 
| 15 | 
            +
                out_channels,
         | 
| 16 | 
            +
                temb_channels,
         | 
| 17 | 
            +
                add_downsample,
         | 
| 18 | 
            +
                resnet_eps,
         | 
| 19 | 
            +
                resnet_act_fn,
         | 
| 20 | 
            +
                attn_num_head_channels,
         | 
| 21 | 
            +
                resnet_groups=None,
         | 
| 22 | 
            +
                cross_attention_dim=None,
         | 
| 23 | 
            +
                downsample_padding=None,
         | 
| 24 | 
            +
                dual_cross_attention=False,
         | 
| 25 | 
            +
                use_linear_projection=False,
         | 
| 26 | 
            +
                only_cross_attention=False,
         | 
| 27 | 
            +
                upcast_attention=False,
         | 
| 28 | 
            +
                resnet_time_scale_shift="default",
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                unet_use_cross_frame_attention=None,
         | 
| 31 | 
            +
                unet_use_temporal_attention=None,
         | 
| 32 | 
            +
                use_inflated_groupnorm=None,
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                use_motion_module=None,
         | 
| 35 | 
            +
                use_motion_resnet=None,  # not used for current weight
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                motion_module_type=None,
         | 
| 38 | 
            +
                motion_module_kwargs=None,
         | 
| 39 | 
            +
            ):
         | 
| 40 | 
            +
                down_block_type = down_block_type[7:] if down_block_type.startswith(
         | 
| 41 | 
            +
                    "UNetRes") else down_block_type
         | 
| 42 | 
            +
                if down_block_type == "DownBlock3D":
         | 
| 43 | 
            +
                    return DownBlock3D(
         | 
| 44 | 
            +
                        num_layers=num_layers,
         | 
| 45 | 
            +
                        in_channels=in_channels,
         | 
| 46 | 
            +
                        out_channels=out_channels,
         | 
| 47 | 
            +
                        temb_channels=temb_channels,
         | 
| 48 | 
            +
                        add_downsample=add_downsample,
         | 
| 49 | 
            +
                        resnet_eps=resnet_eps,
         | 
| 50 | 
            +
                        resnet_act_fn=resnet_act_fn,
         | 
| 51 | 
            +
                        resnet_groups=resnet_groups,
         | 
| 52 | 
            +
                        downsample_padding=downsample_padding,
         | 
| 53 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                        use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                        use_motion_module=use_motion_module,
         | 
| 58 | 
            +
                        motion_module_type=motion_module_type,
         | 
| 59 | 
            +
                        motion_module_kwargs=motion_module_kwargs,
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
                elif down_block_type == "CrossAttnDownBlock3D":
         | 
| 62 | 
            +
                    if cross_attention_dim is None:
         | 
| 63 | 
            +
                        raise ValueError(
         | 
| 64 | 
            +
                            "cross_attention_dim must be specified for CrossAttnDownBlock3D")
         | 
| 65 | 
            +
                    return CrossAttnDownBlock3D(
         | 
| 66 | 
            +
                        num_layers=num_layers,
         | 
| 67 | 
            +
                        in_channels=in_channels,
         | 
| 68 | 
            +
                        out_channels=out_channels,
         | 
| 69 | 
            +
                        temb_channels=temb_channels,
         | 
| 70 | 
            +
                        add_downsample=add_downsample,
         | 
| 71 | 
            +
                        resnet_eps=resnet_eps,
         | 
| 72 | 
            +
                        resnet_act_fn=resnet_act_fn,
         | 
| 73 | 
            +
                        resnet_groups=resnet_groups,
         | 
| 74 | 
            +
                        downsample_padding=downsample_padding,
         | 
| 75 | 
            +
                        cross_attention_dim=cross_attention_dim,
         | 
| 76 | 
            +
                        attn_num_head_channels=attn_num_head_channels,
         | 
| 77 | 
            +
                        dual_cross_attention=dual_cross_attention,
         | 
| 78 | 
            +
                        use_linear_projection=use_linear_projection,
         | 
| 79 | 
            +
                        only_cross_attention=only_cross_attention,
         | 
| 80 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 81 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                        unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 84 | 
            +
                        unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 85 | 
            +
                        use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                        use_motion_module=use_motion_module,
         | 
| 88 | 
            +
                        use_motion_resnet=use_motion_resnet,
         | 
| 89 | 
            +
                        motion_module_type=motion_module_type,
         | 
| 90 | 
            +
                        motion_module_kwargs=motion_module_kwargs,
         | 
| 91 | 
            +
                    )
         | 
| 92 | 
            +
                raise ValueError(f"{down_block_type} does not exist.")
         | 
| 93 | 
            +
             | 
| 94 | 
            +
             | 
| 95 | 
            +
            def get_up_block(
         | 
| 96 | 
            +
                up_block_type,
         | 
| 97 | 
            +
                num_layers,
         | 
| 98 | 
            +
                in_channels,
         | 
| 99 | 
            +
                out_channels,
         | 
| 100 | 
            +
                prev_output_channel,
         | 
| 101 | 
            +
                temb_channels,
         | 
| 102 | 
            +
                add_upsample,
         | 
| 103 | 
            +
                resnet_eps,
         | 
| 104 | 
            +
                resnet_act_fn,
         | 
| 105 | 
            +
                attn_num_head_channels,
         | 
| 106 | 
            +
                resnet_groups=None,
         | 
| 107 | 
            +
                cross_attention_dim=None,
         | 
| 108 | 
            +
                dual_cross_attention=False,
         | 
| 109 | 
            +
                use_linear_projection=False,
         | 
| 110 | 
            +
                only_cross_attention=False,
         | 
| 111 | 
            +
                upcast_attention=False,
         | 
| 112 | 
            +
                resnet_time_scale_shift="default",
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                unet_use_cross_frame_attention=None,
         | 
| 115 | 
            +
                unet_use_temporal_attention=None,
         | 
| 116 | 
            +
                use_inflated_groupnorm=None,
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                use_motion_module=None,
         | 
| 119 | 
            +
                use_motion_resnet=None,
         | 
| 120 | 
            +
                motion_module_type=None,
         | 
| 121 | 
            +
                motion_module_kwargs=None,
         | 
| 122 | 
            +
            ):
         | 
| 123 | 
            +
                up_block_type = up_block_type[7:] if up_block_type.startswith(
         | 
| 124 | 
            +
                    "UNetRes") else up_block_type
         | 
| 125 | 
            +
                if up_block_type == "UpBlock3D":
         | 
| 126 | 
            +
                    return UpBlock3D(
         | 
| 127 | 
            +
                        num_layers=num_layers,
         | 
| 128 | 
            +
                        in_channels=in_channels,
         | 
| 129 | 
            +
                        out_channels=out_channels,
         | 
| 130 | 
            +
                        prev_output_channel=prev_output_channel,
         | 
| 131 | 
            +
                        temb_channels=temb_channels,
         | 
| 132 | 
            +
                        add_upsample=add_upsample,
         | 
| 133 | 
            +
                        resnet_eps=resnet_eps,
         | 
| 134 | 
            +
                        resnet_act_fn=resnet_act_fn,
         | 
| 135 | 
            +
                        resnet_groups=resnet_groups,
         | 
| 136 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                        use_motion_module=use_motion_module,
         | 
| 141 | 
            +
                        motion_module_type=motion_module_type,
         | 
| 142 | 
            +
                        motion_module_kwargs=motion_module_kwargs,
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
                elif up_block_type == "CrossAttnUpBlock3D":
         | 
| 145 | 
            +
                    if cross_attention_dim is None:
         | 
| 146 | 
            +
                        raise ValueError(
         | 
| 147 | 
            +
                            "cross_attention_dim must be specified for CrossAttnUpBlock3D")
         | 
| 148 | 
            +
                    return CrossAttnUpBlock3D(
         | 
| 149 | 
            +
                        num_layers=num_layers,
         | 
| 150 | 
            +
                        in_channels=in_channels,
         | 
| 151 | 
            +
                        out_channels=out_channels,
         | 
| 152 | 
            +
                        prev_output_channel=prev_output_channel,
         | 
| 153 | 
            +
                        temb_channels=temb_channels,
         | 
| 154 | 
            +
                        add_upsample=add_upsample,
         | 
| 155 | 
            +
                        resnet_eps=resnet_eps,
         | 
| 156 | 
            +
                        resnet_act_fn=resnet_act_fn,
         | 
| 157 | 
            +
                        resnet_groups=resnet_groups,
         | 
| 158 | 
            +
                        cross_attention_dim=cross_attention_dim,
         | 
| 159 | 
            +
                        attn_num_head_channels=attn_num_head_channels,
         | 
| 160 | 
            +
                        dual_cross_attention=dual_cross_attention,
         | 
| 161 | 
            +
                        use_linear_projection=use_linear_projection,
         | 
| 162 | 
            +
                        only_cross_attention=only_cross_attention,
         | 
| 163 | 
            +
                        upcast_attention=upcast_attention,
         | 
| 164 | 
            +
                        resnet_time_scale_shift=resnet_time_scale_shift,
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                        unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 167 | 
            +
                        unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 168 | 
            +
                        use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        use_motion_module=use_motion_module,
         | 
| 171 | 
            +
                        use_motion_resnet=use_motion_resnet,
         | 
| 172 | 
            +
                        motion_module_type=motion_module_type,
         | 
| 173 | 
            +
                        motion_module_kwargs=motion_module_kwargs,
         | 
| 174 | 
            +
                    )
         | 
| 175 | 
            +
                raise ValueError(f"{up_block_type} does not exist.")
         | 
| 176 | 
            +
             | 
| 177 | 
            +
             | 
| 178 | 
            +
            class UNetMidBlock3DCrossAttn(nn.Module):
         | 
| 179 | 
            +
                def __init__(
         | 
| 180 | 
            +
                    self,
         | 
| 181 | 
            +
                    in_channels: int,
         | 
| 182 | 
            +
                    temb_channels: int,
         | 
| 183 | 
            +
                    dropout: float = 0.0,
         | 
| 184 | 
            +
                    num_layers: int = 1,
         | 
| 185 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 186 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 187 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 188 | 
            +
                    resnet_groups: int = 32,
         | 
| 189 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 190 | 
            +
                    attn_num_head_channels=1,
         | 
| 191 | 
            +
                    output_scale_factor=1.0,
         | 
| 192 | 
            +
                    cross_attention_dim=1280,
         | 
| 193 | 
            +
                    dual_cross_attention=False,
         | 
| 194 | 
            +
                    use_linear_projection=False,
         | 
| 195 | 
            +
                    upcast_attention=False,
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                    unet_use_cross_frame_attention=None,
         | 
| 198 | 
            +
                    unet_use_temporal_attention=None,
         | 
| 199 | 
            +
                    use_inflated_groupnorm=None,
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    use_motion_module=None,
         | 
| 202 | 
            +
                    use_motion_resnet=None,
         | 
| 203 | 
            +
             | 
| 204 | 
            +
                    motion_module_type=None,
         | 
| 205 | 
            +
                    motion_module_kwargs=None,
         | 
| 206 | 
            +
                ):
         | 
| 207 | 
            +
                    super().__init__()
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                    self.has_cross_attention = True
         | 
| 210 | 
            +
                    self.attn_num_head_channels = attn_num_head_channels
         | 
| 211 | 
            +
                    resnet_groups = resnet_groups if resnet_groups is not None else min(
         | 
| 212 | 
            +
                        in_channels // 4, 32)
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    # there is always at least one resnet
         | 
| 215 | 
            +
                    resnets = [
         | 
| 216 | 
            +
                        ResnetBlock3D(
         | 
| 217 | 
            +
                            in_channels=in_channels,
         | 
| 218 | 
            +
                            out_channels=in_channels,
         | 
| 219 | 
            +
                            temb_channels=temb_channels,
         | 
| 220 | 
            +
                            eps=resnet_eps,
         | 
| 221 | 
            +
                            groups=resnet_groups,
         | 
| 222 | 
            +
                            dropout=dropout,
         | 
| 223 | 
            +
                            time_embedding_norm=resnet_time_scale_shift,
         | 
| 224 | 
            +
                            non_linearity=resnet_act_fn,
         | 
| 225 | 
            +
                            output_scale_factor=output_scale_factor,
         | 
| 226 | 
            +
                            pre_norm=resnet_pre_norm,
         | 
| 227 | 
            +
                            use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 228 | 
            +
                        )
         | 
| 229 | 
            +
                    ]
         | 
| 230 | 
            +
                    motion_resnets = [
         | 
| 231 | 
            +
                        ResnetBlock3D(
         | 
| 232 | 
            +
                            in_channels=in_channels,
         | 
| 233 | 
            +
                            out_channels=in_channels,
         | 
| 234 | 
            +
                            temb_channels=temb_channels,
         | 
| 235 | 
            +
                            eps=resnet_eps,
         | 
| 236 | 
            +
                            groups=resnet_groups,
         | 
| 237 | 
            +
                            dropout=dropout,
         | 
| 238 | 
            +
                            time_embedding_norm=resnet_time_scale_shift,
         | 
| 239 | 
            +
                            non_linearity=resnet_act_fn,
         | 
| 240 | 
            +
                            output_scale_factor=output_scale_factor,
         | 
| 241 | 
            +
                            pre_norm=resnet_pre_norm,
         | 
| 242 | 
            +
                            use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 243 | 
            +
                            use_temporal_conv=True,
         | 
| 244 | 
            +
                            use_temporal_mixer=True,
         | 
| 245 | 
            +
                        ) if use_motion_resnet else None
         | 
| 246 | 
            +
                    ]
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    attentions = []
         | 
| 249 | 
            +
                    motion_modules = []
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                    for _ in range(num_layers):
         | 
| 252 | 
            +
                        if dual_cross_attention:
         | 
| 253 | 
            +
                            raise NotImplementedError
         | 
| 254 | 
            +
                        attentions.append(
         | 
| 255 | 
            +
                            Transformer3DModel(
         | 
| 256 | 
            +
                                attn_num_head_channels,
         | 
| 257 | 
            +
                                in_channels // attn_num_head_channels,
         | 
| 258 | 
            +
                                in_channels=in_channels,
         | 
| 259 | 
            +
                                num_layers=1,
         | 
| 260 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 261 | 
            +
                                norm_num_groups=resnet_groups,
         | 
| 262 | 
            +
                                use_linear_projection=use_linear_projection,
         | 
| 263 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                                unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 266 | 
            +
                                unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 267 | 
            +
                            )
         | 
| 268 | 
            +
                        )
         | 
| 269 | 
            +
                        motion_modules.append(
         | 
| 270 | 
            +
                            get_motion_module(
         | 
| 271 | 
            +
                                in_channels=in_channels,
         | 
| 272 | 
            +
                                motion_module_type=motion_module_type,
         | 
| 273 | 
            +
                                motion_module_kwargs=motion_module_kwargs,
         | 
| 274 | 
            +
                            ) if use_motion_module else None
         | 
| 275 | 
            +
                        )
         | 
| 276 | 
            +
                        resnets.append(
         | 
| 277 | 
            +
                            ResnetBlock3D(
         | 
| 278 | 
            +
                                in_channels=in_channels,
         | 
| 279 | 
            +
                                out_channels=in_channels,
         | 
| 280 | 
            +
                                temb_channels=temb_channels,
         | 
| 281 | 
            +
                                eps=resnet_eps,
         | 
| 282 | 
            +
                                groups=resnet_groups,
         | 
| 283 | 
            +
                                dropout=dropout,
         | 
| 284 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 285 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 286 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 287 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                                use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 290 | 
            +
                            )
         | 
| 291 | 
            +
                        )
         | 
| 292 | 
            +
                        motion_resnets.append(
         | 
| 293 | 
            +
                            ResnetBlock3D(
         | 
| 294 | 
            +
                                in_channels=in_channels,
         | 
| 295 | 
            +
                                out_channels=in_channels,
         | 
| 296 | 
            +
                                temb_channels=temb_channels,
         | 
| 297 | 
            +
                                eps=resnet_eps,
         | 
| 298 | 
            +
                                groups=resnet_groups,
         | 
| 299 | 
            +
                                dropout=dropout,
         | 
| 300 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 301 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 302 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 303 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 304 | 
            +
                                use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 305 | 
            +
                                use_temporal_conv=True,
         | 
| 306 | 
            +
                                use_temporal_mixer=True,
         | 
| 307 | 
            +
                            ) if use_motion_resnet else None
         | 
| 308 | 
            +
                        )
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    self.attentions = nn.ModuleList(attentions)
         | 
| 311 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 312 | 
            +
                    self.motion_modules = nn.ModuleList(motion_modules)
         | 
| 313 | 
            +
                    self.motion_resnets = nn.ModuleList(motion_resnets)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
         | 
| 316 | 
            +
                    hidden_states = self.resnets[0](hidden_states, temb)
         | 
| 317 | 
            +
                    hidden_states = self.motion_resnets[0](
         | 
| 318 | 
            +
                        hidden_states, temb) if self.motion_resnets[0] is not None else hidden_states
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                    for attn, resnet, motion_module, motion_resnet in zip(self.attentions, self.resnets[1:], self.motion_modules, self.motion_resnets[1:]):
         | 
| 321 | 
            +
                        hidden_states = attn(
         | 
| 322 | 
            +
                            hidden_states, encoder_hidden_states=encoder_hidden_states).sample
         | 
| 323 | 
            +
                        hidden_states = motion_module(
         | 
| 324 | 
            +
                            hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
         | 
| 325 | 
            +
                        hidden_states = resnet(hidden_states, temb)
         | 
| 326 | 
            +
                        hidden_states = motion_resnet(
         | 
| 327 | 
            +
                            hidden_states, temb) if motion_resnet is not None else hidden_states
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                    return hidden_states
         | 
| 330 | 
            +
             | 
| 331 | 
            +
             | 
| 332 | 
            +
            class CrossAttnDownBlock3D(nn.Module):
         | 
| 333 | 
            +
                def __init__(
         | 
| 334 | 
            +
                    self,
         | 
| 335 | 
            +
                    in_channels: int,
         | 
| 336 | 
            +
                    out_channels: int,
         | 
| 337 | 
            +
                    temb_channels: int,
         | 
| 338 | 
            +
                    dropout: float = 0.0,
         | 
| 339 | 
            +
                    num_layers: int = 1,
         | 
| 340 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 341 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 342 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 343 | 
            +
                    resnet_groups: int = 32,
         | 
| 344 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 345 | 
            +
                    attn_num_head_channels=1,
         | 
| 346 | 
            +
                    cross_attention_dim=1280,
         | 
| 347 | 
            +
                    output_scale_factor=1.0,
         | 
| 348 | 
            +
                    downsample_padding=1,
         | 
| 349 | 
            +
                    add_downsample=True,
         | 
| 350 | 
            +
                    dual_cross_attention=False,
         | 
| 351 | 
            +
                    use_linear_projection=False,
         | 
| 352 | 
            +
                    only_cross_attention=False,
         | 
| 353 | 
            +
                    upcast_attention=False,
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                    unet_use_cross_frame_attention=None,
         | 
| 356 | 
            +
                    unet_use_temporal_attention=None,
         | 
| 357 | 
            +
                    use_inflated_groupnorm=None,
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    use_motion_module=None,
         | 
| 360 | 
            +
                    use_motion_resnet=None,
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    motion_module_type=None,
         | 
| 363 | 
            +
                    motion_module_kwargs=None,
         | 
| 364 | 
            +
                ):
         | 
| 365 | 
            +
                    super().__init__()
         | 
| 366 | 
            +
                    resnets = []
         | 
| 367 | 
            +
                    motion_resnets = []
         | 
| 368 | 
            +
                    attentions = []
         | 
| 369 | 
            +
                    motion_modules = []
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    self.has_cross_attention = True
         | 
| 372 | 
            +
                    self.attn_num_head_channels = attn_num_head_channels
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    for i in range(num_layers):
         | 
| 375 | 
            +
                        in_channels = in_channels if i == 0 else out_channels
         | 
| 376 | 
            +
                        resnets.append(
         | 
| 377 | 
            +
                            ResnetBlock3D(
         | 
| 378 | 
            +
                                in_channels=in_channels,
         | 
| 379 | 
            +
                                out_channels=out_channels,
         | 
| 380 | 
            +
                                temb_channels=temb_channels,
         | 
| 381 | 
            +
                                eps=resnet_eps,
         | 
| 382 | 
            +
                                groups=resnet_groups,
         | 
| 383 | 
            +
                                dropout=dropout,
         | 
| 384 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 385 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 386 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 387 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 388 | 
            +
             | 
| 389 | 
            +
                                use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 390 | 
            +
                            )
         | 
| 391 | 
            +
                        )
         | 
| 392 | 
            +
                        motion_resnets.append(
         | 
| 393 | 
            +
                            ResnetBlock3D(
         | 
| 394 | 
            +
                                in_channels=out_channels,
         | 
| 395 | 
            +
                                out_channels=out_channels,
         | 
| 396 | 
            +
                                temb_channels=temb_channels,
         | 
| 397 | 
            +
                                eps=resnet_eps,
         | 
| 398 | 
            +
                                groups=resnet_groups,
         | 
| 399 | 
            +
                                dropout=dropout,
         | 
| 400 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 401 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 402 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 403 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 404 | 
            +
                                use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 405 | 
            +
                                use_temporal_conv=True,
         | 
| 406 | 
            +
                                use_temporal_mixer=True,
         | 
| 407 | 
            +
                            ) if use_motion_resnet else None
         | 
| 408 | 
            +
                        )
         | 
| 409 | 
            +
                        if dual_cross_attention:
         | 
| 410 | 
            +
                            raise NotImplementedError
         | 
| 411 | 
            +
                        attentions.append(
         | 
| 412 | 
            +
                            Transformer3DModel(
         | 
| 413 | 
            +
                                attn_num_head_channels,
         | 
| 414 | 
            +
                                out_channels // attn_num_head_channels,
         | 
| 415 | 
            +
                                in_channels=out_channels,
         | 
| 416 | 
            +
                                num_layers=1,
         | 
| 417 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 418 | 
            +
                                norm_num_groups=resnet_groups,
         | 
| 419 | 
            +
                                use_linear_projection=use_linear_projection,
         | 
| 420 | 
            +
                                only_cross_attention=only_cross_attention,
         | 
| 421 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                                unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 424 | 
            +
                                unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 425 | 
            +
                            )
         | 
| 426 | 
            +
                        )
         | 
| 427 | 
            +
                        motion_modules.append(
         | 
| 428 | 
            +
                            get_motion_module(
         | 
| 429 | 
            +
                                in_channels=out_channels,
         | 
| 430 | 
            +
                                motion_module_type=motion_module_type,
         | 
| 431 | 
            +
                                motion_module_kwargs=motion_module_kwargs,
         | 
| 432 | 
            +
                            ) if use_motion_module else None
         | 
| 433 | 
            +
                        )
         | 
| 434 | 
            +
             | 
| 435 | 
            +
                    self.attentions = nn.ModuleList(attentions)
         | 
| 436 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 437 | 
            +
                    self.motion_modules = nn.ModuleList(motion_modules)
         | 
| 438 | 
            +
                    self.motion_resnets = nn.ModuleList(motion_resnets)
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                    if add_downsample:
         | 
| 441 | 
            +
                        self.downsamplers = nn.ModuleList(
         | 
| 442 | 
            +
                            [
         | 
| 443 | 
            +
                                Downsample3D(
         | 
| 444 | 
            +
                                    out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
         | 
| 445 | 
            +
                                )
         | 
| 446 | 
            +
                            ]
         | 
| 447 | 
            +
                        )
         | 
| 448 | 
            +
                    else:
         | 
| 449 | 
            +
                        self.downsamplers = None
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    self.gradient_checkpointing = False
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, img_feature=None):
         | 
| 454 | 
            +
                    output_states = ()
         | 
| 455 | 
            +
                    idx = 1
         | 
| 456 | 
            +
                    for resnet, attn, motion_module, motion_resnet in zip(self.resnets, self.attentions, self.motion_modules, self.motion_resnets):
         | 
| 457 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 460 | 
            +
                                def custom_forward(*inputs):
         | 
| 461 | 
            +
                                    if return_dict is not None:
         | 
| 462 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 463 | 
            +
                                    else:
         | 
| 464 | 
            +
                                        return module(*inputs)
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                                return custom_forward
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 469 | 
            +
                                resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
         | 
| 470 | 
            +
                            if motion_resnet is not None:
         | 
| 471 | 
            +
                                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 472 | 
            +
                                    motion_resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 475 | 
            +
                                create_custom_forward(attn, return_dict=False),
         | 
| 476 | 
            +
                                hidden_states.requires_grad_(),
         | 
| 477 | 
            +
                                encoder_hidden_states,
         | 
| 478 | 
            +
                                use_reentrant=False
         | 
| 479 | 
            +
                            )[0]
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                            hidden_states = hidden_states + \
         | 
| 482 | 
            +
                                img_feature if (
         | 
| 483 | 
            +
                                    img_feature is not None and idx == 2) else hidden_states
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                            if motion_module is not None:
         | 
| 486 | 
            +
                                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 487 | 
            +
                                    motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                        else:
         | 
| 490 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                            hidden_states = motion_resnet(
         | 
| 493 | 
            +
                                hidden_states, temb) if motion_resnet is not None else hidden_states
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                            hidden_states = attn(
         | 
| 496 | 
            +
                                hidden_states, encoder_hidden_states=encoder_hidden_states).sample
         | 
| 497 | 
            +
             | 
| 498 | 
            +
                            hidden_states = hidden_states + \
         | 
| 499 | 
            +
                                img_feature if (
         | 
| 500 | 
            +
                                    img_feature is not None and idx == 2) else hidden_states
         | 
| 501 | 
            +
             | 
| 502 | 
            +
                            # add motion module
         | 
| 503 | 
            +
                            hidden_states = motion_module(
         | 
| 504 | 
            +
                                hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                        idx += 1
         | 
| 507 | 
            +
                        output_states += (hidden_states,)
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    if self.downsamplers is not None:
         | 
| 510 | 
            +
                        for downsampler in self.downsamplers:
         | 
| 511 | 
            +
                            hidden_states = downsampler(hidden_states)
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                        output_states += (hidden_states,)
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                    return hidden_states, output_states
         | 
| 516 | 
            +
             | 
| 517 | 
            +
             | 
| 518 | 
            +
            class DownBlock3D(nn.Module):
         | 
| 519 | 
            +
                def __init__(
         | 
| 520 | 
            +
                    self,
         | 
| 521 | 
            +
                    in_channels: int,
         | 
| 522 | 
            +
                    out_channels: int,
         | 
| 523 | 
            +
                    temb_channels: int,
         | 
| 524 | 
            +
                    dropout: float = 0.0,
         | 
| 525 | 
            +
                    num_layers: int = 1,
         | 
| 526 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 527 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 528 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 529 | 
            +
                    resnet_groups: int = 32,
         | 
| 530 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 531 | 
            +
                    output_scale_factor=1.0,
         | 
| 532 | 
            +
                    add_downsample=True,
         | 
| 533 | 
            +
                    downsample_padding=1,
         | 
| 534 | 
            +
             | 
| 535 | 
            +
                    use_inflated_groupnorm=None,
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    use_motion_module=None,
         | 
| 538 | 
            +
                    motion_module_type=None,
         | 
| 539 | 
            +
                    motion_module_kwargs=None,
         | 
| 540 | 
            +
                ):
         | 
| 541 | 
            +
                    super().__init__()
         | 
| 542 | 
            +
                    resnets = []
         | 
| 543 | 
            +
                    motion_modules = []
         | 
| 544 | 
            +
             | 
| 545 | 
            +
                    for i in range(num_layers):
         | 
| 546 | 
            +
                        in_channels = in_channels if i == 0 else out_channels
         | 
| 547 | 
            +
                        resnets.append(
         | 
| 548 | 
            +
                            ResnetBlock3D(
         | 
| 549 | 
            +
                                in_channels=in_channels,
         | 
| 550 | 
            +
                                out_channels=out_channels,
         | 
| 551 | 
            +
                                temb_channels=temb_channels,
         | 
| 552 | 
            +
                                eps=resnet_eps,
         | 
| 553 | 
            +
                                groups=resnet_groups,
         | 
| 554 | 
            +
                                dropout=dropout,
         | 
| 555 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 556 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 557 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 558 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                                use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 561 | 
            +
                            )
         | 
| 562 | 
            +
                        )
         | 
| 563 | 
            +
                        motion_modules.append(
         | 
| 564 | 
            +
                            get_motion_module(
         | 
| 565 | 
            +
                                in_channels=out_channels,
         | 
| 566 | 
            +
                                motion_module_type=motion_module_type,
         | 
| 567 | 
            +
                                motion_module_kwargs=motion_module_kwargs,
         | 
| 568 | 
            +
                            ) if use_motion_module else None
         | 
| 569 | 
            +
                        )
         | 
| 570 | 
            +
             | 
| 571 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 572 | 
            +
                    self.motion_modules = nn.ModuleList(motion_modules)
         | 
| 573 | 
            +
             | 
| 574 | 
            +
                    if add_downsample:
         | 
| 575 | 
            +
                        self.downsamplers = nn.ModuleList(
         | 
| 576 | 
            +
                            [
         | 
| 577 | 
            +
                                Downsample3D(
         | 
| 578 | 
            +
                                    out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
         | 
| 579 | 
            +
                                )
         | 
| 580 | 
            +
                            ]
         | 
| 581 | 
            +
                        )
         | 
| 582 | 
            +
                    else:
         | 
| 583 | 
            +
                        self.downsamplers = None
         | 
| 584 | 
            +
             | 
| 585 | 
            +
                    self.gradient_checkpointing = False
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                def forward(self, hidden_states, temb=None, encoder_hidden_states=None, img_feature=None):
         | 
| 588 | 
            +
                    output_states = ()
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                    idx = 1
         | 
| 591 | 
            +
                    for resnet, motion_module in zip(self.resnets, self.motion_modules):
         | 
| 592 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 593 | 
            +
                            def create_custom_forward(module):
         | 
| 594 | 
            +
                                def custom_forward(*inputs):
         | 
| 595 | 
            +
                                    return module(*inputs)
         | 
| 596 | 
            +
             | 
| 597 | 
            +
                                return custom_forward
         | 
| 598 | 
            +
             | 
| 599 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 600 | 
            +
                                resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
         | 
| 601 | 
            +
                            hidden_states = hidden_states + \
         | 
| 602 | 
            +
                                img_feature if (
         | 
| 603 | 
            +
                                    img_feature is not None and idx == 2) else hidden_states
         | 
| 604 | 
            +
                            if motion_module is not None:
         | 
| 605 | 
            +
                                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 606 | 
            +
                                    motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
         | 
| 607 | 
            +
                        else:
         | 
| 608 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 609 | 
            +
                            hidden_states = hidden_states + \
         | 
| 610 | 
            +
                                img_feature if (
         | 
| 611 | 
            +
                                    img_feature is not None and idx == 2) else hidden_states
         | 
| 612 | 
            +
                            # add motion module
         | 
| 613 | 
            +
                            hidden_states = motion_module(
         | 
| 614 | 
            +
                                hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
         | 
| 615 | 
            +
             | 
| 616 | 
            +
                        output_states += (hidden_states,)
         | 
| 617 | 
            +
                        idx += 1
         | 
| 618 | 
            +
             | 
| 619 | 
            +
                    if self.downsamplers is not None:
         | 
| 620 | 
            +
                        for downsampler in self.downsamplers:
         | 
| 621 | 
            +
                            hidden_states = downsampler(hidden_states)
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                        output_states += (hidden_states,)
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                    return hidden_states, output_states
         | 
| 626 | 
            +
             | 
| 627 | 
            +
             | 
| 628 | 
            +
            class CrossAttnUpBlock3D(nn.Module):
         | 
| 629 | 
            +
                def __init__(
         | 
| 630 | 
            +
                    self,
         | 
| 631 | 
            +
                    in_channels: int,
         | 
| 632 | 
            +
                    out_channels: int,
         | 
| 633 | 
            +
                    prev_output_channel: int,
         | 
| 634 | 
            +
                    temb_channels: int,
         | 
| 635 | 
            +
                    dropout: float = 0.0,
         | 
| 636 | 
            +
                    num_layers: int = 1,
         | 
| 637 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 638 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 639 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 640 | 
            +
                    resnet_groups: int = 32,
         | 
| 641 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 642 | 
            +
                    attn_num_head_channels=1,
         | 
| 643 | 
            +
                    cross_attention_dim=1280,
         | 
| 644 | 
            +
                    output_scale_factor=1.0,
         | 
| 645 | 
            +
                    add_upsample=True,
         | 
| 646 | 
            +
                    dual_cross_attention=False,
         | 
| 647 | 
            +
                    use_linear_projection=False,
         | 
| 648 | 
            +
                    only_cross_attention=False,
         | 
| 649 | 
            +
                    upcast_attention=False,
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                    unet_use_cross_frame_attention=None,
         | 
| 652 | 
            +
                    unet_use_temporal_attention=None,
         | 
| 653 | 
            +
                    use_inflated_groupnorm=None,
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    use_motion_module=None,
         | 
| 656 | 
            +
                    use_motion_resnet=None,
         | 
| 657 | 
            +
             | 
| 658 | 
            +
                    motion_module_type=None,
         | 
| 659 | 
            +
                    motion_module_kwargs=None,
         | 
| 660 | 
            +
                ):
         | 
| 661 | 
            +
                    super().__init__()
         | 
| 662 | 
            +
                    resnets = []
         | 
| 663 | 
            +
                    attentions = []
         | 
| 664 | 
            +
                    motion_modules = []
         | 
| 665 | 
            +
                    motion_resnets = []
         | 
| 666 | 
            +
                    self.has_cross_attention = True
         | 
| 667 | 
            +
                    self.attn_num_head_channels = attn_num_head_channels
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    for i in range(num_layers):
         | 
| 670 | 
            +
                        res_skip_channels = in_channels if (
         | 
| 671 | 
            +
                            i == num_layers - 1) else out_channels
         | 
| 672 | 
            +
                        resnet_in_channels = prev_output_channel if i == 0 else out_channels
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                        resnets.append(
         | 
| 675 | 
            +
                            ResnetBlock3D(
         | 
| 676 | 
            +
                                in_channels=resnet_in_channels + res_skip_channels,
         | 
| 677 | 
            +
                                out_channels=out_channels,
         | 
| 678 | 
            +
                                temb_channels=temb_channels,
         | 
| 679 | 
            +
                                eps=resnet_eps,
         | 
| 680 | 
            +
                                groups=resnet_groups,
         | 
| 681 | 
            +
                                dropout=dropout,
         | 
| 682 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 683 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 684 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 685 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                                use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 688 | 
            +
                            )
         | 
| 689 | 
            +
                        )
         | 
| 690 | 
            +
                        motion_resnets.append(
         | 
| 691 | 
            +
                            ResnetBlock3D(
         | 
| 692 | 
            +
                                in_channels=out_channels,
         | 
| 693 | 
            +
                                out_channels=out_channels,
         | 
| 694 | 
            +
                                temb_channels=temb_channels,
         | 
| 695 | 
            +
                                eps=resnet_eps,
         | 
| 696 | 
            +
                                groups=resnet_groups,
         | 
| 697 | 
            +
                                dropout=dropout,
         | 
| 698 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 699 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 700 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 701 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 702 | 
            +
                                use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 703 | 
            +
                                use_temporal_conv=True,
         | 
| 704 | 
            +
                                use_temporal_mixer=True
         | 
| 705 | 
            +
                            ) if use_motion_resnet else None
         | 
| 706 | 
            +
                        )
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                        if dual_cross_attention:
         | 
| 709 | 
            +
                            raise NotImplementedError
         | 
| 710 | 
            +
                        attentions.append(
         | 
| 711 | 
            +
                            Transformer3DModel(
         | 
| 712 | 
            +
                                attn_num_head_channels,
         | 
| 713 | 
            +
                                out_channels // attn_num_head_channels,
         | 
| 714 | 
            +
                                in_channels=out_channels,
         | 
| 715 | 
            +
                                num_layers=1,
         | 
| 716 | 
            +
                                cross_attention_dim=cross_attention_dim,
         | 
| 717 | 
            +
                                norm_num_groups=resnet_groups,
         | 
| 718 | 
            +
                                use_linear_projection=use_linear_projection,
         | 
| 719 | 
            +
                                only_cross_attention=only_cross_attention,
         | 
| 720 | 
            +
                                upcast_attention=upcast_attention,
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                                unet_use_cross_frame_attention=unet_use_cross_frame_attention,
         | 
| 723 | 
            +
                                unet_use_temporal_attention=unet_use_temporal_attention,
         | 
| 724 | 
            +
                            )
         | 
| 725 | 
            +
                        )
         | 
| 726 | 
            +
                        motion_modules.append(
         | 
| 727 | 
            +
                            get_motion_module(
         | 
| 728 | 
            +
                                in_channels=out_channels,
         | 
| 729 | 
            +
                                motion_module_type=motion_module_type,
         | 
| 730 | 
            +
                                motion_module_kwargs=motion_module_kwargs,
         | 
| 731 | 
            +
                            ) if use_motion_module else None
         | 
| 732 | 
            +
                        )
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                    self.attentions = nn.ModuleList(attentions)
         | 
| 735 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 736 | 
            +
                    self.motion_modules = nn.ModuleList(motion_modules)
         | 
| 737 | 
            +
                    self.motion_resnets = nn.ModuleList(motion_resnets)
         | 
| 738 | 
            +
             | 
| 739 | 
            +
                    if add_upsample:
         | 
| 740 | 
            +
                        self.upsamplers = nn.ModuleList(
         | 
| 741 | 
            +
                            [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
         | 
| 742 | 
            +
                    else:
         | 
| 743 | 
            +
                        self.upsamplers = None
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                    self.gradient_checkpointing = False
         | 
| 746 | 
            +
             | 
| 747 | 
            +
                def forward(
         | 
| 748 | 
            +
                    self,
         | 
| 749 | 
            +
                    hidden_states,
         | 
| 750 | 
            +
                    res_hidden_states_tuple,
         | 
| 751 | 
            +
                    temb=None,
         | 
| 752 | 
            +
                    encoder_hidden_states=None,
         | 
| 753 | 
            +
                    upsample_size=None,
         | 
| 754 | 
            +
                    attention_mask=None,
         | 
| 755 | 
            +
                ):
         | 
| 756 | 
            +
                    for resnet, attn, motion_module, motion_resnet in zip(self.resnets, self.attentions, self.motion_modules, self.motion_resnets):
         | 
| 757 | 
            +
                        # pop res hidden states
         | 
| 758 | 
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         | 
| 759 | 
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         | 
| 760 | 
            +
                        hidden_states = torch.cat(
         | 
| 761 | 
            +
                            [hidden_states, res_hidden_states], dim=1)
         | 
| 762 | 
            +
             | 
| 763 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 764 | 
            +
             | 
| 765 | 
            +
                            def create_custom_forward(module, return_dict=None):
         | 
| 766 | 
            +
                                def custom_forward(*inputs):
         | 
| 767 | 
            +
                                    if return_dict is not None:
         | 
| 768 | 
            +
                                        return module(*inputs, return_dict=return_dict)
         | 
| 769 | 
            +
                                    else:
         | 
| 770 | 
            +
                                        return module(*inputs)
         | 
| 771 | 
            +
             | 
| 772 | 
            +
                                return custom_forward
         | 
| 773 | 
            +
             | 
| 774 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 775 | 
            +
                                resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
         | 
| 776 | 
            +
                            if motion_resnet is not None:
         | 
| 777 | 
            +
                                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 778 | 
            +
                                    motion_resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
         | 
| 779 | 
            +
             | 
| 780 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(
         | 
| 781 | 
            +
                                create_custom_forward(attn, return_dict=False),
         | 
| 782 | 
            +
                                hidden_states.requires_grad_(),
         | 
| 783 | 
            +
                                encoder_hidden_states,
         | 
| 784 | 
            +
                                use_reentrant=False,
         | 
| 785 | 
            +
                            )[0]
         | 
| 786 | 
            +
                            if motion_module is not None:
         | 
| 787 | 
            +
                                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 788 | 
            +
                                    motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
         | 
| 789 | 
            +
             | 
| 790 | 
            +
                        else:
         | 
| 791 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 792 | 
            +
                            hidden_states = motion_resnet(
         | 
| 793 | 
            +
                                hidden_states, temb) if motion_resnet is not None else hidden_states
         | 
| 794 | 
            +
                            hidden_states = attn(
         | 
| 795 | 
            +
                                hidden_states, encoder_hidden_states=encoder_hidden_states).sample
         | 
| 796 | 
            +
             | 
| 797 | 
            +
                            # add motion module
         | 
| 798 | 
            +
                            hidden_states = motion_module(
         | 
| 799 | 
            +
                                hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                    if self.upsamplers is not None:
         | 
| 802 | 
            +
                        for upsampler in self.upsamplers:
         | 
| 803 | 
            +
                            hidden_states = upsampler(hidden_states, upsample_size)
         | 
| 804 | 
            +
             | 
| 805 | 
            +
                    return hidden_states
         | 
| 806 | 
            +
             | 
| 807 | 
            +
             | 
| 808 | 
            +
            class UpBlock3D(nn.Module):
         | 
| 809 | 
            +
                def __init__(
         | 
| 810 | 
            +
                    self,
         | 
| 811 | 
            +
                    in_channels: int,
         | 
| 812 | 
            +
                    prev_output_channel: int,
         | 
| 813 | 
            +
                    out_channels: int,
         | 
| 814 | 
            +
                    temb_channels: int,
         | 
| 815 | 
            +
                    dropout: float = 0.0,
         | 
| 816 | 
            +
                    num_layers: int = 1,
         | 
| 817 | 
            +
                    resnet_eps: float = 1e-6,
         | 
| 818 | 
            +
                    resnet_time_scale_shift: str = "default",
         | 
| 819 | 
            +
                    resnet_act_fn: str = "swish",
         | 
| 820 | 
            +
                    resnet_groups: int = 32,
         | 
| 821 | 
            +
                    resnet_pre_norm: bool = True,
         | 
| 822 | 
            +
                    output_scale_factor=1.0,
         | 
| 823 | 
            +
                    add_upsample=True,
         | 
| 824 | 
            +
             | 
| 825 | 
            +
                    use_inflated_groupnorm=None,
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                    use_motion_module=None,
         | 
| 828 | 
            +
                    motion_module_type=None,
         | 
| 829 | 
            +
                    motion_module_kwargs=None,
         | 
| 830 | 
            +
                ):
         | 
| 831 | 
            +
                    super().__init__()
         | 
| 832 | 
            +
                    resnets = []
         | 
| 833 | 
            +
                    motion_modules = []
         | 
| 834 | 
            +
             | 
| 835 | 
            +
                    for i in range(num_layers):
         | 
| 836 | 
            +
                        res_skip_channels = in_channels if (
         | 
| 837 | 
            +
                            i == num_layers - 1) else out_channels
         | 
| 838 | 
            +
                        resnet_in_channels = prev_output_channel if i == 0 else out_channels
         | 
| 839 | 
            +
             | 
| 840 | 
            +
                        resnets.append(
         | 
| 841 | 
            +
                            ResnetBlock3D(
         | 
| 842 | 
            +
                                in_channels=resnet_in_channels + res_skip_channels,
         | 
| 843 | 
            +
                                out_channels=out_channels,
         | 
| 844 | 
            +
                                temb_channels=temb_channels,
         | 
| 845 | 
            +
                                eps=resnet_eps,
         | 
| 846 | 
            +
                                groups=resnet_groups,
         | 
| 847 | 
            +
                                dropout=dropout,
         | 
| 848 | 
            +
                                time_embedding_norm=resnet_time_scale_shift,
         | 
| 849 | 
            +
                                non_linearity=resnet_act_fn,
         | 
| 850 | 
            +
                                output_scale_factor=output_scale_factor,
         | 
| 851 | 
            +
                                pre_norm=resnet_pre_norm,
         | 
| 852 | 
            +
             | 
| 853 | 
            +
                                use_inflated_groupnorm=use_inflated_groupnorm,
         | 
| 854 | 
            +
                            )
         | 
| 855 | 
            +
                        )
         | 
| 856 | 
            +
                        motion_modules.append(
         | 
| 857 | 
            +
                            get_motion_module(
         | 
| 858 | 
            +
                                in_channels=out_channels,
         | 
| 859 | 
            +
                                motion_module_type=motion_module_type,
         | 
| 860 | 
            +
                                motion_module_kwargs=motion_module_kwargs,
         | 
| 861 | 
            +
                            ) if use_motion_module else None
         | 
| 862 | 
            +
                        )
         | 
| 863 | 
            +
             | 
| 864 | 
            +
                    self.resnets = nn.ModuleList(resnets)
         | 
| 865 | 
            +
                    self.motion_modules = nn.ModuleList(motion_modules)
         | 
| 866 | 
            +
             | 
| 867 | 
            +
                    if add_upsample:
         | 
| 868 | 
            +
                        self.upsamplers = nn.ModuleList(
         | 
| 869 | 
            +
                            [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
         | 
| 870 | 
            +
                    else:
         | 
| 871 | 
            +
                        self.upsamplers = None
         | 
| 872 | 
            +
             | 
| 873 | 
            +
                    self.gradient_checkpointing = False
         | 
| 874 | 
            +
             | 
| 875 | 
            +
                def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,):
         | 
| 876 | 
            +
                    for resnet, motion_module in zip(self.resnets, self.motion_modules):
         | 
| 877 | 
            +
                        # pop res hidden states
         | 
| 878 | 
            +
                        res_hidden_states = res_hidden_states_tuple[-1]
         | 
| 879 | 
            +
                        res_hidden_states_tuple = res_hidden_states_tuple[:-1]
         | 
| 880 | 
            +
                        hidden_states = torch.cat(
         | 
| 881 | 
            +
                            [hidden_states, res_hidden_states], dim=1)
         | 
| 882 | 
            +
             | 
| 883 | 
            +
                        if self.training and self.gradient_checkpointing:
         | 
| 884 | 
            +
                            def create_custom_forward(module):
         | 
| 885 | 
            +
                                def custom_forward(*inputs):
         | 
| 886 | 
            +
                                    return module(*inputs)
         | 
| 887 | 
            +
             | 
| 888 | 
            +
                                return custom_forward
         | 
| 889 | 
            +
             | 
| 890 | 
            +
                            hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 891 | 
            +
                                resnet), hidden_states.requires_grad_(), temb, use_reentrant=False)
         | 
| 892 | 
            +
                            if motion_module is not None:
         | 
| 893 | 
            +
                                hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(
         | 
| 894 | 
            +
                                    motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, use_reentrant=False)
         | 
| 895 | 
            +
                        else:
         | 
| 896 | 
            +
                            hidden_states = resnet(hidden_states, temb)
         | 
| 897 | 
            +
                            hidden_states = motion_module(
         | 
| 898 | 
            +
                                hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states
         | 
| 899 | 
            +
             | 
| 900 | 
            +
                    if self.upsamplers is not None:
         | 
| 901 | 
            +
                        for upsampler in self.upsamplers:
         | 
| 902 | 
            +
                            hidden_states = upsampler(hidden_states, upsample_size)
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                    return hidden_states
         | 
    	
        animatelcm/pipelines/pipeline_animation.py
    ADDED
    
    | @@ -0,0 +1,456 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import inspect
         | 
| 4 | 
            +
            from typing import Callable, List, Optional, Union
         | 
| 5 | 
            +
            from dataclasses import dataclass
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            from tqdm import tqdm
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from diffusers.utils import is_accelerate_available
         | 
| 12 | 
            +
            from packaging import version
         | 
| 13 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from diffusers.configuration_utils import FrozenDict
         | 
| 16 | 
            +
            from diffusers.models import AutoencoderKL
         | 
| 17 | 
            +
            from diffusers.pipeline_utils import DiffusionPipeline
         | 
| 18 | 
            +
            from diffusers.schedulers import (
         | 
| 19 | 
            +
                DDIMScheduler,
         | 
| 20 | 
            +
                DPMSolverMultistepScheduler,
         | 
| 21 | 
            +
                EulerAncestralDiscreteScheduler,
         | 
| 22 | 
            +
                EulerDiscreteScheduler,
         | 
| 23 | 
            +
                LMSDiscreteScheduler,
         | 
| 24 | 
            +
                PNDMScheduler,
         | 
| 25 | 
            +
            )
         | 
| 26 | 
            +
            from diffusers.utils import deprecate, logging, BaseOutput
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            from einops import rearrange
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            from ..models.unet import UNet3DConditionModel
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            @dataclass
         | 
| 36 | 
            +
            class AnimationPipelineOutput(BaseOutput):
         | 
| 37 | 
            +
                videos: Union[torch.Tensor, np.ndarray]
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             | 
| 40 | 
            +
            class AnimationPipeline(DiffusionPipeline):
         | 
| 41 | 
            +
                _optional_components = []
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def __init__(
         | 
| 44 | 
            +
                    self,
         | 
| 45 | 
            +
                    vae: AutoencoderKL,
         | 
| 46 | 
            +
                    text_encoder: CLIPTextModel,
         | 
| 47 | 
            +
                    tokenizer: CLIPTokenizer,
         | 
| 48 | 
            +
                    unet: UNet3DConditionModel,
         | 
| 49 | 
            +
                    scheduler: Union[
         | 
| 50 | 
            +
                        DDIMScheduler,
         | 
| 51 | 
            +
                        PNDMScheduler,
         | 
| 52 | 
            +
                        LMSDiscreteScheduler,
         | 
| 53 | 
            +
                        EulerDiscreteScheduler,
         | 
| 54 | 
            +
                        EulerAncestralDiscreteScheduler,
         | 
| 55 | 
            +
                        DPMSolverMultistepScheduler,
         | 
| 56 | 
            +
                    ],
         | 
| 57 | 
            +
                ):
         | 
| 58 | 
            +
                    super().__init__()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
         | 
| 61 | 
            +
                        deprecation_message = (
         | 
| 62 | 
            +
                            f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
         | 
| 63 | 
            +
                            f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
         | 
| 64 | 
            +
                            "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
         | 
| 65 | 
            +
                            " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
         | 
| 66 | 
            +
                            " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
         | 
| 67 | 
            +
                            " file"
         | 
| 68 | 
            +
                        )
         | 
| 69 | 
            +
                        deprecate("steps_offset!=1", "1.0.0",
         | 
| 70 | 
            +
                                  deprecation_message, standard_warn=False)
         | 
| 71 | 
            +
                        new_config = dict(scheduler.config)
         | 
| 72 | 
            +
                        new_config["steps_offset"] = 1
         | 
| 73 | 
            +
                        scheduler._internal_dict = FrozenDict(new_config)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
         | 
| 76 | 
            +
                        deprecation_message = (
         | 
| 77 | 
            +
                            f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
         | 
| 78 | 
            +
                            " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
         | 
| 79 | 
            +
                            " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
         | 
| 80 | 
            +
                            " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
         | 
| 81 | 
            +
                            " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
         | 
| 82 | 
            +
                        )
         | 
| 83 | 
            +
                        deprecate("clip_sample not set", "1.0.0",
         | 
| 84 | 
            +
                                  deprecation_message, standard_warn=False)
         | 
| 85 | 
            +
                        new_config = dict(scheduler.config)
         | 
| 86 | 
            +
                        new_config["clip_sample"] = False
         | 
| 87 | 
            +
                        scheduler._internal_dict = FrozenDict(new_config)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
         | 
| 90 | 
            +
                        version.parse(unet.config._diffusers_version).base_version
         | 
| 91 | 
            +
                    ) < version.parse("0.9.0.dev0")
         | 
| 92 | 
            +
                    is_unet_sample_size_less_64 = hasattr(
         | 
| 93 | 
            +
                        unet.config, "sample_size") and unet.config.sample_size < 64
         | 
| 94 | 
            +
                    if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
         | 
| 95 | 
            +
                        deprecation_message = (
         | 
| 96 | 
            +
                            "The configuration file of the unet has set the default `sample_size` to smaller than"
         | 
| 97 | 
            +
                            " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
         | 
| 98 | 
            +
                            " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
         | 
| 99 | 
            +
                            " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
         | 
| 100 | 
            +
                            " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
         | 
| 101 | 
            +
                            " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
         | 
| 102 | 
            +
                            " in the config might lead to incorrect results in future versions. If you have downloaded this"
         | 
| 103 | 
            +
                            " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
         | 
| 104 | 
            +
                            " the `unet/config.json` file"
         | 
| 105 | 
            +
                        )
         | 
| 106 | 
            +
                        deprecate("sample_size<64", "1.0.0",
         | 
| 107 | 
            +
                                  deprecation_message, standard_warn=False)
         | 
| 108 | 
            +
                        new_config = dict(unet.config)
         | 
| 109 | 
            +
                        new_config["sample_size"] = 64
         | 
| 110 | 
            +
                        unet._internal_dict = FrozenDict(new_config)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.register_modules(
         | 
| 113 | 
            +
                        vae=vae,
         | 
| 114 | 
            +
                        text_encoder=text_encoder,
         | 
| 115 | 
            +
                        tokenizer=tokenizer,
         | 
| 116 | 
            +
                        unet=unet,
         | 
| 117 | 
            +
                        scheduler=scheduler,
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
                    self.vae_scale_factor = 2 ** (
         | 
| 120 | 
            +
                        len(self.vae.config.block_out_channels) - 1)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def enable_vae_slicing(self):
         | 
| 123 | 
            +
                    self.vae.enable_slicing()
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def disable_vae_slicing(self):
         | 
| 126 | 
            +
                    self.vae.disable_slicing()
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def enable_sequential_cpu_offload(self, gpu_id=0):
         | 
| 129 | 
            +
                    if is_accelerate_available():
         | 
| 130 | 
            +
                        from accelerate import cpu_offload
         | 
| 131 | 
            +
                    else:
         | 
| 132 | 
            +
                        raise ImportError(
         | 
| 133 | 
            +
                            "Please install accelerate via `pip install accelerate`")
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    device = torch.device(f"cuda:{gpu_id}")
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
         | 
| 138 | 
            +
                        if cpu_offloaded_model is not None:
         | 
| 139 | 
            +
                            cpu_offload(cpu_offloaded_model, device)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                @property
         | 
| 142 | 
            +
                def _execution_device(self):
         | 
| 143 | 
            +
                    if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
         | 
| 144 | 
            +
                        return self.device
         | 
| 145 | 
            +
                    for module in self.unet.modules():
         | 
| 146 | 
            +
                        if (
         | 
| 147 | 
            +
                            hasattr(module, "_hf_hook")
         | 
| 148 | 
            +
                            and hasattr(module._hf_hook, "execution_device")
         | 
| 149 | 
            +
                            and module._hf_hook.execution_device is not None
         | 
| 150 | 
            +
                        ):
         | 
| 151 | 
            +
                            return torch.device(module._hf_hook.execution_device)
         | 
| 152 | 
            +
                    return self.device
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
         | 
| 155 | 
            +
                    batch_size = len(prompt) if isinstance(prompt, list) else 1
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    text_inputs = self.tokenizer(
         | 
| 158 | 
            +
                        prompt,
         | 
| 159 | 
            +
                        padding="max_length",
         | 
| 160 | 
            +
                        max_length=self.tokenizer.model_max_length,
         | 
| 161 | 
            +
                        truncation=True,
         | 
| 162 | 
            +
                        return_tensors="pt",
         | 
| 163 | 
            +
                    )
         | 
| 164 | 
            +
                    text_input_ids = text_inputs.input_ids
         | 
| 165 | 
            +
                    untruncated_ids = self.tokenizer(
         | 
| 166 | 
            +
                        prompt, padding="longest", return_tensors="pt").input_ids
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
         | 
| 169 | 
            +
                        removed_text = self.tokenizer.batch_decode(
         | 
| 170 | 
            +
                            untruncated_ids[:, self.tokenizer.model_max_length - 1: -1])
         | 
| 171 | 
            +
                        logger.warning(
         | 
| 172 | 
            +
                            "The following part of your input was truncated because CLIP can only handle sequences up to"
         | 
| 173 | 
            +
                            f" {self.tokenizer.model_max_length} tokens: {removed_text}"
         | 
| 174 | 
            +
                        )
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
         | 
| 177 | 
            +
                        attention_mask = text_inputs.attention_mask.to(device)
         | 
| 178 | 
            +
                    else:
         | 
| 179 | 
            +
                        attention_mask = None
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    text_embeddings = self.text_encoder(
         | 
| 182 | 
            +
                        text_input_ids.to(device),
         | 
| 183 | 
            +
                        attention_mask=attention_mask,
         | 
| 184 | 
            +
                    )
         | 
| 185 | 
            +
                    text_embeddings = text_embeddings[0]
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                    # duplicate text embeddings for each generation per prompt, using mps friendly method
         | 
| 188 | 
            +
                    bs_embed, seq_len, _ = text_embeddings.shape
         | 
| 189 | 
            +
                    text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
         | 
| 190 | 
            +
                    text_embeddings = text_embeddings.view(
         | 
| 191 | 
            +
                        bs_embed * num_videos_per_prompt, seq_len, -1)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    # get unconditional embeddings for classifier free guidance
         | 
| 194 | 
            +
                    if do_classifier_free_guidance:
         | 
| 195 | 
            +
                        uncond_tokens: List[str]
         | 
| 196 | 
            +
                        if negative_prompt is None:
         | 
| 197 | 
            +
                            uncond_tokens = [""] * batch_size
         | 
| 198 | 
            +
                        elif type(prompt) is not type(negative_prompt):
         | 
| 199 | 
            +
                            raise TypeError(
         | 
| 200 | 
            +
                                f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
         | 
| 201 | 
            +
                                f" {type(prompt)}."
         | 
| 202 | 
            +
                            )
         | 
| 203 | 
            +
                        elif isinstance(negative_prompt, str):
         | 
| 204 | 
            +
                            uncond_tokens = [negative_prompt]
         | 
| 205 | 
            +
                        elif batch_size != len(negative_prompt):
         | 
| 206 | 
            +
                            raise ValueError(
         | 
| 207 | 
            +
                                f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
         | 
| 208 | 
            +
                                f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
         | 
| 209 | 
            +
                                " the batch size of `prompt`."
         | 
| 210 | 
            +
                            )
         | 
| 211 | 
            +
                        else:
         | 
| 212 | 
            +
                            uncond_tokens = negative_prompt
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                        max_length = text_input_ids.shape[-1]
         | 
| 215 | 
            +
                        uncond_input = self.tokenizer(
         | 
| 216 | 
            +
                            uncond_tokens,
         | 
| 217 | 
            +
                            padding="max_length",
         | 
| 218 | 
            +
                            max_length=max_length,
         | 
| 219 | 
            +
                            truncation=True,
         | 
| 220 | 
            +
                            return_tensors="pt",
         | 
| 221 | 
            +
                        )
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                        if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
         | 
| 224 | 
            +
                            attention_mask = uncond_input.attention_mask.to(device)
         | 
| 225 | 
            +
                        else:
         | 
| 226 | 
            +
                            attention_mask = None
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                        uncond_embeddings = self.text_encoder(
         | 
| 229 | 
            +
                            uncond_input.input_ids.to(device),
         | 
| 230 | 
            +
                            attention_mask=attention_mask,
         | 
| 231 | 
            +
                        )
         | 
| 232 | 
            +
                        uncond_embeddings = uncond_embeddings[0]
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                        # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
         | 
| 235 | 
            +
                        seq_len = uncond_embeddings.shape[1]
         | 
| 236 | 
            +
                        uncond_embeddings = uncond_embeddings.repeat(
         | 
| 237 | 
            +
                            1, num_videos_per_prompt, 1)
         | 
| 238 | 
            +
                        uncond_embeddings = uncond_embeddings.view(
         | 
| 239 | 
            +
                            batch_size * num_videos_per_prompt, seq_len, -1)
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        # For classifier free guidance, we need to do two forward passes.
         | 
| 242 | 
            +
                        # Here we concatenate the unconditional and text embeddings into a single batch
         | 
| 243 | 
            +
                        # to avoid doing two forward passes
         | 
| 244 | 
            +
                        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    return text_embeddings
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                def decode_latents(self, latents):
         | 
| 249 | 
            +
                    video_length = latents.shape[2]
         | 
| 250 | 
            +
                    latents = 1 / 0.18215 * latents
         | 
| 251 | 
            +
                    latents = rearrange(latents, "b c f h w -> (b f) c h w")
         | 
| 252 | 
            +
                    # video = self.vae.decode(latents).sample
         | 
| 253 | 
            +
                    video = []
         | 
| 254 | 
            +
                    for frame_idx in tqdm(range(latents.shape[0])):
         | 
| 255 | 
            +
                        video.append(self.vae.decode(
         | 
| 256 | 
            +
                            latents[frame_idx:frame_idx+1]).sample)
         | 
| 257 | 
            +
                    video = torch.cat(video)
         | 
| 258 | 
            +
                    video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
         | 
| 259 | 
            +
                    video = (video / 2 + 0.5).clamp(0, 1)
         | 
| 260 | 
            +
                    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
         | 
| 261 | 
            +
                    video = video.cpu().float().numpy()
         | 
| 262 | 
            +
                    return video
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                def prepare_extra_step_kwargs(self, generator, eta):
         | 
| 265 | 
            +
                    # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
         | 
| 266 | 
            +
                    # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
         | 
| 267 | 
            +
                    # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
         | 
| 268 | 
            +
                    # and should be between [0, 1]
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    accepts_eta = "eta" in set(inspect.signature(
         | 
| 271 | 
            +
                        self.scheduler.step).parameters.keys())
         | 
| 272 | 
            +
                    extra_step_kwargs = {}
         | 
| 273 | 
            +
                    if accepts_eta:
         | 
| 274 | 
            +
                        extra_step_kwargs["eta"] = eta
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    # check if the scheduler accepts generator
         | 
| 277 | 
            +
                    accepts_generator = "generator" in set(
         | 
| 278 | 
            +
                        inspect.signature(self.scheduler.step).parameters.keys())
         | 
| 279 | 
            +
                    if accepts_generator:
         | 
| 280 | 
            +
                        extra_step_kwargs["generator"] = generator
         | 
| 281 | 
            +
                    return extra_step_kwargs
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                def check_inputs(self, prompt, height, width, callback_steps):
         | 
| 284 | 
            +
                    if not isinstance(prompt, str) and not isinstance(prompt, list):
         | 
| 285 | 
            +
                        raise ValueError(
         | 
| 286 | 
            +
                            f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    if height % 8 != 0 or width % 8 != 0:
         | 
| 289 | 
            +
                        raise ValueError(
         | 
| 290 | 
            +
                            f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                    if (callback_steps is None) or (
         | 
| 293 | 
            +
                        callback_steps is not None and (not isinstance(
         | 
| 294 | 
            +
                            callback_steps, int) or callback_steps <= 0)
         | 
| 295 | 
            +
                    ):
         | 
| 296 | 
            +
                        raise ValueError(
         | 
| 297 | 
            +
                            f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
         | 
| 298 | 
            +
                            f" {type(callback_steps)}."
         | 
| 299 | 
            +
                        )
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
         | 
| 302 | 
            +
                    shape = (batch_size, num_channels_latents, video_length, height //
         | 
| 303 | 
            +
                             self.vae_scale_factor, width // self.vae_scale_factor)
         | 
| 304 | 
            +
                    if isinstance(generator, list) and len(generator) != batch_size:
         | 
| 305 | 
            +
                        raise ValueError(
         | 
| 306 | 
            +
                            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         | 
| 307 | 
            +
                            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         | 
| 308 | 
            +
                        )
         | 
| 309 | 
            +
                    if latents is None:
         | 
| 310 | 
            +
                        rand_device = "cpu" if device.type == "mps" else device
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                        if isinstance(generator, list):
         | 
| 313 | 
            +
                            shape = shape
         | 
| 314 | 
            +
                            # shape = (1,) + shape[1:]
         | 
| 315 | 
            +
                            latents = [
         | 
| 316 | 
            +
                                torch.randn(
         | 
| 317 | 
            +
                                    shape, generator=generator[i], device=rand_device, dtype=dtype)
         | 
| 318 | 
            +
                                for i in range(batch_size)
         | 
| 319 | 
            +
                            ]
         | 
| 320 | 
            +
                            latents = torch.cat(latents, dim=0).to(device)
         | 
| 321 | 
            +
                        else:
         | 
| 322 | 
            +
                            latents = torch.randn(
         | 
| 323 | 
            +
                                shape, generator=generator, device=rand_device, dtype=dtype).to(device)
         | 
| 324 | 
            +
                    else:
         | 
| 325 | 
            +
                        if latents.shape != shape:
         | 
| 326 | 
            +
                            raise ValueError(
         | 
| 327 | 
            +
                                f"Unexpected latents shape, got {latents.shape}, expected {shape}")
         | 
| 328 | 
            +
                        latents = latents.to(device)
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    # scale the initial noise by the standard deviation required by the scheduler
         | 
| 331 | 
            +
                    latents = latents * self.scheduler.init_noise_sigma
         | 
| 332 | 
            +
                    return latents
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                @torch.no_grad()
         | 
| 335 | 
            +
                def __call__(
         | 
| 336 | 
            +
                    self,
         | 
| 337 | 
            +
                    prompt: Union[str, List[str]],
         | 
| 338 | 
            +
                    video_length: Optional[int],
         | 
| 339 | 
            +
                    height: Optional[int] = None,
         | 
| 340 | 
            +
                    width: Optional[int] = None,
         | 
| 341 | 
            +
                    num_inference_steps: int = 50,
         | 
| 342 | 
            +
                    guidance_scale: float = 7.5,
         | 
| 343 | 
            +
                    negative_prompt: Optional[Union[str, List[str]]] = None,
         | 
| 344 | 
            +
                    num_videos_per_prompt: Optional[int] = 1,
         | 
| 345 | 
            +
                    eta: float = 0.0,
         | 
| 346 | 
            +
                    generator: Optional[Union[torch.Generator,
         | 
| 347 | 
            +
                                              List[torch.Generator]]] = None,
         | 
| 348 | 
            +
                    latents: Optional[torch.FloatTensor] = None,
         | 
| 349 | 
            +
                    output_type: Optional[str] = "tensor",
         | 
| 350 | 
            +
                    return_dict: bool = True,
         | 
| 351 | 
            +
                    callback: Optional[Callable[[
         | 
| 352 | 
            +
                        int, int, torch.FloatTensor], None]] = None,
         | 
| 353 | 
            +
                    callback_steps: Optional[int] = 1,
         | 
| 354 | 
            +
                    do_classifier_free_guidance: bool = True,
         | 
| 355 | 
            +
                    image_path: str = None,  # not ready
         | 
| 356 | 
            +
                    control_path: str = None,  # not ready
         | 
| 357 | 
            +
                    sparse_control: str = False,  # not ready
         | 
| 358 | 
            +
                    **kwargs,
         | 
| 359 | 
            +
                ):
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    # Default height and width to unet
         | 
| 362 | 
            +
                    height = height or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 363 | 
            +
                    width = width or self.unet.config.sample_size * self.vae_scale_factor
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                    # Check inputs. Raise error if not correct
         | 
| 366 | 
            +
                    self.check_inputs(prompt, height, width, callback_steps)
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                    # Define call parameters
         | 
| 369 | 
            +
                    # batch_size = 1 if isinstance(prompt, str) else len(prompt)
         | 
| 370 | 
            +
                    batch_size = 1
         | 
| 371 | 
            +
                    if latents is not None:
         | 
| 372 | 
            +
                        batch_size = latents.shape[0]
         | 
| 373 | 
            +
                    if isinstance(prompt, list):
         | 
| 374 | 
            +
                        batch_size = len(prompt)
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                    device = self._execution_device
         | 
| 377 | 
            +
                    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
         | 
| 378 | 
            +
                    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
         | 
| 379 | 
            +
                    # corresponds to doing no classifier free guidance.
         | 
| 380 | 
            +
                    do_classifier_free_guidance = (
         | 
| 381 | 
            +
                        guidance_scale > 1.0) & do_classifier_free_guidance
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
         | 
| 384 | 
            +
                    if negative_prompt is not None:
         | 
| 385 | 
            +
                        negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [
         | 
| 386 | 
            +
                            negative_prompt] * batch_size
         | 
| 387 | 
            +
                    text_embeddings = self._encode_prompt(
         | 
| 388 | 
            +
                        prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
         | 
| 389 | 
            +
                    )
         | 
| 390 | 
            +
             | 
| 391 | 
            +
                    # Prepare timesteps
         | 
| 392 | 
            +
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         | 
| 393 | 
            +
                    timesteps = self.scheduler.timesteps
         | 
| 394 | 
            +
             | 
| 395 | 
            +
                    # Prepare latent variables
         | 
| 396 | 
            +
                    num_channels_latents = self.unet.in_channels
         | 
| 397 | 
            +
                    latents = self.prepare_latents(
         | 
| 398 | 
            +
                        batch_size * num_videos_per_prompt,
         | 
| 399 | 
            +
                        num_channels_latents,
         | 
| 400 | 
            +
                        video_length,
         | 
| 401 | 
            +
                        height,
         | 
| 402 | 
            +
                        width,
         | 
| 403 | 
            +
                        text_embeddings.dtype,
         | 
| 404 | 
            +
                        device,
         | 
| 405 | 
            +
                        generator,
         | 
| 406 | 
            +
                        latents,
         | 
| 407 | 
            +
                    )
         | 
| 408 | 
            +
                    latents_dtype = latents.dtype
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    w_embedding = None  # not ready
         | 
| 411 | 
            +
             | 
| 412 | 
            +
                    # Prepare extra step kwargs.
         | 
| 413 | 
            +
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                    # Denoising loop
         | 
| 416 | 
            +
                    num_warmup_steps = len(timesteps) - \
         | 
| 417 | 
            +
                        num_inference_steps * self.scheduler.order
         | 
| 418 | 
            +
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         | 
| 419 | 
            +
                        for i, t in enumerate(timesteps):
         | 
| 420 | 
            +
                            # expand the latents if we are doing classifier free guidance
         | 
| 421 | 
            +
                            latent_model_input = torch.cat(
         | 
| 422 | 
            +
                                [latents] * 2) if do_classifier_free_guidance else latents
         | 
| 423 | 
            +
                            latent_model_input = self.scheduler.scale_model_input(
         | 
| 424 | 
            +
                                latent_model_input, t)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                            # predict the noise residual
         | 
| 427 | 
            +
                            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
         | 
| 428 | 
            +
                                                   time_cond=w_embedding).sample.to(dtype=latents_dtype)
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                            # perform guidance
         | 
| 431 | 
            +
                            if do_classifier_free_guidance:
         | 
| 432 | 
            +
                                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 433 | 
            +
                                noise_pred = noise_pred_uncond + guidance_scale * \
         | 
| 434 | 
            +
                                    (noise_pred_text - noise_pred_uncond)
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 437 | 
            +
                            latents = self.scheduler.step(
         | 
| 438 | 
            +
                                noise_pred, t, latents, **extra_step_kwargs).prev_sample
         | 
| 439 | 
            +
             | 
| 440 | 
            +
                            # call the callback, if provided
         | 
| 441 | 
            +
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         | 
| 442 | 
            +
                                progress_bar.update()
         | 
| 443 | 
            +
                                if callback is not None and i % callback_steps == 0:
         | 
| 444 | 
            +
                                    callback(i, t, latents)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    # Post-processing
         | 
| 447 | 
            +
                    video = self.decode_latents(latents)
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    # Convert to tensor
         | 
| 450 | 
            +
                    if output_type == "tensor":
         | 
| 451 | 
            +
                        video = torch.from_numpy(video)
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    if not return_dict:
         | 
| 454 | 
            +
                        return video
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    return AnimationPipelineOutput(videos=video)
         | 
    	
        animatelcm/scheduler/lcm_scheduler.py
    ADDED
    
    | @@ -0,0 +1,722 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
         | 
| 16 | 
            +
            # and https://github.com/hojonathanho/diffusion
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import math
         | 
| 19 | 
            +
            from dataclasses import dataclass
         | 
| 20 | 
            +
            from typing import List, Optional, Tuple, Union
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            import numpy as np
         | 
| 23 | 
            +
            import torch
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         | 
| 26 | 
            +
            from diffusers.utils import BaseOutput, logging
         | 
| 27 | 
            +
            from diffusers.schedulers.scheduling_utils import SchedulerMixin
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            @dataclass
         | 
| 35 | 
            +
            class LCMSchedulerOutput(BaseOutput):
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
                Output class for the scheduler's `step` function output.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                Args:
         | 
| 40 | 
            +
                    prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 41 | 
            +
                        Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
         | 
| 42 | 
            +
                        denoising loop.
         | 
| 43 | 
            +
                    pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
         | 
| 44 | 
            +
                        The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
         | 
| 45 | 
            +
                        `pred_original_sample` can be used to preview progress or for guidance.
         | 
| 46 | 
            +
                """
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                prev_sample: torch.FloatTensor
         | 
| 49 | 
            +
                denoised: Optional[torch.FloatTensor] = None
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
         | 
| 53 | 
            +
            def betas_for_alpha_bar(
         | 
| 54 | 
            +
                num_diffusion_timesteps,
         | 
| 55 | 
            +
                max_beta=0.999,
         | 
| 56 | 
            +
                alpha_transform_type="cosine",
         | 
| 57 | 
            +
            ):
         | 
| 58 | 
            +
                """
         | 
| 59 | 
            +
                Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
         | 
| 60 | 
            +
                (1-beta) over time from t = [0,1].
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
         | 
| 63 | 
            +
                to that part of the diffusion process.
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
                Args:
         | 
| 67 | 
            +
                    num_diffusion_timesteps (`int`): the number of betas to produce.
         | 
| 68 | 
            +
                    max_beta (`float`): the maximum beta to use; use values lower than 1 to
         | 
| 69 | 
            +
                                 prevent singularities.
         | 
| 70 | 
            +
                    alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar.
         | 
| 71 | 
            +
                                 Choose from `cosine` or `exp`
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                Returns:
         | 
| 74 | 
            +
                    betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
         | 
| 75 | 
            +
                """
         | 
| 76 | 
            +
                if alpha_transform_type == "cosine":
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    def alpha_bar_fn(t):
         | 
| 79 | 
            +
                        return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                elif alpha_transform_type == "exp":
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    def alpha_bar_fn(t):
         | 
| 84 | 
            +
                        return math.exp(t * -12.0)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                else:
         | 
| 87 | 
            +
                    raise ValueError(
         | 
| 88 | 
            +
                        f"Unsupported alpha_tranform_type: {alpha_transform_type}")
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                betas = []
         | 
| 91 | 
            +
                for i in range(num_diffusion_timesteps):
         | 
| 92 | 
            +
                    t1 = i / num_diffusion_timesteps
         | 
| 93 | 
            +
                    t2 = (i + 1) / num_diffusion_timesteps
         | 
| 94 | 
            +
                    betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta))
         | 
| 95 | 
            +
                return torch.tensor(betas, dtype=torch.float32)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
            # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
         | 
| 99 | 
            +
            def rescale_zero_terminal_snr(betas: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 100 | 
            +
                """
         | 
| 101 | 
            +
                Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
                Args:
         | 
| 105 | 
            +
                    betas (`torch.FloatTensor`):
         | 
| 106 | 
            +
                        the betas that the scheduler is being initialized with.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                Returns:
         | 
| 109 | 
            +
                    `torch.FloatTensor`: rescaled betas with zero terminal SNR
         | 
| 110 | 
            +
                """
         | 
| 111 | 
            +
                # Convert betas to alphas_bar_sqrt
         | 
| 112 | 
            +
                alphas = 1.0 - betas
         | 
| 113 | 
            +
                alphas_cumprod = torch.cumprod(alphas, dim=0)
         | 
| 114 | 
            +
                alphas_bar_sqrt = alphas_cumprod.sqrt()
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                # Store old values.
         | 
| 117 | 
            +
                alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
         | 
| 118 | 
            +
                alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                # Shift so the last timestep is zero.
         | 
| 121 | 
            +
                alphas_bar_sqrt -= alphas_bar_sqrt_T
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                # Scale so the first timestep is back to the old value.
         | 
| 124 | 
            +
                alphas_bar_sqrt *= alphas_bar_sqrt_0 / \
         | 
| 125 | 
            +
                    (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                # Convert alphas_bar_sqrt to betas
         | 
| 128 | 
            +
                alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
         | 
| 129 | 
            +
                alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
         | 
| 130 | 
            +
                alphas = torch.cat([alphas_bar[0:1], alphas])
         | 
| 131 | 
            +
                betas = 1 - alphas
         | 
| 132 | 
            +
             | 
| 133 | 
            +
                return betas
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            def randn_tensor(
         | 
| 137 | 
            +
                shape: Union[Tuple, List],
         | 
| 138 | 
            +
                generator: Optional[Union[List["torch.Generator"],
         | 
| 139 | 
            +
                                          "torch.Generator"]] = None,
         | 
| 140 | 
            +
                device: Optional["torch.device"] = None,
         | 
| 141 | 
            +
                dtype: Optional["torch.dtype"] = None,
         | 
| 142 | 
            +
                layout: Optional["torch.layout"] = None,
         | 
| 143 | 
            +
            ):
         | 
| 144 | 
            +
                """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
         | 
| 145 | 
            +
                passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
         | 
| 146 | 
            +
                is always created on the CPU.
         | 
| 147 | 
            +
                """
         | 
| 148 | 
            +
                # device on which tensor is created defaults to device
         | 
| 149 | 
            +
                rand_device = device
         | 
| 150 | 
            +
                batch_size = shape[0]
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                layout = layout or torch.strided
         | 
| 153 | 
            +
                device = device or torch.device("cpu")
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                if generator is not None:
         | 
| 156 | 
            +
                    gen_device_type = generator.device.type if not isinstance(
         | 
| 157 | 
            +
                        generator, list) else generator[0].device.type
         | 
| 158 | 
            +
                    if gen_device_type != device.type and gen_device_type == "cpu":
         | 
| 159 | 
            +
                        rand_device = "cpu"
         | 
| 160 | 
            +
                        if device != "mps":
         | 
| 161 | 
            +
                            logger.info(
         | 
| 162 | 
            +
                                f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
         | 
| 163 | 
            +
                                f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
         | 
| 164 | 
            +
                                f" slighly speed up this function by passing a generator that was created on the {device} device."
         | 
| 165 | 
            +
                            )
         | 
| 166 | 
            +
                    elif gen_device_type != device.type and gen_device_type == "cuda":
         | 
| 167 | 
            +
                        raise ValueError(
         | 
| 168 | 
            +
                            f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                # make sure generator list of length 1 is treated like a non-list
         | 
| 171 | 
            +
                if isinstance(generator, list) and len(generator) == 1:
         | 
| 172 | 
            +
                    generator = generator[0]
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                if isinstance(generator, list):
         | 
| 175 | 
            +
                    shape = (1,) + shape[1:]
         | 
| 176 | 
            +
                    latents = [
         | 
| 177 | 
            +
                        torch.randn(
         | 
| 178 | 
            +
                            shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
         | 
| 179 | 
            +
                        for i in range(batch_size)
         | 
| 180 | 
            +
                    ]
         | 
| 181 | 
            +
                    latents = torch.cat(latents, dim=0).to(device)
         | 
| 182 | 
            +
                else:
         | 
| 183 | 
            +
                    latents = torch.randn(shape, generator=generator,
         | 
| 184 | 
            +
                                          device=rand_device, dtype=dtype, layout=layout).to(device)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                return latents
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            class LCMScheduler(SchedulerMixin, ConfigMixin):
         | 
| 190 | 
            +
                """
         | 
| 191 | 
            +
                `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with
         | 
| 192 | 
            +
                non-Markovian guidance.
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. [`~ConfigMixin`] takes care of storing all config
         | 
| 195 | 
            +
                attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be
         | 
| 196 | 
            +
                accessed via `scheduler.config.num_train_timesteps`. [`SchedulerMixin`] provides general loading and saving
         | 
| 197 | 
            +
                functionality via the [`SchedulerMixin.save_pretrained`] and [`~SchedulerMixin.from_pretrained`] functions.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                Args:
         | 
| 200 | 
            +
                    num_train_timesteps (`int`, defaults to 1000):
         | 
| 201 | 
            +
                        The number of diffusion steps to train the model.
         | 
| 202 | 
            +
                    beta_start (`float`, defaults to 0.0001):
         | 
| 203 | 
            +
                        The starting `beta` value of inference.
         | 
| 204 | 
            +
                    beta_end (`float`, defaults to 0.02):
         | 
| 205 | 
            +
                        The final `beta` value.
         | 
| 206 | 
            +
                    beta_schedule (`str`, defaults to `"linear"`):
         | 
| 207 | 
            +
                        The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
         | 
| 208 | 
            +
                        `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
         | 
| 209 | 
            +
                    trained_betas (`np.ndarray`, *optional*):
         | 
| 210 | 
            +
                        Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
         | 
| 211 | 
            +
                    original_inference_steps (`int`, *optional*, defaults to 50):
         | 
| 212 | 
            +
                        The default number of inference steps used to generate a linearly-spaced timestep schedule, from which we
         | 
| 213 | 
            +
                        will ultimately take `num_inference_steps` evenly spaced timesteps to form the final timestep schedule.
         | 
| 214 | 
            +
                    clip_sample (`bool`, defaults to `True`):
         | 
| 215 | 
            +
                        Clip the predicted sample for numerical stability.
         | 
| 216 | 
            +
                    clip_sample_range (`float`, defaults to 1.0):
         | 
| 217 | 
            +
                        The maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
         | 
| 218 | 
            +
                    set_alpha_to_one (`bool`, defaults to `True`):
         | 
| 219 | 
            +
                        Each diffusion step uses the alphas product value at that step and at the previous one. For the final step
         | 
| 220 | 
            +
                        there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
         | 
| 221 | 
            +
                        otherwise it uses the alpha value at step 0.
         | 
| 222 | 
            +
                    steps_offset (`int`, defaults to 0):
         | 
| 223 | 
            +
                        An offset added to the inference steps. You can use a combination of `offset=1` and
         | 
| 224 | 
            +
                        `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
         | 
| 225 | 
            +
                        Diffusion.
         | 
| 226 | 
            +
                    prediction_type (`str`, defaults to `epsilon`, *optional*):
         | 
| 227 | 
            +
                        Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
         | 
| 228 | 
            +
                        `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
         | 
| 229 | 
            +
                        Video](https://imagen.research.google/video/paper.pdf) paper).
         | 
| 230 | 
            +
                    thresholding (`bool`, defaults to `False`):
         | 
| 231 | 
            +
                        Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such
         | 
| 232 | 
            +
                        as Stable Diffusion.
         | 
| 233 | 
            +
                    dynamic_thresholding_ratio (`float`, defaults to 0.995):
         | 
| 234 | 
            +
                        The ratio for the dynamic thresholding method. Valid only when `thresholding=True`.
         | 
| 235 | 
            +
                    sample_max_value (`float`, defaults to 1.0):
         | 
| 236 | 
            +
                        The threshold value for dynamic thresholding. Valid only when `thresholding=True`.
         | 
| 237 | 
            +
                    timestep_spacing (`str`, defaults to `"leading"`):
         | 
| 238 | 
            +
                        The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
         | 
| 239 | 
            +
                        Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
         | 
| 240 | 
            +
                    timestep_scaling (`float`, defaults to 10.0):
         | 
| 241 | 
            +
                        The factor the timesteps will be multiplied by when calculating the consistency model boundary conditions
         | 
| 242 | 
            +
                        `c_skip` and `c_out`. Increasing this will decrease the approximation error (although the approximation
         | 
| 243 | 
            +
                        error at the default of `10.0` is already pretty small).
         | 
| 244 | 
            +
                    rescale_betas_zero_snr (`bool`, defaults to `False`):
         | 
| 245 | 
            +
                        Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
         | 
| 246 | 
            +
                        dark samples instead of limiting it to samples with medium brightness. Loosely related to
         | 
| 247 | 
            +
                        [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
         | 
| 248 | 
            +
                """
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                order = 1
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                @register_to_config
         | 
| 253 | 
            +
                def __init__(
         | 
| 254 | 
            +
                    self,
         | 
| 255 | 
            +
                    num_train_timesteps: int = 1000,
         | 
| 256 | 
            +
                    beta_start: float = 0.00085,
         | 
| 257 | 
            +
                    beta_end: float = 0.012,
         | 
| 258 | 
            +
                    beta_schedule: str = "scaled_linear",
         | 
| 259 | 
            +
                    trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
         | 
| 260 | 
            +
                    original_inference_steps: int = 50,
         | 
| 261 | 
            +
                    clip_sample: bool = False,
         | 
| 262 | 
            +
                    clip_sample_range: float = 1.0,
         | 
| 263 | 
            +
                    set_alpha_to_one: bool = True,
         | 
| 264 | 
            +
                    steps_offset: int = 0,
         | 
| 265 | 
            +
                    prediction_type: str = "epsilon",
         | 
| 266 | 
            +
                    thresholding: bool = False,
         | 
| 267 | 
            +
                    dynamic_thresholding_ratio: float = 0.995,
         | 
| 268 | 
            +
                    sample_max_value: float = 1.0,
         | 
| 269 | 
            +
                    timestep_spacing: str = "leading",
         | 
| 270 | 
            +
                    timestep_scaling: float = 10.0,
         | 
| 271 | 
            +
                    rescale_betas_zero_snr: bool = False,
         | 
| 272 | 
            +
                ):
         | 
| 273 | 
            +
                    if trained_betas is not None:
         | 
| 274 | 
            +
                        self.betas = torch.tensor(trained_betas, dtype=torch.float32)
         | 
| 275 | 
            +
                    elif beta_schedule == "linear":
         | 
| 276 | 
            +
                        self.betas = torch.linspace(
         | 
| 277 | 
            +
                            beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
         | 
| 278 | 
            +
                    elif beta_schedule == "scaled_linear":
         | 
| 279 | 
            +
                        # this schedule is very specific to the latent diffusion model.
         | 
| 280 | 
            +
                        self.betas = torch.linspace(
         | 
| 281 | 
            +
                            beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
         | 
| 282 | 
            +
                    elif beta_schedule == "squaredcos_cap_v2":
         | 
| 283 | 
            +
                        # Glide cosine schedule
         | 
| 284 | 
            +
                        self.betas = betas_for_alpha_bar(num_train_timesteps)
         | 
| 285 | 
            +
                    else:
         | 
| 286 | 
            +
                        raise NotImplementedError(
         | 
| 287 | 
            +
                            f"{beta_schedule} does is not implemented for {self.__class__}")
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                    # Rescale for zero SNR
         | 
| 290 | 
            +
                    if rescale_betas_zero_snr:
         | 
| 291 | 
            +
                        self.betas = rescale_zero_terminal_snr(self.betas)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    self.alphas = 1.0 - self.betas
         | 
| 294 | 
            +
                    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    # At every step in ddim, we are looking into the previous alphas_cumprod
         | 
| 297 | 
            +
                    # For the final step, there is no previous alphas_cumprod because we are already at 0
         | 
| 298 | 
            +
                    # `set_alpha_to_one` decides whether we set this parameter simply to one or
         | 
| 299 | 
            +
                    # whether we use the final alpha of the "non-previous" one.
         | 
| 300 | 
            +
                    self.final_alpha_cumprod = torch.tensor(
         | 
| 301 | 
            +
                        1.0) if set_alpha_to_one else self.alphas_cumprod[0]
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # standard deviation of the initial noise distribution
         | 
| 304 | 
            +
                    self.init_noise_sigma = 1.0
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    # setable values
         | 
| 307 | 
            +
                    self.num_inference_steps = None
         | 
| 308 | 
            +
                    self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[
         | 
| 309 | 
            +
                                                      ::-1].copy().astype(np.int64))
         | 
| 310 | 
            +
                    self.custom_timesteps = False
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    self._step_index = None
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
         | 
| 315 | 
            +
                def _init_step_index(self, timestep):
         | 
| 316 | 
            +
                    if isinstance(timestep, torch.Tensor):
         | 
| 317 | 
            +
                        timestep = timestep.to(self.timesteps.device)
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    index_candidates = (self.timesteps == timestep).nonzero()
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    # The sigma index that is taken for the **very** first `step`
         | 
| 322 | 
            +
                    # is always the second index (or the last index if there is only 1)
         | 
| 323 | 
            +
                    # This way we can ensure we don't accidentally skip a sigma in
         | 
| 324 | 
            +
                    # case we start in the middle of the denoising schedule (e.g. for image-to-image)
         | 
| 325 | 
            +
                    if len(index_candidates) > 1:
         | 
| 326 | 
            +
                        step_index = index_candidates[1]
         | 
| 327 | 
            +
                    else:
         | 
| 328 | 
            +
                        step_index = index_candidates[0]
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    self._step_index = step_index.item()
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                @property
         | 
| 333 | 
            +
                def step_index(self):
         | 
| 334 | 
            +
                    return self._step_index
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
         | 
| 337 | 
            +
                    """
         | 
| 338 | 
            +
                    Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
         | 
| 339 | 
            +
                    current timestep.
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                    Args:
         | 
| 342 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 343 | 
            +
                            The input sample.
         | 
| 344 | 
            +
                        timestep (`int`, *optional*):
         | 
| 345 | 
            +
                            The current timestep in the diffusion chain.
         | 
| 346 | 
            +
                    Returns:
         | 
| 347 | 
            +
                        `torch.FloatTensor`:
         | 
| 348 | 
            +
                            A scaled input sample.
         | 
| 349 | 
            +
                    """
         | 
| 350 | 
            +
                    return sample
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
         | 
| 353 | 
            +
                def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
         | 
| 354 | 
            +
                    """
         | 
| 355 | 
            +
                    "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
         | 
| 356 | 
            +
                    prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
         | 
| 357 | 
            +
                    s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
         | 
| 358 | 
            +
                    pixels from saturation at each step. We find that dynamic thresholding results in significantly better
         | 
| 359 | 
            +
                    photorealism as well as better image-text alignment, especially when using very large guidance weights."
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                    https://arxiv.org/abs/2205.11487
         | 
| 362 | 
            +
                    """
         | 
| 363 | 
            +
                    dtype = sample.dtype
         | 
| 364 | 
            +
                    batch_size, channels, *remaining_dims = sample.shape
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    if dtype not in (torch.float32, torch.float64):
         | 
| 367 | 
            +
                        # upcast for quantile calculation, and clamp not implemented for cpu half
         | 
| 368 | 
            +
                        sample = sample.float()
         | 
| 369 | 
            +
             | 
| 370 | 
            +
                    # Flatten sample for doing quantile calculation along each image
         | 
| 371 | 
            +
                    sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
         | 
| 372 | 
            +
             | 
| 373 | 
            +
                    abs_sample = sample.abs()  # "a certain percentile absolute pixel value"
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                    s = torch.quantile(
         | 
| 376 | 
            +
                        abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
         | 
| 377 | 
            +
                    s = torch.clamp(
         | 
| 378 | 
            +
                        s, min=1, max=self.config.sample_max_value
         | 
| 379 | 
            +
                    )  # When clamped to min=1, equivalent to standard clipping to [-1, 1]
         | 
| 380 | 
            +
                    # (batch_size, 1) because clamp will broadcast along dim=0
         | 
| 381 | 
            +
                    s = s.unsqueeze(1)
         | 
| 382 | 
            +
                    # "we threshold xt0 to the range [-s, s] and then divide by s"
         | 
| 383 | 
            +
                    sample = torch.clamp(sample, -s, s) / s
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                    sample = sample.reshape(batch_size, channels, *remaining_dims)
         | 
| 386 | 
            +
                    sample = sample.to(dtype)
         | 
| 387 | 
            +
             | 
| 388 | 
            +
                    return sample
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                def set_timesteps(
         | 
| 391 | 
            +
                    self,
         | 
| 392 | 
            +
                    num_inference_steps: Optional[int] = None,
         | 
| 393 | 
            +
                    device: Union[str, torch.device] = None,
         | 
| 394 | 
            +
                    original_inference_steps: Optional[int] = None,
         | 
| 395 | 
            +
                    timesteps: Optional[List[int]] = None,
         | 
| 396 | 
            +
                    strength: int = 1.0,
         | 
| 397 | 
            +
                ):
         | 
| 398 | 
            +
                    """
         | 
| 399 | 
            +
                    Sets the discrete timesteps used for the diffusion chain (to be run before inference).
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                    Args:
         | 
| 402 | 
            +
                        num_inference_steps (`int`, *optional*):
         | 
| 403 | 
            +
                            The number of diffusion steps used when generating samples with a pre-trained model. If used,
         | 
| 404 | 
            +
                            `timesteps` must be `None`.
         | 
| 405 | 
            +
                        device (`str` or `torch.device`, *optional*):
         | 
| 406 | 
            +
                            The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
         | 
| 407 | 
            +
                        original_inference_steps (`int`, *optional*):
         | 
| 408 | 
            +
                            The original number of inference steps, which will be used to generate a linearly-spaced timestep
         | 
| 409 | 
            +
                            schedule (which is different from the standard `diffusers` implementation). We will then take
         | 
| 410 | 
            +
                            `num_inference_steps` timesteps from this schedule, evenly spaced in terms of indices, and use that as
         | 
| 411 | 
            +
                            our final timestep schedule. If not set, this will default to the `original_inference_steps` attribute.
         | 
| 412 | 
            +
                        timesteps (`List[int]`, *optional*):
         | 
| 413 | 
            +
                            Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
         | 
| 414 | 
            +
                            timestep spacing strategy of equal spacing between timesteps on the training/distillation timestep
         | 
| 415 | 
            +
                            schedule is used. If `timesteps` is passed, `num_inference_steps` must be `None`.
         | 
| 416 | 
            +
                    """
         | 
| 417 | 
            +
                    # 0. Check inputs
         | 
| 418 | 
            +
                    if num_inference_steps is None and timesteps is None:
         | 
| 419 | 
            +
                        raise ValueError(
         | 
| 420 | 
            +
                            "Must pass exactly one of `num_inference_steps` or `custom_timesteps`.")
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                    if num_inference_steps is not None and timesteps is not None:
         | 
| 423 | 
            +
                        raise ValueError(
         | 
| 424 | 
            +
                            "Can only pass one of `num_inference_steps` or `custom_timesteps`.")
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                    # 1. Calculate the LCM original training/distillation timestep schedule.
         | 
| 427 | 
            +
                    original_steps = (
         | 
| 428 | 
            +
                        original_inference_steps if original_inference_steps is not None else self.config.original_inference_steps
         | 
| 429 | 
            +
                    )
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    if original_steps > self.config.num_train_timesteps:
         | 
| 432 | 
            +
                        raise ValueError(
         | 
| 433 | 
            +
                            f"`original_steps`: {original_steps} cannot be larger than `self.config.train_timesteps`:"
         | 
| 434 | 
            +
                            f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
         | 
| 435 | 
            +
                            f" maximal {self.config.num_train_timesteps} timesteps."
         | 
| 436 | 
            +
                        )
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                    # LCM Timesteps Setting
         | 
| 439 | 
            +
                    # The skipping step parameter k from the paper.
         | 
| 440 | 
            +
                    k = self.config.num_train_timesteps // original_steps
         | 
| 441 | 
            +
                    # LCM Training/Distillation Steps Schedule
         | 
| 442 | 
            +
                    # Currently, only a linearly-spaced schedule is supported (same as in the LCM distillation scripts).
         | 
| 443 | 
            +
                    lcm_origin_timesteps = np.asarray(
         | 
| 444 | 
            +
                        list(range(1, int(original_steps * strength) + 1))) * k - 1
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                    # 2. Calculate the LCM inference timestep schedule.
         | 
| 447 | 
            +
                    if timesteps is not None:
         | 
| 448 | 
            +
                        # 2.1 Handle custom timestep schedules.
         | 
| 449 | 
            +
                        train_timesteps = set(lcm_origin_timesteps)
         | 
| 450 | 
            +
                        non_train_timesteps = []
         | 
| 451 | 
            +
                        for i in range(1, len(timesteps)):
         | 
| 452 | 
            +
                            if timesteps[i] >= timesteps[i - 1]:
         | 
| 453 | 
            +
                                raise ValueError(
         | 
| 454 | 
            +
                                    "`custom_timesteps` must be in descending order.")
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                            if timesteps[i] not in train_timesteps:
         | 
| 457 | 
            +
                                non_train_timesteps.append(timesteps[i])
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                        if timesteps[0] >= self.config.num_train_timesteps:
         | 
| 460 | 
            +
                            raise ValueError(
         | 
| 461 | 
            +
                                f"`timesteps` must start before `self.config.train_timesteps`:"
         | 
| 462 | 
            +
                                f" {self.config.num_train_timesteps}."
         | 
| 463 | 
            +
                            )
         | 
| 464 | 
            +
             | 
| 465 | 
            +
                        # Raise warning if timestep schedule does not start with self.config.num_train_timesteps - 1
         | 
| 466 | 
            +
                        if strength == 1.0 and timesteps[0] != self.config.num_train_timesteps - 1:
         | 
| 467 | 
            +
                            logger.warning(
         | 
| 468 | 
            +
                                f"The first timestep on the custom timestep schedule is {timesteps[0]}, not"
         | 
| 469 | 
            +
                                f" `self.config.num_train_timesteps - 1`: {self.config.num_train_timesteps - 1}. You may get"
         | 
| 470 | 
            +
                                f" unexpected results when using this timestep schedule."
         | 
| 471 | 
            +
                            )
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                        # Raise warning if custom timestep schedule contains timesteps not on original timestep schedule
         | 
| 474 | 
            +
                        if non_train_timesteps:
         | 
| 475 | 
            +
                            logger.warning(
         | 
| 476 | 
            +
                                f"The custom timestep schedule contains the following timesteps which are not on the original"
         | 
| 477 | 
            +
                                f" training/distillation timestep schedule: {non_train_timesteps}. You may get unexpected results"
         | 
| 478 | 
            +
                                f" when using this timestep schedule."
         | 
| 479 | 
            +
                            )
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                        # Raise warning if custom timestep schedule is longer than original_steps
         | 
| 482 | 
            +
                        if len(timesteps) > original_steps:
         | 
| 483 | 
            +
                            logger.warning(
         | 
| 484 | 
            +
                                f"The number of timesteps in the custom timestep schedule is {len(timesteps)}, which exceeds the"
         | 
| 485 | 
            +
                                f" the length of the timestep schedule used for training: {original_steps}. You may get some"
         | 
| 486 | 
            +
                                f" unexpected results when using this timestep schedule."
         | 
| 487 | 
            +
                            )
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                        timesteps = np.array(timesteps, dtype=np.int64)
         | 
| 490 | 
            +
                        self.num_inference_steps = len(timesteps)
         | 
| 491 | 
            +
                        self.custom_timesteps = True
         | 
| 492 | 
            +
             | 
| 493 | 
            +
                        # Apply strength (e.g. for img2img pipelines) (see StableDiffusionImg2ImgPipeline.get_timesteps)
         | 
| 494 | 
            +
                        init_timestep = min(
         | 
| 495 | 
            +
                            int(self.num_inference_steps * strength), self.num_inference_steps)
         | 
| 496 | 
            +
                        t_start = max(self.num_inference_steps - init_timestep, 0)
         | 
| 497 | 
            +
                        timesteps = timesteps[t_start * self.order:]
         | 
| 498 | 
            +
                        # TODO: also reset self.num_inference_steps?
         | 
| 499 | 
            +
                    else:
         | 
| 500 | 
            +
                        # 2.2 Create the "standard" LCM inference timestep schedule.
         | 
| 501 | 
            +
                        if num_inference_steps > self.config.num_train_timesteps:
         | 
| 502 | 
            +
                            raise ValueError(
         | 
| 503 | 
            +
                                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
         | 
| 504 | 
            +
                                f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
         | 
| 505 | 
            +
                                f" maximal {self.config.num_train_timesteps} timesteps."
         | 
| 506 | 
            +
                            )
         | 
| 507 | 
            +
             | 
| 508 | 
            +
                        skipping_step = len(lcm_origin_timesteps) // num_inference_steps
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                        if skipping_step < 1:
         | 
| 511 | 
            +
                            raise ValueError(
         | 
| 512 | 
            +
                                f"The combination of `original_steps x strength`: {original_steps} x {strength} is smaller than `num_inference_steps`: {num_inference_steps}. Make sure to either reduce `num_inference_steps` to a value smaller than {int(original_steps * strength)} or increase `strength` to a value higher than {float(num_inference_steps / original_steps)}."
         | 
| 513 | 
            +
                            )
         | 
| 514 | 
            +
             | 
| 515 | 
            +
                        self.num_inference_steps = num_inference_steps
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                        if num_inference_steps > original_steps:
         | 
| 518 | 
            +
                            raise ValueError(
         | 
| 519 | 
            +
                                f"`num_inference_steps`: {num_inference_steps} cannot be larger than `original_inference_steps`:"
         | 
| 520 | 
            +
                                f" {original_steps} because the final timestep schedule will be a subset of the"
         | 
| 521 | 
            +
                                f" `original_inference_steps`-sized initial timestep schedule."
         | 
| 522 | 
            +
                            )
         | 
| 523 | 
            +
             | 
| 524 | 
            +
                        # LCM Inference Steps Schedule
         | 
| 525 | 
            +
                        lcm_origin_timesteps = lcm_origin_timesteps[::-1].copy()
         | 
| 526 | 
            +
                        # Select (approximately) evenly spaced indices from lcm_origin_timesteps.
         | 
| 527 | 
            +
                        inference_indices = np.linspace(
         | 
| 528 | 
            +
                            0, len(lcm_origin_timesteps), num=num_inference_steps, endpoint=False)
         | 
| 529 | 
            +
                        '''
         | 
| 530 | 
            +
                        
         | 
| 531 | 
            +
                        当只有1步时会进行999步直接进行
         | 
| 532 | 
            +
                        两步: 999, 499, 
         | 
| 533 | 
            +
                        四步: 999, 759, 499, 259
         | 
| 534 | 
            +
                        
         | 
| 535 | 
            +
                        '''
         | 
| 536 | 
            +
                        inference_indices = np.floor(inference_indices).astype(np.int64)
         | 
| 537 | 
            +
                        timesteps = lcm_origin_timesteps[inference_indices]
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    self.timesteps = torch.from_numpy(timesteps).to(
         | 
| 540 | 
            +
                        device=device, dtype=torch.long)
         | 
| 541 | 
            +
             | 
| 542 | 
            +
                    self._step_index = None
         | 
| 543 | 
            +
             | 
| 544 | 
            +
             | 
| 545 | 
            +
                def get_scalings_for_boundary_condition_discrete(self, timestep):
         | 
| 546 | 
            +
                    self.sigma_data = 0.5  # Default: 0.5
         | 
| 547 | 
            +
                    scaled_timestep = timestep * self.config.timestep_scaling
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
         | 
| 550 | 
            +
                    c_out = scaled_timestep / \
         | 
| 551 | 
            +
                        (scaled_timestep**2 + self.sigma_data**2) ** 0.5
         | 
| 552 | 
            +
                    return c_skip, c_out
         | 
| 553 | 
            +
             | 
| 554 | 
            +
                def step(
         | 
| 555 | 
            +
                    self,
         | 
| 556 | 
            +
                    model_output: torch.FloatTensor,
         | 
| 557 | 
            +
                    timestep: int,
         | 
| 558 | 
            +
                    sample: torch.FloatTensor,
         | 
| 559 | 
            +
                    generator: Optional[torch.Generator] = None,
         | 
| 560 | 
            +
                    return_dict: bool = True,
         | 
| 561 | 
            +
                    use_ddim: bool = False,
         | 
| 562 | 
            +
                ) -> Union[LCMSchedulerOutput, Tuple]:
         | 
| 563 | 
            +
                    """
         | 
| 564 | 
            +
                    Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
         | 
| 565 | 
            +
                    process from the learned model outputs (most often the predicted noise).
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                    Args:
         | 
| 568 | 
            +
                        model_output (`torch.FloatTensor`):
         | 
| 569 | 
            +
                            The direct output from learned diffusion model.
         | 
| 570 | 
            +
                        timestep (`float`):
         | 
| 571 | 
            +
                            The current discrete timestep in the diffusion chain.
         | 
| 572 | 
            +
                        sample (`torch.FloatTensor`):
         | 
| 573 | 
            +
                            A current instance of a sample created by the diffusion process.
         | 
| 574 | 
            +
                        generator (`torch.Generator`, *optional*):
         | 
| 575 | 
            +
                            A random number generator.
         | 
| 576 | 
            +
                        return_dict (`bool`, *optional*, defaults to `True`):
         | 
| 577 | 
            +
                            Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`.
         | 
| 578 | 
            +
                    Returns:
         | 
| 579 | 
            +
                        [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`:
         | 
| 580 | 
            +
                            If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a
         | 
| 581 | 
            +
                            tuple is returned where the first element is the sample tensor.
         | 
| 582 | 
            +
                    """
         | 
| 583 | 
            +
                    if self.num_inference_steps is None:
         | 
| 584 | 
            +
                        raise ValueError(
         | 
| 585 | 
            +
                            "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
         | 
| 586 | 
            +
                        )
         | 
| 587 | 
            +
             | 
| 588 | 
            +
                    if self.step_index is None:
         | 
| 589 | 
            +
                        self._init_step_index(timestep)
         | 
| 590 | 
            +
             | 
| 591 | 
            +
                    # 1. get previous step value
         | 
| 592 | 
            +
                    prev_step_index = self.step_index + 1
         | 
| 593 | 
            +
                    if prev_step_index < len(self.timesteps):
         | 
| 594 | 
            +
                        prev_timestep = self.timesteps[prev_step_index]
         | 
| 595 | 
            +
                    else:
         | 
| 596 | 
            +
                        prev_timestep = timestep
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                    # 2. compute alphas, betas
         | 
| 599 | 
            +
                    alpha_prod_t = self.alphas_cumprod[timestep]
         | 
| 600 | 
            +
                    alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
         | 
| 601 | 
            +
             | 
| 602 | 
            +
                    beta_prod_t = 1 - alpha_prod_t
         | 
| 603 | 
            +
                    beta_prod_t_prev = 1 - alpha_prod_t_prev
         | 
| 604 | 
            +
             | 
| 605 | 
            +
                    # 3. Get scalings for boundary conditions
         | 
| 606 | 
            +
                    c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(
         | 
| 607 | 
            +
                        timestep)
         | 
| 608 | 
            +
             | 
| 609 | 
            +
                    # 4. Compute the predicted original sample x_0 based on the model parameterization
         | 
| 610 | 
            +
                    if self.config.prediction_type == "epsilon":  # noise-prediction
         | 
| 611 | 
            +
                        predicted_original_sample = (
         | 
| 612 | 
            +
                            sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
         | 
| 613 | 
            +
                    elif self.config.prediction_type == "sample":  # x-prediction
         | 
| 614 | 
            +
                        predicted_original_sample = model_output
         | 
| 615 | 
            +
                    elif self.config.prediction_type == "v_prediction":  # v-prediction
         | 
| 616 | 
            +
                        predicted_original_sample = alpha_prod_t.sqrt(
         | 
| 617 | 
            +
                        ) * sample - beta_prod_t.sqrt() * model_output
         | 
| 618 | 
            +
                    else:
         | 
| 619 | 
            +
                        raise ValueError(
         | 
| 620 | 
            +
                            f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or"
         | 
| 621 | 
            +
                            " `v_prediction` for `LCMScheduler`."
         | 
| 622 | 
            +
                        )
         | 
| 623 | 
            +
             | 
| 624 | 
            +
                    # 5. Clip or threshold "predicted x_0"
         | 
| 625 | 
            +
                    if self.config.thresholding:
         | 
| 626 | 
            +
                        predicted_original_sample = self._threshold_sample(
         | 
| 627 | 
            +
                            predicted_original_sample)
         | 
| 628 | 
            +
                    elif self.config.clip_sample:
         | 
| 629 | 
            +
                        predicted_original_sample = predicted_original_sample.clamp(
         | 
| 630 | 
            +
                            -self.config.clip_sample_range, self.config.clip_sample_range
         | 
| 631 | 
            +
                        )
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                    # 6. Denoise model output using boundary conditions
         | 
| 634 | 
            +
                    denoised = c_out * predicted_original_sample + c_skip * sample
         | 
| 635 | 
            +
                    # denoised = predicted_original_sample
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                    # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
         | 
| 638 | 
            +
                    # Noise is not used on the final timestep of the timestep schedule.
         | 
| 639 | 
            +
                    # This also means that noise is not used for one-step sampling.
         | 
| 640 | 
            +
                    if self.step_index != self.num_inference_steps - 1:
         | 
| 641 | 
            +
                        if not use_ddim:
         | 
| 642 | 
            +
                            noise = randn_tensor(
         | 
| 643 | 
            +
                                model_output.shape, generator=generator, device=model_output.device, dtype=denoised.dtype
         | 
| 644 | 
            +
                            )
         | 
| 645 | 
            +
                        prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
         | 
| 646 | 
            +
                    else:
         | 
| 647 | 
            +
                        prev_sample = denoised
         | 
| 648 | 
            +
             | 
| 649 | 
            +
                    # upon completion increase step index by one
         | 
| 650 | 
            +
                    self._step_index += 1
         | 
| 651 | 
            +
             | 
| 652 | 
            +
                    if not return_dict:
         | 
| 653 | 
            +
                        return (prev_sample, denoised)
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                    return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised)
         | 
| 656 | 
            +
             | 
| 657 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
         | 
| 658 | 
            +
                def add_noise(
         | 
| 659 | 
            +
                    self,
         | 
| 660 | 
            +
                    original_samples: torch.FloatTensor,
         | 
| 661 | 
            +
                    noise: torch.FloatTensor,
         | 
| 662 | 
            +
                    timesteps: torch.IntTensor,
         | 
| 663 | 
            +
                ) -> torch.FloatTensor:
         | 
| 664 | 
            +
                    # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
         | 
| 665 | 
            +
                    alphas_cumprod = self.alphas_cumprod.to(
         | 
| 666 | 
            +
                        device=original_samples.device, dtype=original_samples.dtype)
         | 
| 667 | 
            +
                    timesteps = timesteps.to(original_samples.device)
         | 
| 668 | 
            +
             | 
| 669 | 
            +
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         | 
| 670 | 
            +
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         | 
| 671 | 
            +
                    while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
         | 
| 672 | 
            +
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         | 
| 673 | 
            +
             | 
| 674 | 
            +
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         | 
| 675 | 
            +
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         | 
| 676 | 
            +
                    while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
         | 
| 677 | 
            +
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    noisy_samples = sqrt_alpha_prod * original_samples + \
         | 
| 680 | 
            +
                        sqrt_one_minus_alpha_prod * noise
         | 
| 681 | 
            +
                    return noisy_samples
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
         | 
| 684 | 
            +
                def get_velocity(
         | 
| 685 | 
            +
                    self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
         | 
| 686 | 
            +
                ) -> torch.FloatTensor:
         | 
| 687 | 
            +
                    # Make sure alphas_cumprod and timestep have same device and dtype as sample
         | 
| 688 | 
            +
                    alphas_cumprod = self.alphas_cumprod.to(
         | 
| 689 | 
            +
                        device=sample.device, dtype=sample.dtype)
         | 
| 690 | 
            +
                    timesteps = timesteps.to(sample.device)
         | 
| 691 | 
            +
             | 
| 692 | 
            +
                    sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         | 
| 693 | 
            +
                    sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         | 
| 694 | 
            +
                    while len(sqrt_alpha_prod.shape) < len(sample.shape):
         | 
| 695 | 
            +
                        sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         | 
| 698 | 
            +
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         | 
| 699 | 
            +
                    while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
         | 
| 700 | 
            +
                        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                    velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
         | 
| 703 | 
            +
                    return velocity
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                def __len__(self):
         | 
| 706 | 
            +
                    return self.config.num_train_timesteps
         | 
| 707 | 
            +
             | 
| 708 | 
            +
                # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.previous_timestep
         | 
| 709 | 
            +
                def previous_timestep(self, timestep):
         | 
| 710 | 
            +
                    if self.custom_timesteps:
         | 
| 711 | 
            +
                        index = (self.timesteps == timestep).nonzero(as_tuple=True)[0][0]
         | 
| 712 | 
            +
                        if index == self.timesteps.shape[0] - 1:
         | 
| 713 | 
            +
                            prev_t = torch.tensor(-1)
         | 
| 714 | 
            +
                        else:
         | 
| 715 | 
            +
                            prev_t = self.timesteps[index + 1]
         | 
| 716 | 
            +
                    else:
         | 
| 717 | 
            +
                        num_inference_steps = (
         | 
| 718 | 
            +
                            self.num_inference_steps if self.num_inference_steps else self.config.num_train_timesteps
         | 
| 719 | 
            +
                        )
         | 
| 720 | 
            +
                        prev_t = timestep - self.config.num_train_timesteps // num_inference_steps
         | 
| 721 | 
            +
             | 
| 722 | 
            +
                    return prev_t
         | 
    	
        animatelcm/utils/convert_from_ckpt.py
    ADDED
    
    | @@ -0,0 +1,951 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2023 The HuggingFace Inc. team.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
            """ Conversion script for the Stable Diffusion checkpoints."""
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            import re
         | 
| 18 | 
            +
            from io import BytesIO
         | 
| 19 | 
            +
            from typing import Optional
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            import requests
         | 
| 22 | 
            +
            import torch
         | 
| 23 | 
            +
            from transformers import (
         | 
| 24 | 
            +
                AutoFeatureExtractor,
         | 
| 25 | 
            +
                BertTokenizerFast,
         | 
| 26 | 
            +
                CLIPImageProcessor,
         | 
| 27 | 
            +
                CLIPTextModel,
         | 
| 28 | 
            +
                CLIPTextModelWithProjection,
         | 
| 29 | 
            +
                CLIPTokenizer,
         | 
| 30 | 
            +
                CLIPVisionConfig,
         | 
| 31 | 
            +
                CLIPVisionModelWithProjection,
         | 
| 32 | 
            +
            )
         | 
| 33 | 
            +
             | 
| 34 | 
            +
            from diffusers.models import (
         | 
| 35 | 
            +
                AutoencoderKL,
         | 
| 36 | 
            +
                PriorTransformer,
         | 
| 37 | 
            +
                UNet2DConditionModel,
         | 
| 38 | 
            +
            )
         | 
| 39 | 
            +
            from diffusers.schedulers import (
         | 
| 40 | 
            +
                DDIMScheduler,
         | 
| 41 | 
            +
                DDPMScheduler,
         | 
| 42 | 
            +
                DPMSolverMultistepScheduler,
         | 
| 43 | 
            +
                EulerAncestralDiscreteScheduler,
         | 
| 44 | 
            +
                EulerDiscreteScheduler,
         | 
| 45 | 
            +
                HeunDiscreteScheduler,
         | 
| 46 | 
            +
                LMSDiscreteScheduler,
         | 
| 47 | 
            +
                PNDMScheduler,
         | 
| 48 | 
            +
                UnCLIPScheduler,
         | 
| 49 | 
            +
            )
         | 
| 50 | 
            +
            from diffusers.utils.import_utils import BACKENDS_MAPPING
         | 
| 51 | 
            +
             | 
| 52 | 
            +
             | 
| 53 | 
            +
            def shave_segments(path, n_shave_prefix_segments=1):
         | 
| 54 | 
            +
                """
         | 
| 55 | 
            +
                Removes segments. Positive values shave the first segments, negative shave the last segments.
         | 
| 56 | 
            +
                """
         | 
| 57 | 
            +
                if n_shave_prefix_segments >= 0:
         | 
| 58 | 
            +
                    return ".".join(path.split(".")[n_shave_prefix_segments:])
         | 
| 59 | 
            +
                else:
         | 
| 60 | 
            +
                    return ".".join(path.split(".")[:n_shave_prefix_segments])
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
         | 
| 64 | 
            +
                """
         | 
| 65 | 
            +
                Updates paths inside resnets to the new naming scheme (local renaming)
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                mapping = []
         | 
| 68 | 
            +
                for old_item in old_list:
         | 
| 69 | 
            +
                    new_item = old_item.replace("in_layers.0", "norm1")
         | 
| 70 | 
            +
                    new_item = new_item.replace("in_layers.2", "conv1")
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    new_item = new_item.replace("out_layers.0", "norm2")
         | 
| 73 | 
            +
                    new_item = new_item.replace("out_layers.3", "conv2")
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    new_item = new_item.replace("emb_layers.1", "time_emb_proj")
         | 
| 76 | 
            +
                    new_item = new_item.replace("skip_connection", "conv_shortcut")
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                return mapping
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                Updates paths inside resnets to the new naming scheme (local renaming)
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
                mapping = []
         | 
| 90 | 
            +
                for old_item in old_list:
         | 
| 91 | 
            +
                    new_item = old_item
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    new_item = new_item.replace("nin_shortcut", "conv_shortcut")
         | 
| 94 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                return mapping
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            def renew_attention_paths(old_list, n_shave_prefix_segments=0):
         | 
| 102 | 
            +
                """
         | 
| 103 | 
            +
                Updates paths inside attentions to the new naming scheme (local renaming)
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                mapping = []
         | 
| 106 | 
            +
                for old_item in old_list:
         | 
| 107 | 
            +
                    new_item = old_item
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                return mapping
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
         | 
| 115 | 
            +
                """
         | 
| 116 | 
            +
                Updates paths inside attentions to the new naming scheme (local renaming)
         | 
| 117 | 
            +
                """
         | 
| 118 | 
            +
                mapping = []
         | 
| 119 | 
            +
                for old_item in old_list:
         | 
| 120 | 
            +
                    new_item = old_item
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                    new_item = new_item.replace("norm.weight", "group_norm.weight")
         | 
| 123 | 
            +
                    new_item = new_item.replace("norm.bias", "group_norm.bias")
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    new_item = new_item.replace("q.weight", "query.weight")
         | 
| 126 | 
            +
                    new_item = new_item.replace("q.bias", "query.bias")
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    new_item = new_item.replace("k.weight", "key.weight")
         | 
| 129 | 
            +
                    new_item = new_item.replace("k.bias", "key.bias")
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                    new_item = new_item.replace("v.weight", "value.weight")
         | 
| 132 | 
            +
                    new_item = new_item.replace("v.bias", "value.bias")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
         | 
| 135 | 
            +
                    new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    mapping.append({"old": old_item, "new": new_item})
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                return mapping
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
            def assign_to_checkpoint(
         | 
| 145 | 
            +
                paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
         | 
| 146 | 
            +
            ):
         | 
| 147 | 
            +
                """
         | 
| 148 | 
            +
                This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
         | 
| 149 | 
            +
                attention layers, and takes into account additional replacements that may arise.
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                Assigns the weights to the new checkpoint.
         | 
| 152 | 
            +
                """
         | 
| 153 | 
            +
                assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                # Splits the attention layers into three variables.
         | 
| 156 | 
            +
                if attention_paths_to_split is not None:
         | 
| 157 | 
            +
                    for path, path_map in attention_paths_to_split.items():
         | 
| 158 | 
            +
                        old_tensor = old_checkpoint[path]
         | 
| 159 | 
            +
                        channels = old_tensor.shape[0] // 3
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                        target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                        num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                        old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
         | 
| 166 | 
            +
                        query, key, value = old_tensor.split(channels // num_heads, dim=1)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                        checkpoint[path_map["query"]] = query.reshape(target_shape)
         | 
| 169 | 
            +
                        checkpoint[path_map["key"]] = key.reshape(target_shape)
         | 
| 170 | 
            +
                        checkpoint[path_map["value"]] = value.reshape(target_shape)
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                for path in paths:
         | 
| 173 | 
            +
                    new_path = path["new"]
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    # These have already been assigned
         | 
| 176 | 
            +
                    if attention_paths_to_split is not None and new_path in attention_paths_to_split:
         | 
| 177 | 
            +
                        continue
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                    # Global renaming happens here
         | 
| 180 | 
            +
                    new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
         | 
| 181 | 
            +
                    new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
         | 
| 182 | 
            +
                    new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if additional_replacements is not None:
         | 
| 185 | 
            +
                        for replacement in additional_replacements:
         | 
| 186 | 
            +
                            new_path = new_path.replace(replacement["old"], replacement["new"])
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                    # proj_attn.weight has to be converted from conv 1D to linear
         | 
| 189 | 
            +
                    if "proj_attn.weight" in new_path:
         | 
| 190 | 
            +
                        checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        checkpoint[new_path] = old_checkpoint[path["old"]]
         | 
| 193 | 
            +
             | 
| 194 | 
            +
             | 
| 195 | 
            +
            def conv_attn_to_linear(checkpoint):
         | 
| 196 | 
            +
                keys = list(checkpoint.keys())
         | 
| 197 | 
            +
                attn_keys = ["query.weight", "key.weight", "value.weight"]
         | 
| 198 | 
            +
                for key in keys:
         | 
| 199 | 
            +
                    if ".".join(key.split(".")[-2:]) in attn_keys:
         | 
| 200 | 
            +
                        if checkpoint[key].ndim > 2:
         | 
| 201 | 
            +
                            checkpoint[key] = checkpoint[key][:, :, 0, 0]
         | 
| 202 | 
            +
                    elif "proj_attn.weight" in key:
         | 
| 203 | 
            +
                        if checkpoint[key].ndim > 2:
         | 
| 204 | 
            +
                            checkpoint[key] = checkpoint[key][:, :, 0]
         | 
| 205 | 
            +
             | 
| 206 | 
            +
             | 
| 207 | 
            +
            def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
         | 
| 208 | 
            +
                """
         | 
| 209 | 
            +
                Creates a config for the diffusers based on the config of the LDM model.
         | 
| 210 | 
            +
                """
         | 
| 211 | 
            +
                if controlnet:
         | 
| 212 | 
            +
                    unet_params = original_config.model.params.control_stage_config.params
         | 
| 213 | 
            +
                else:
         | 
| 214 | 
            +
                    unet_params = original_config.model.params.unet_config.params
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                vae_params = original_config.model.params.first_stage_config.params.ddconfig
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                down_block_types = []
         | 
| 221 | 
            +
                resolution = 1
         | 
| 222 | 
            +
                for i in range(len(block_out_channels)):
         | 
| 223 | 
            +
                    block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
         | 
| 224 | 
            +
                    down_block_types.append(block_type)
         | 
| 225 | 
            +
                    if i != len(block_out_channels) - 1:
         | 
| 226 | 
            +
                        resolution *= 2
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                up_block_types = []
         | 
| 229 | 
            +
                for i in range(len(block_out_channels)):
         | 
| 230 | 
            +
                    block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
         | 
| 231 | 
            +
                    up_block_types.append(block_type)
         | 
| 232 | 
            +
                    resolution //= 2
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                head_dim = unet_params.num_heads if "num_heads" in unet_params else None
         | 
| 237 | 
            +
                use_linear_projection = (
         | 
| 238 | 
            +
                    unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
         | 
| 239 | 
            +
                )
         | 
| 240 | 
            +
                if use_linear_projection:
         | 
| 241 | 
            +
                    # stable diffusion 2-base-512 and 2-768
         | 
| 242 | 
            +
                    if head_dim is None:
         | 
| 243 | 
            +
                        head_dim = [5, 10, 20, 20]
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                class_embed_type = None
         | 
| 246 | 
            +
                projection_class_embeddings_input_dim = None
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                if "num_classes" in unet_params:
         | 
| 249 | 
            +
                    if unet_params.num_classes == "sequential":
         | 
| 250 | 
            +
                        class_embed_type = "projection"
         | 
| 251 | 
            +
                        assert "adm_in_channels" in unet_params
         | 
| 252 | 
            +
                        projection_class_embeddings_input_dim = unet_params.adm_in_channels
         | 
| 253 | 
            +
                    else:
         | 
| 254 | 
            +
                        raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                config = {
         | 
| 257 | 
            +
                    "sample_size": image_size // vae_scale_factor,
         | 
| 258 | 
            +
                    "in_channels": unet_params.in_channels,
         | 
| 259 | 
            +
                    "down_block_types": tuple(down_block_types),
         | 
| 260 | 
            +
                    "block_out_channels": tuple(block_out_channels),
         | 
| 261 | 
            +
                    "layers_per_block": unet_params.num_res_blocks,
         | 
| 262 | 
            +
                    "cross_attention_dim": unet_params.context_dim,
         | 
| 263 | 
            +
                    "attention_head_dim": head_dim,
         | 
| 264 | 
            +
                    "use_linear_projection": use_linear_projection,
         | 
| 265 | 
            +
                    "class_embed_type": class_embed_type,
         | 
| 266 | 
            +
                    "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
         | 
| 267 | 
            +
                }
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                if not controlnet:
         | 
| 270 | 
            +
                    config["out_channels"] = unet_params.out_channels
         | 
| 271 | 
            +
                    config["up_block_types"] = tuple(up_block_types)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                return config
         | 
| 274 | 
            +
             | 
| 275 | 
            +
             | 
| 276 | 
            +
            def create_vae_diffusers_config(original_config, image_size: int):
         | 
| 277 | 
            +
                """
         | 
| 278 | 
            +
                Creates a config for the diffusers based on the config of the LDM model.
         | 
| 279 | 
            +
                """
         | 
| 280 | 
            +
                vae_params = original_config.model.params.first_stage_config.params.ddconfig
         | 
| 281 | 
            +
                _ = original_config.model.params.first_stage_config.params.embed_dim
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
         | 
| 284 | 
            +
                down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
         | 
| 285 | 
            +
                up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                config = {
         | 
| 288 | 
            +
                    "sample_size": image_size,
         | 
| 289 | 
            +
                    "in_channels": vae_params.in_channels,
         | 
| 290 | 
            +
                    "out_channels": vae_params.out_ch,
         | 
| 291 | 
            +
                    "down_block_types": tuple(down_block_types),
         | 
| 292 | 
            +
                    "up_block_types": tuple(up_block_types),
         | 
| 293 | 
            +
                    "block_out_channels": tuple(block_out_channels),
         | 
| 294 | 
            +
                    "latent_channels": vae_params.z_channels,
         | 
| 295 | 
            +
                    "layers_per_block": vae_params.num_res_blocks,
         | 
| 296 | 
            +
                }
         | 
| 297 | 
            +
                return config
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
| 300 | 
            +
            def create_diffusers_schedular(original_config):
         | 
| 301 | 
            +
                schedular = DDIMScheduler(
         | 
| 302 | 
            +
                    num_train_timesteps=original_config.model.params.timesteps,
         | 
| 303 | 
            +
                    beta_start=original_config.model.params.linear_start,
         | 
| 304 | 
            +
                    beta_end=original_config.model.params.linear_end,
         | 
| 305 | 
            +
                    beta_schedule="scaled_linear",
         | 
| 306 | 
            +
                )
         | 
| 307 | 
            +
                return schedular
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            def create_ldm_bert_config(original_config):
         | 
| 311 | 
            +
                bert_params = original_config.model.parms.cond_stage_config.params
         | 
| 312 | 
            +
                config = LDMBertConfig(
         | 
| 313 | 
            +
                    d_model=bert_params.n_embed,
         | 
| 314 | 
            +
                    encoder_layers=bert_params.n_layer,
         | 
| 315 | 
            +
                    encoder_ffn_dim=bert_params.n_embed * 4,
         | 
| 316 | 
            +
                )
         | 
| 317 | 
            +
                return config
         | 
| 318 | 
            +
             | 
| 319 | 
            +
             | 
| 320 | 
            +
            def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
         | 
| 321 | 
            +
                """
         | 
| 322 | 
            +
                Takes a state dict and a config, and returns a converted checkpoint.
         | 
| 323 | 
            +
                """
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                # extract state_dict for UNet
         | 
| 326 | 
            +
                unet_state_dict = {}
         | 
| 327 | 
            +
                keys = list(checkpoint.keys())
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                if controlnet:
         | 
| 330 | 
            +
                    unet_key = "control_model."
         | 
| 331 | 
            +
                else:
         | 
| 332 | 
            +
                    unet_key = "model.diffusion_model."
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
         | 
| 335 | 
            +
                if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
         | 
| 336 | 
            +
                    print(f"Checkpoint {path} has both EMA and non-EMA weights.")
         | 
| 337 | 
            +
                    print(
         | 
| 338 | 
            +
                        "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
         | 
| 339 | 
            +
                        " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
         | 
| 340 | 
            +
                    )
         | 
| 341 | 
            +
                    for key in keys:
         | 
| 342 | 
            +
                        if key.startswith("model.diffusion_model"):
         | 
| 343 | 
            +
                            flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
         | 
| 344 | 
            +
                            unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
         | 
| 345 | 
            +
                else:
         | 
| 346 | 
            +
                    if sum(k.startswith("model_ema") for k in keys) > 100:
         | 
| 347 | 
            +
                        print(
         | 
| 348 | 
            +
                            "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
         | 
| 349 | 
            +
                            " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
         | 
| 350 | 
            +
                        )
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    for key in keys:
         | 
| 353 | 
            +
                        if key.startswith(unet_key):
         | 
| 354 | 
            +
                            unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                new_checkpoint = {}
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
         | 
| 359 | 
            +
                new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
         | 
| 360 | 
            +
                new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
         | 
| 361 | 
            +
                new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                if config["class_embed_type"] is None:
         | 
| 364 | 
            +
                    # No parameters to port
         | 
| 365 | 
            +
                    ...
         | 
| 366 | 
            +
                elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
         | 
| 367 | 
            +
                    new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
         | 
| 368 | 
            +
                    new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
         | 
| 369 | 
            +
                    new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
         | 
| 370 | 
            +
                    new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
         | 
| 371 | 
            +
                else:
         | 
| 372 | 
            +
                    raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
         | 
| 375 | 
            +
                new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                if not controlnet:
         | 
| 378 | 
            +
                    new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
         | 
| 379 | 
            +
                    new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
         | 
| 380 | 
            +
                    new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
         | 
| 381 | 
            +
                    new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                # Retrieves the keys for the input blocks only
         | 
| 384 | 
            +
                num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
         | 
| 385 | 
            +
                input_blocks = {
         | 
| 386 | 
            +
                    layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
         | 
| 387 | 
            +
                    for layer_id in range(num_input_blocks)
         | 
| 388 | 
            +
                }
         | 
| 389 | 
            +
             | 
| 390 | 
            +
                # Retrieves the keys for the middle blocks only
         | 
| 391 | 
            +
                num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
         | 
| 392 | 
            +
                middle_blocks = {
         | 
| 393 | 
            +
                    layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
         | 
| 394 | 
            +
                    for layer_id in range(num_middle_blocks)
         | 
| 395 | 
            +
                }
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                # Retrieves the keys for the output blocks only
         | 
| 398 | 
            +
                num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
         | 
| 399 | 
            +
                output_blocks = {
         | 
| 400 | 
            +
                    layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
         | 
| 401 | 
            +
                    for layer_id in range(num_output_blocks)
         | 
| 402 | 
            +
                }
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                for i in range(1, num_input_blocks):
         | 
| 405 | 
            +
                    block_id = (i - 1) // (config["layers_per_block"] + 1)
         | 
| 406 | 
            +
                    layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
         | 
| 407 | 
            +
             | 
| 408 | 
            +
                    resnets = [
         | 
| 409 | 
            +
                        key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
         | 
| 410 | 
            +
                    ]
         | 
| 411 | 
            +
                    attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                    if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
         | 
| 414 | 
            +
                        new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
         | 
| 415 | 
            +
                            f"input_blocks.{i}.0.op.weight"
         | 
| 416 | 
            +
                        )
         | 
| 417 | 
            +
                        new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
         | 
| 418 | 
            +
                            f"input_blocks.{i}.0.op.bias"
         | 
| 419 | 
            +
                        )
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    paths = renew_resnet_paths(resnets)
         | 
| 422 | 
            +
                    meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
         | 
| 423 | 
            +
                    assign_to_checkpoint(
         | 
| 424 | 
            +
                        paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
         | 
| 425 | 
            +
                    )
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                    if len(attentions):
         | 
| 428 | 
            +
                        paths = renew_attention_paths(attentions)
         | 
| 429 | 
            +
                        meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
         | 
| 430 | 
            +
                        assign_to_checkpoint(
         | 
| 431 | 
            +
                            paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
         | 
| 432 | 
            +
                        )
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                resnet_0 = middle_blocks[0]
         | 
| 435 | 
            +
                attentions = middle_blocks[1]
         | 
| 436 | 
            +
                resnet_1 = middle_blocks[2]
         | 
| 437 | 
            +
             | 
| 438 | 
            +
                resnet_0_paths = renew_resnet_paths(resnet_0)
         | 
| 439 | 
            +
                assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                resnet_1_paths = renew_resnet_paths(resnet_1)
         | 
| 442 | 
            +
                assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                attentions_paths = renew_attention_paths(attentions)
         | 
| 445 | 
            +
                meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
         | 
| 446 | 
            +
                assign_to_checkpoint(
         | 
| 447 | 
            +
                    attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
         | 
| 448 | 
            +
                )
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                for i in range(num_output_blocks):
         | 
| 451 | 
            +
                    block_id = i // (config["layers_per_block"] + 1)
         | 
| 452 | 
            +
                    layer_in_block_id = i % (config["layers_per_block"] + 1)
         | 
| 453 | 
            +
                    output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
         | 
| 454 | 
            +
                    output_block_list = {}
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                    for layer in output_block_layers:
         | 
| 457 | 
            +
                        layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
         | 
| 458 | 
            +
                        if layer_id in output_block_list:
         | 
| 459 | 
            +
                            output_block_list[layer_id].append(layer_name)
         | 
| 460 | 
            +
                        else:
         | 
| 461 | 
            +
                            output_block_list[layer_id] = [layer_name]
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    if len(output_block_list) > 1:
         | 
| 464 | 
            +
                        resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
         | 
| 465 | 
            +
                        attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
         | 
| 466 | 
            +
             | 
| 467 | 
            +
                        resnet_0_paths = renew_resnet_paths(resnets)
         | 
| 468 | 
            +
                        paths = renew_resnet_paths(resnets)
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                        meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
         | 
| 471 | 
            +
                        assign_to_checkpoint(
         | 
| 472 | 
            +
                            paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
         | 
| 473 | 
            +
                        )
         | 
| 474 | 
            +
             | 
| 475 | 
            +
                        output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
         | 
| 476 | 
            +
                        if ["conv.bias", "conv.weight"] in output_block_list.values():
         | 
| 477 | 
            +
                            index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
         | 
| 478 | 
            +
                            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
         | 
| 479 | 
            +
                                f"output_blocks.{i}.{index}.conv.weight"
         | 
| 480 | 
            +
                            ]
         | 
| 481 | 
            +
                            new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
         | 
| 482 | 
            +
                                f"output_blocks.{i}.{index}.conv.bias"
         | 
| 483 | 
            +
                            ]
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                            # Clear attentions as they have been attributed above.
         | 
| 486 | 
            +
                            if len(attentions) == 2:
         | 
| 487 | 
            +
                                attentions = []
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                        if len(attentions):
         | 
| 490 | 
            +
                            paths = renew_attention_paths(attentions)
         | 
| 491 | 
            +
                            meta_path = {
         | 
| 492 | 
            +
                                "old": f"output_blocks.{i}.1",
         | 
| 493 | 
            +
                                "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
         | 
| 494 | 
            +
                            }
         | 
| 495 | 
            +
                            assign_to_checkpoint(
         | 
| 496 | 
            +
                                paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
         | 
| 497 | 
            +
                            )
         | 
| 498 | 
            +
                    else:
         | 
| 499 | 
            +
                        resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
         | 
| 500 | 
            +
                        for path in resnet_0_paths:
         | 
| 501 | 
            +
                            old_path = ".".join(["output_blocks", str(i), path["old"]])
         | 
| 502 | 
            +
                            new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
         | 
| 503 | 
            +
             | 
| 504 | 
            +
                            new_checkpoint[new_path] = unet_state_dict[old_path]
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                if controlnet:
         | 
| 507 | 
            +
                    # conditioning embedding
         | 
| 508 | 
            +
             | 
| 509 | 
            +
                    orig_index = 0
         | 
| 510 | 
            +
             | 
| 511 | 
            +
                    new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
         | 
| 512 | 
            +
                        f"input_hint_block.{orig_index}.weight"
         | 
| 513 | 
            +
                    )
         | 
| 514 | 
            +
                    new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
         | 
| 515 | 
            +
                        f"input_hint_block.{orig_index}.bias"
         | 
| 516 | 
            +
                    )
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                    orig_index += 2
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                    diffusers_index = 0
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    while diffusers_index < 6:
         | 
| 523 | 
            +
                        new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
         | 
| 524 | 
            +
                            f"input_hint_block.{orig_index}.weight"
         | 
| 525 | 
            +
                        )
         | 
| 526 | 
            +
                        new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
         | 
| 527 | 
            +
                            f"input_hint_block.{orig_index}.bias"
         | 
| 528 | 
            +
                        )
         | 
| 529 | 
            +
                        diffusers_index += 1
         | 
| 530 | 
            +
                        orig_index += 2
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
         | 
| 533 | 
            +
                        f"input_hint_block.{orig_index}.weight"
         | 
| 534 | 
            +
                    )
         | 
| 535 | 
            +
                    new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
         | 
| 536 | 
            +
                        f"input_hint_block.{orig_index}.bias"
         | 
| 537 | 
            +
                    )
         | 
| 538 | 
            +
             | 
| 539 | 
            +
                    # down blocks
         | 
| 540 | 
            +
                    for i in range(num_input_blocks):
         | 
| 541 | 
            +
                        new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
         | 
| 542 | 
            +
                        new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                    # mid block
         | 
| 545 | 
            +
                    new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
         | 
| 546 | 
            +
                    new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                return new_checkpoint
         | 
| 549 | 
            +
             | 
| 550 | 
            +
             | 
| 551 | 
            +
            def convert_ldm_vae_checkpoint(checkpoint, config):
         | 
| 552 | 
            +
                # extract state dict for VAE
         | 
| 553 | 
            +
                vae_state_dict = {}
         | 
| 554 | 
            +
                vae_key = "first_stage_model."
         | 
| 555 | 
            +
                keys = list(checkpoint.keys())
         | 
| 556 | 
            +
                for key in keys:
         | 
| 557 | 
            +
                    if key.startswith(vae_key):
         | 
| 558 | 
            +
                        vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                new_checkpoint = {}
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
         | 
| 563 | 
            +
                new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
         | 
| 564 | 
            +
                new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
         | 
| 565 | 
            +
                new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
         | 
| 566 | 
            +
                new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
         | 
| 567 | 
            +
                new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
         | 
| 568 | 
            +
             | 
| 569 | 
            +
                new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
         | 
| 570 | 
            +
                new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
         | 
| 571 | 
            +
                new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
         | 
| 572 | 
            +
                new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
         | 
| 573 | 
            +
                new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
         | 
| 574 | 
            +
                new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
         | 
| 575 | 
            +
             | 
| 576 | 
            +
                new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
         | 
| 577 | 
            +
                new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
         | 
| 578 | 
            +
                new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
         | 
| 579 | 
            +
                new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
         | 
| 580 | 
            +
             | 
| 581 | 
            +
                # Retrieves the keys for the encoder down blocks only
         | 
| 582 | 
            +
                num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
         | 
| 583 | 
            +
                down_blocks = {
         | 
| 584 | 
            +
                    layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
         | 
| 585 | 
            +
                }
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                # Retrieves the keys for the decoder up blocks only
         | 
| 588 | 
            +
                num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
         | 
| 589 | 
            +
                up_blocks = {
         | 
| 590 | 
            +
                    layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
         | 
| 591 | 
            +
                }
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                for i in range(num_down_blocks):
         | 
| 594 | 
            +
                    resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                    if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
         | 
| 597 | 
            +
                        new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
         | 
| 598 | 
            +
                            f"encoder.down.{i}.downsample.conv.weight"
         | 
| 599 | 
            +
                        )
         | 
| 600 | 
            +
                        new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
         | 
| 601 | 
            +
                            f"encoder.down.{i}.downsample.conv.bias"
         | 
| 602 | 
            +
                        )
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 605 | 
            +
                    meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
         | 
| 606 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 607 | 
            +
             | 
| 608 | 
            +
                mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
         | 
| 609 | 
            +
                num_mid_res_blocks = 2
         | 
| 610 | 
            +
                for i in range(1, num_mid_res_blocks + 1):
         | 
| 611 | 
            +
                    resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 614 | 
            +
                    meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
         | 
| 615 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 616 | 
            +
             | 
| 617 | 
            +
                mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
         | 
| 618 | 
            +
                paths = renew_vae_attention_paths(mid_attentions)
         | 
| 619 | 
            +
                meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
         | 
| 620 | 
            +
                assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 621 | 
            +
                conv_attn_to_linear(new_checkpoint)
         | 
| 622 | 
            +
             | 
| 623 | 
            +
                for i in range(num_up_blocks):
         | 
| 624 | 
            +
                    block_id = num_up_blocks - 1 - i
         | 
| 625 | 
            +
                    resnets = [
         | 
| 626 | 
            +
                        key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
         | 
| 627 | 
            +
                    ]
         | 
| 628 | 
            +
             | 
| 629 | 
            +
                    if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
         | 
| 630 | 
            +
                        new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
         | 
| 631 | 
            +
                            f"decoder.up.{block_id}.upsample.conv.weight"
         | 
| 632 | 
            +
                        ]
         | 
| 633 | 
            +
                        new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
         | 
| 634 | 
            +
                            f"decoder.up.{block_id}.upsample.conv.bias"
         | 
| 635 | 
            +
                        ]
         | 
| 636 | 
            +
             | 
| 637 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 638 | 
            +
                    meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
         | 
| 639 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 640 | 
            +
             | 
| 641 | 
            +
                mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
         | 
| 642 | 
            +
                num_mid_res_blocks = 2
         | 
| 643 | 
            +
                for i in range(1, num_mid_res_blocks + 1):
         | 
| 644 | 
            +
                    resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
         | 
| 645 | 
            +
             | 
| 646 | 
            +
                    paths = renew_vae_resnet_paths(resnets)
         | 
| 647 | 
            +
                    meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
         | 
| 648 | 
            +
                    assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
         | 
| 651 | 
            +
                paths = renew_vae_attention_paths(mid_attentions)
         | 
| 652 | 
            +
                meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
         | 
| 653 | 
            +
                assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
         | 
| 654 | 
            +
                conv_attn_to_linear(new_checkpoint)
         | 
| 655 | 
            +
                return new_checkpoint
         | 
| 656 | 
            +
             | 
| 657 | 
            +
             | 
| 658 | 
            +
            def convert_ldm_bert_checkpoint(checkpoint, config):
         | 
| 659 | 
            +
                def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
         | 
| 660 | 
            +
                    hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
         | 
| 661 | 
            +
                    hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
         | 
| 662 | 
            +
                    hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                    hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
         | 
| 665 | 
            +
                    hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
         | 
| 666 | 
            +
             | 
| 667 | 
            +
                def _copy_linear(hf_linear, pt_linear):
         | 
| 668 | 
            +
                    hf_linear.weight = pt_linear.weight
         | 
| 669 | 
            +
                    hf_linear.bias = pt_linear.bias
         | 
| 670 | 
            +
             | 
| 671 | 
            +
                def _copy_layer(hf_layer, pt_layer):
         | 
| 672 | 
            +
                    # copy layer norms
         | 
| 673 | 
            +
                    _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
         | 
| 674 | 
            +
                    _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
         | 
| 675 | 
            +
             | 
| 676 | 
            +
                    # copy attn
         | 
| 677 | 
            +
                    _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
         | 
| 678 | 
            +
             | 
| 679 | 
            +
                    # copy MLP
         | 
| 680 | 
            +
                    pt_mlp = pt_layer[1][1]
         | 
| 681 | 
            +
                    _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
         | 
| 682 | 
            +
                    _copy_linear(hf_layer.fc2, pt_mlp.net[2])
         | 
| 683 | 
            +
             | 
| 684 | 
            +
                def _copy_layers(hf_layers, pt_layers):
         | 
| 685 | 
            +
                    for i, hf_layer in enumerate(hf_layers):
         | 
| 686 | 
            +
                        if i != 0:
         | 
| 687 | 
            +
                            i += i
         | 
| 688 | 
            +
                        pt_layer = pt_layers[i : i + 2]
         | 
| 689 | 
            +
                        _copy_layer(hf_layer, pt_layer)
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                hf_model = LDMBertModel(config).eval()
         | 
| 692 | 
            +
             | 
| 693 | 
            +
                # copy  embeds
         | 
| 694 | 
            +
                hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
         | 
| 695 | 
            +
                hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                # copy layer norm
         | 
| 698 | 
            +
                _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
         | 
| 699 | 
            +
             | 
| 700 | 
            +
                # copy hidden layers
         | 
| 701 | 
            +
                _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
         | 
| 702 | 
            +
             | 
| 703 | 
            +
                _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                return hf_model
         | 
| 706 | 
            +
             | 
| 707 | 
            +
             | 
| 708 | 
            +
            def convert_ldm_clip_checkpoint(checkpoint):
         | 
| 709 | 
            +
                text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
         | 
| 710 | 
            +
                keys = list(checkpoint.keys())
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                text_model_dict = {}
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                for key in keys:
         | 
| 715 | 
            +
                    if key.startswith("cond_stage_model.transformer"):
         | 
| 716 | 
            +
                        text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
         | 
| 717 | 
            +
             | 
| 718 | 
            +
                text_model.load_state_dict(text_model_dict)
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                return text_model
         | 
| 721 | 
            +
             | 
| 722 | 
            +
             | 
| 723 | 
            +
            textenc_conversion_lst = [
         | 
| 724 | 
            +
                ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
         | 
| 725 | 
            +
                ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
         | 
| 726 | 
            +
                ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
         | 
| 727 | 
            +
                ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
         | 
| 728 | 
            +
            ]
         | 
| 729 | 
            +
            textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
         | 
| 730 | 
            +
             | 
| 731 | 
            +
            textenc_transformer_conversion_lst = [
         | 
| 732 | 
            +
                # (stable-diffusion, HF Diffusers)
         | 
| 733 | 
            +
                ("resblocks.", "text_model.encoder.layers."),
         | 
| 734 | 
            +
                ("ln_1", "layer_norm1"),
         | 
| 735 | 
            +
                ("ln_2", "layer_norm2"),
         | 
| 736 | 
            +
                (".c_fc.", ".fc1."),
         | 
| 737 | 
            +
                (".c_proj.", ".fc2."),
         | 
| 738 | 
            +
                (".attn", ".self_attn"),
         | 
| 739 | 
            +
                ("ln_final.", "transformer.text_model.final_layer_norm."),
         | 
| 740 | 
            +
                ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
         | 
| 741 | 
            +
                ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
         | 
| 742 | 
            +
            ]
         | 
| 743 | 
            +
            protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
         | 
| 744 | 
            +
            textenc_pattern = re.compile("|".join(protected.keys()))
         | 
| 745 | 
            +
             | 
| 746 | 
            +
             | 
| 747 | 
            +
            def convert_paint_by_example_checkpoint(checkpoint):
         | 
| 748 | 
            +
                config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
         | 
| 749 | 
            +
                model = PaintByExampleImageEncoder(config)
         | 
| 750 | 
            +
             | 
| 751 | 
            +
                keys = list(checkpoint.keys())
         | 
| 752 | 
            +
             | 
| 753 | 
            +
                text_model_dict = {}
         | 
| 754 | 
            +
             | 
| 755 | 
            +
                for key in keys:
         | 
| 756 | 
            +
                    if key.startswith("cond_stage_model.transformer"):
         | 
| 757 | 
            +
                        text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
         | 
| 758 | 
            +
             | 
| 759 | 
            +
                # load clip vision
         | 
| 760 | 
            +
                model.model.load_state_dict(text_model_dict)
         | 
| 761 | 
            +
             | 
| 762 | 
            +
                # load mapper
         | 
| 763 | 
            +
                keys_mapper = {
         | 
| 764 | 
            +
                    k[len("cond_stage_model.mapper.res") :]: v
         | 
| 765 | 
            +
                    for k, v in checkpoint.items()
         | 
| 766 | 
            +
                    if k.startswith("cond_stage_model.mapper")
         | 
| 767 | 
            +
                }
         | 
| 768 | 
            +
             | 
| 769 | 
            +
                MAPPING = {
         | 
| 770 | 
            +
                    "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
         | 
| 771 | 
            +
                    "attn.c_proj": ["attn1.to_out.0"],
         | 
| 772 | 
            +
                    "ln_1": ["norm1"],
         | 
| 773 | 
            +
                    "ln_2": ["norm3"],
         | 
| 774 | 
            +
                    "mlp.c_fc": ["ff.net.0.proj"],
         | 
| 775 | 
            +
                    "mlp.c_proj": ["ff.net.2"],
         | 
| 776 | 
            +
                }
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                mapped_weights = {}
         | 
| 779 | 
            +
                for key, value in keys_mapper.items():
         | 
| 780 | 
            +
                    prefix = key[: len("blocks.i")]
         | 
| 781 | 
            +
                    suffix = key.split(prefix)[-1].split(".")[-1]
         | 
| 782 | 
            +
                    name = key.split(prefix)[-1].split(suffix)[0][1:-1]
         | 
| 783 | 
            +
                    mapped_names = MAPPING[name]
         | 
| 784 | 
            +
             | 
| 785 | 
            +
                    num_splits = len(mapped_names)
         | 
| 786 | 
            +
                    for i, mapped_name in enumerate(mapped_names):
         | 
| 787 | 
            +
                        new_name = ".".join([prefix, mapped_name, suffix])
         | 
| 788 | 
            +
                        shape = value.shape[0] // num_splits
         | 
| 789 | 
            +
                        mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
         | 
| 790 | 
            +
             | 
| 791 | 
            +
                model.mapper.load_state_dict(mapped_weights)
         | 
| 792 | 
            +
             | 
| 793 | 
            +
                # load final layer norm
         | 
| 794 | 
            +
                model.final_layer_norm.load_state_dict(
         | 
| 795 | 
            +
                    {
         | 
| 796 | 
            +
                        "bias": checkpoint["cond_stage_model.final_ln.bias"],
         | 
| 797 | 
            +
                        "weight": checkpoint["cond_stage_model.final_ln.weight"],
         | 
| 798 | 
            +
                    }
         | 
| 799 | 
            +
                )
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                # load final proj
         | 
| 802 | 
            +
                model.proj_out.load_state_dict(
         | 
| 803 | 
            +
                    {
         | 
| 804 | 
            +
                        "bias": checkpoint["proj_out.bias"],
         | 
| 805 | 
            +
                        "weight": checkpoint["proj_out.weight"],
         | 
| 806 | 
            +
                    }
         | 
| 807 | 
            +
                )
         | 
| 808 | 
            +
             | 
| 809 | 
            +
                # load uncond vector
         | 
| 810 | 
            +
                model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
         | 
| 811 | 
            +
                return model
         | 
| 812 | 
            +
             | 
| 813 | 
            +
             | 
| 814 | 
            +
            def convert_open_clip_checkpoint(checkpoint):
         | 
| 815 | 
            +
                text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
         | 
| 816 | 
            +
             | 
| 817 | 
            +
                keys = list(checkpoint.keys())
         | 
| 818 | 
            +
             | 
| 819 | 
            +
                text_model_dict = {}
         | 
| 820 | 
            +
             | 
| 821 | 
            +
                if "cond_stage_model.model.text_projection" in checkpoint:
         | 
| 822 | 
            +
                    d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
         | 
| 823 | 
            +
                else:
         | 
| 824 | 
            +
                    d_model = 1024
         | 
| 825 | 
            +
             | 
| 826 | 
            +
                text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
         | 
| 827 | 
            +
             | 
| 828 | 
            +
                for key in keys:
         | 
| 829 | 
            +
                    if "resblocks.23" in key:  # Diffusers drops the final layer and only uses the penultimate layer
         | 
| 830 | 
            +
                        continue
         | 
| 831 | 
            +
                    if key in textenc_conversion_map:
         | 
| 832 | 
            +
                        text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
         | 
| 833 | 
            +
                    if key.startswith("cond_stage_model.model.transformer."):
         | 
| 834 | 
            +
                        new_key = key[len("cond_stage_model.model.transformer.") :]
         | 
| 835 | 
            +
                        if new_key.endswith(".in_proj_weight"):
         | 
| 836 | 
            +
                            new_key = new_key[: -len(".in_proj_weight")]
         | 
| 837 | 
            +
                            new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
         | 
| 838 | 
            +
                            text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
         | 
| 839 | 
            +
                            text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
         | 
| 840 | 
            +
                            text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
         | 
| 841 | 
            +
                        elif new_key.endswith(".in_proj_bias"):
         | 
| 842 | 
            +
                            new_key = new_key[: -len(".in_proj_bias")]
         | 
| 843 | 
            +
                            new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
         | 
| 844 | 
            +
                            text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
         | 
| 845 | 
            +
                            text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
         | 
| 846 | 
            +
                            text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
         | 
| 847 | 
            +
                        else:
         | 
| 848 | 
            +
                            new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
         | 
| 849 | 
            +
             | 
| 850 | 
            +
                            text_model_dict[new_key] = checkpoint[key]
         | 
| 851 | 
            +
             | 
| 852 | 
            +
                text_model.load_state_dict(text_model_dict)
         | 
| 853 | 
            +
             | 
| 854 | 
            +
                return text_model
         | 
| 855 | 
            +
             | 
| 856 | 
            +
             | 
| 857 | 
            +
            def stable_unclip_image_encoder(original_config):
         | 
| 858 | 
            +
                """
         | 
| 859 | 
            +
                Returns the image processor and clip image encoder for the img2img unclip pipeline.
         | 
| 860 | 
            +
             | 
| 861 | 
            +
                We currently know of two types of stable unclip models which separately use the clip and the openclip image
         | 
| 862 | 
            +
                encoders.
         | 
| 863 | 
            +
                """
         | 
| 864 | 
            +
             | 
| 865 | 
            +
                image_embedder_config = original_config.model.params.embedder_config
         | 
| 866 | 
            +
             | 
| 867 | 
            +
                sd_clip_image_embedder_class = image_embedder_config.target
         | 
| 868 | 
            +
                sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
         | 
| 869 | 
            +
             | 
| 870 | 
            +
                if sd_clip_image_embedder_class == "ClipImageEmbedder":
         | 
| 871 | 
            +
                    clip_model_name = image_embedder_config.params.model
         | 
| 872 | 
            +
             | 
| 873 | 
            +
                    if clip_model_name == "ViT-L/14":
         | 
| 874 | 
            +
                        feature_extractor = CLIPImageProcessor()
         | 
| 875 | 
            +
                        image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
         | 
| 876 | 
            +
                    else:
         | 
| 877 | 
            +
                        raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
         | 
| 878 | 
            +
             | 
| 879 | 
            +
                elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
         | 
| 880 | 
            +
                    feature_extractor = CLIPImageProcessor()
         | 
| 881 | 
            +
                    image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
         | 
| 882 | 
            +
                else:
         | 
| 883 | 
            +
                    raise NotImplementedError(
         | 
| 884 | 
            +
                        f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
         | 
| 885 | 
            +
                    )
         | 
| 886 | 
            +
             | 
| 887 | 
            +
                return feature_extractor, image_encoder
         | 
| 888 | 
            +
             | 
| 889 | 
            +
             | 
| 890 | 
            +
            def stable_unclip_image_noising_components(
         | 
| 891 | 
            +
                original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
         | 
| 892 | 
            +
            ):
         | 
| 893 | 
            +
                """
         | 
| 894 | 
            +
                Returns the noising components for the img2img and txt2img unclip pipelines.
         | 
| 895 | 
            +
             | 
| 896 | 
            +
                Converts the stability noise augmentor into
         | 
| 897 | 
            +
                1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
         | 
| 898 | 
            +
                2. a `DDPMScheduler` for holding the noise schedule
         | 
| 899 | 
            +
             | 
| 900 | 
            +
                If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
         | 
| 901 | 
            +
                """
         | 
| 902 | 
            +
                noise_aug_config = original_config.model.params.noise_aug_config
         | 
| 903 | 
            +
                noise_aug_class = noise_aug_config.target
         | 
| 904 | 
            +
                noise_aug_class = noise_aug_class.split(".")[-1]
         | 
| 905 | 
            +
             | 
| 906 | 
            +
                if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
         | 
| 907 | 
            +
                    noise_aug_config = noise_aug_config.params
         | 
| 908 | 
            +
                    embedding_dim = noise_aug_config.timestep_dim
         | 
| 909 | 
            +
                    max_noise_level = noise_aug_config.noise_schedule_config.timesteps
         | 
| 910 | 
            +
                    beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
         | 
| 911 | 
            +
             | 
| 912 | 
            +
                    image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
         | 
| 913 | 
            +
                    image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
         | 
| 914 | 
            +
             | 
| 915 | 
            +
                    if "clip_stats_path" in noise_aug_config:
         | 
| 916 | 
            +
                        if clip_stats_path is None:
         | 
| 917 | 
            +
                            raise ValueError("This stable unclip config requires a `clip_stats_path`")
         | 
| 918 | 
            +
             | 
| 919 | 
            +
                        clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
         | 
| 920 | 
            +
                        clip_mean = clip_mean[None, :]
         | 
| 921 | 
            +
                        clip_std = clip_std[None, :]
         | 
| 922 | 
            +
             | 
| 923 | 
            +
                        clip_stats_state_dict = {
         | 
| 924 | 
            +
                            "mean": clip_mean,
         | 
| 925 | 
            +
                            "std": clip_std,
         | 
| 926 | 
            +
                        }
         | 
| 927 | 
            +
             | 
| 928 | 
            +
                        image_normalizer.load_state_dict(clip_stats_state_dict)
         | 
| 929 | 
            +
                else:
         | 
| 930 | 
            +
                    raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
         | 
| 931 | 
            +
             | 
| 932 | 
            +
                return image_normalizer, image_noising_scheduler
         | 
| 933 | 
            +
             | 
| 934 | 
            +
             | 
| 935 | 
            +
            def convert_controlnet_checkpoint(
         | 
| 936 | 
            +
                checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
         | 
| 937 | 
            +
            ):
         | 
| 938 | 
            +
                ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
         | 
| 939 | 
            +
                ctrlnet_config["upcast_attention"] = upcast_attention
         | 
| 940 | 
            +
             | 
| 941 | 
            +
                ctrlnet_config.pop("sample_size")
         | 
| 942 | 
            +
             | 
| 943 | 
            +
                controlnet_model = ControlNetModel(**ctrlnet_config)
         | 
| 944 | 
            +
             | 
| 945 | 
            +
                converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
         | 
| 946 | 
            +
                    checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
         | 
| 947 | 
            +
                )
         | 
| 948 | 
            +
             | 
| 949 | 
            +
                controlnet_model.load_state_dict(converted_ctrl_checkpoint)
         | 
| 950 | 
            +
             | 
| 951 | 
            +
                return controlnet_model
         | 
    	
        animatelcm/utils/convert_lora_safetensor_to_diffusers.py
    ADDED
    
    | @@ -0,0 +1,152 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # coding=utf-8
         | 
| 2 | 
            +
            # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
         | 
| 3 | 
            +
            #
         | 
| 4 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 5 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 6 | 
            +
            # You may obtain a copy of the License at
         | 
| 7 | 
            +
            #
         | 
| 8 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 9 | 
            +
            #
         | 
| 10 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 11 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 12 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 13 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 14 | 
            +
            # limitations under the License.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            """ Conversion script for the LoRA's safetensors checkpoints. """
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            import argparse
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            from safetensors.torch import load_file
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from diffusers import StableDiffusionPipeline
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
         | 
| 27 | 
            +
                # directly update weight in diffusers model
         | 
| 28 | 
            +
                for key in state_dict:
         | 
| 29 | 
            +
                    # only process lora down key
         | 
| 30 | 
            +
                    if "up." in key: continue
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    up_key    = key.replace(".down.", ".up.")
         | 
| 33 | 
            +
                    model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
         | 
| 34 | 
            +
                    model_key = model_key.replace("to_out.", "to_out.0.")
         | 
| 35 | 
            +
                    layer_infos = model_key.split(".")[:-1]
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    curr_layer = pipeline.unet
         | 
| 38 | 
            +
                    while len(layer_infos) > 0:
         | 
| 39 | 
            +
                        temp_name = layer_infos.pop(0)
         | 
| 40 | 
            +
                        curr_layer = curr_layer.__getattr__(temp_name)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    weight_down = state_dict[key]
         | 
| 43 | 
            +
                    weight_up   = state_dict[up_key]
         | 
| 44 | 
            +
                    curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                return pipeline
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
             | 
| 50 | 
            +
            def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
         | 
| 51 | 
            +
                # load base model
         | 
| 52 | 
            +
                # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # load LoRA weight from .safetensors
         | 
| 55 | 
            +
                # state_dict = load_file(checkpoint_path)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                visited = []
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # directly update weight in diffusers model
         | 
| 60 | 
            +
                for key in state_dict:
         | 
| 61 | 
            +
                    # it is suggested to print out the key, it usually will be something like below
         | 
| 62 | 
            +
                    # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # as we have set the alpha beforehand, so just skip
         | 
| 65 | 
            +
                    if ".alpha" in key or key in visited:
         | 
| 66 | 
            +
                        continue
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    if "text" in key:
         | 
| 69 | 
            +
                        layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
         | 
| 70 | 
            +
                        curr_layer = pipeline.text_encoder
         | 
| 71 | 
            +
                    else:
         | 
| 72 | 
            +
                        layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
         | 
| 73 | 
            +
                        curr_layer = pipeline.unet
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    # find the target layer
         | 
| 76 | 
            +
                    temp_name = layer_infos.pop(0)
         | 
| 77 | 
            +
                    while len(layer_infos) > -1:
         | 
| 78 | 
            +
                        try:
         | 
| 79 | 
            +
                            curr_layer = curr_layer.__getattr__(temp_name)
         | 
| 80 | 
            +
                            if len(layer_infos) > 0:
         | 
| 81 | 
            +
                                temp_name = layer_infos.pop(0)
         | 
| 82 | 
            +
                            elif len(layer_infos) == 0:
         | 
| 83 | 
            +
                                break
         | 
| 84 | 
            +
                        except Exception:
         | 
| 85 | 
            +
                            if len(temp_name) > 0:
         | 
| 86 | 
            +
                                temp_name += "_" + layer_infos.pop(0)
         | 
| 87 | 
            +
                            else:
         | 
| 88 | 
            +
                                temp_name = layer_infos.pop(0)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    pair_keys = []
         | 
| 91 | 
            +
                    if "lora_down" in key:
         | 
| 92 | 
            +
                        pair_keys.append(key.replace("lora_down", "lora_up"))
         | 
| 93 | 
            +
                        pair_keys.append(key)
         | 
| 94 | 
            +
                    else:
         | 
| 95 | 
            +
                        pair_keys.append(key)
         | 
| 96 | 
            +
                        pair_keys.append(key.replace("lora_up", "lora_down"))
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    # update weight
         | 
| 99 | 
            +
                    if len(state_dict[pair_keys[0]].shape) == 4:
         | 
| 100 | 
            +
                        weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
         | 
| 101 | 
            +
                        weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
         | 
| 102 | 
            +
                        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        weight_up = state_dict[pair_keys[0]].to(torch.float32)
         | 
| 105 | 
            +
                        weight_down = state_dict[pair_keys[1]].to(torch.float32)
         | 
| 106 | 
            +
                        curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    # update visited list
         | 
| 109 | 
            +
                    for item in pair_keys:
         | 
| 110 | 
            +
                        visited.append(item)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                return pipeline
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            if __name__ == "__main__":
         | 
| 116 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                parser.add_argument(
         | 
| 119 | 
            +
                    "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
         | 
| 120 | 
            +
                )
         | 
| 121 | 
            +
                parser.add_argument(
         | 
| 122 | 
            +
                    "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
         | 
| 123 | 
            +
                )
         | 
| 124 | 
            +
                parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
         | 
| 125 | 
            +
                parser.add_argument(
         | 
| 126 | 
            +
                    "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
         | 
| 127 | 
            +
                )
         | 
| 128 | 
            +
                parser.add_argument(
         | 
| 129 | 
            +
                    "--lora_prefix_text_encoder",
         | 
| 130 | 
            +
                    default="lora_te",
         | 
| 131 | 
            +
                    type=str,
         | 
| 132 | 
            +
                    help="The prefix of text encoder weight in safetensors",
         | 
| 133 | 
            +
                )
         | 
| 134 | 
            +
                parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
         | 
| 135 | 
            +
                parser.add_argument(
         | 
| 136 | 
            +
                    "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
         | 
| 137 | 
            +
                )
         | 
| 138 | 
            +
                parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                args = parser.parse_args()
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                base_model_path = args.base_model_path
         | 
| 143 | 
            +
                checkpoint_path = args.checkpoint_path
         | 
| 144 | 
            +
                dump_path = args.dump_path
         | 
| 145 | 
            +
                lora_prefix_unet = args.lora_prefix_unet
         | 
| 146 | 
            +
                lora_prefix_text_encoder = args.lora_prefix_text_encoder
         | 
| 147 | 
            +
                alpha = args.alpha
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                pipe = pipe.to(args.device)
         | 
| 152 | 
            +
                pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
         | 
    	
        animatelcm/utils/lcm_utils.py
    ADDED
    
    | @@ -0,0 +1,237 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            from safetensors import safe_open
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def guidance_scale_embedding(w, embedding_dim=512, dtype=torch.float32):
         | 
| 7 | 
            +
                """
         | 
| 8 | 
            +
                See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                Args:
         | 
| 11 | 
            +
                    timesteps (`torch.Tensor`):
         | 
| 12 | 
            +
                        generate embedding vectors at these timesteps
         | 
| 13 | 
            +
                    embedding_dim (`int`, *optional*, defaults to 512):
         | 
| 14 | 
            +
                        dimension of the embeddings to generate
         | 
| 15 | 
            +
                    dtype:
         | 
| 16 | 
            +
                        data type of the generated embeddings
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Returns:
         | 
| 19 | 
            +
                    `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
         | 
| 20 | 
            +
                """
         | 
| 21 | 
            +
                assert len(w.shape) == 1
         | 
| 22 | 
            +
                w = w * 1000.0
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                half_dim = embedding_dim // 2
         | 
| 25 | 
            +
                emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
         | 
| 26 | 
            +
                emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
         | 
| 27 | 
            +
                emb = w.to(dtype)[:, None] * emb[None, :]
         | 
| 28 | 
            +
                emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
         | 
| 29 | 
            +
                if embedding_dim % 2 == 1:  # zero pad
         | 
| 30 | 
            +
                    emb = torch.nn.functional.pad(emb, (0, 1))
         | 
| 31 | 
            +
                assert emb.shape == (w.shape[0], embedding_dim)
         | 
| 32 | 
            +
                return emb
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
            def append_dims(x, target_dims):
         | 
| 36 | 
            +
                """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
         | 
| 37 | 
            +
                dims_to_append = target_dims - x.ndim
         | 
| 38 | 
            +
                if dims_to_append < 0:
         | 
| 39 | 
            +
                    raise ValueError(
         | 
| 40 | 
            +
                        f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
         | 
| 41 | 
            +
                return x[(...,) + (None,) * dims_to_append]
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            # From LCMScheduler.get_scalings_for_boundary_condition_discrete
         | 
| 45 | 
            +
            def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0):
         | 
| 46 | 
            +
                c_skip = sigma_data**2 / ((timestep / 0.1) ** 2 + sigma_data**2)
         | 
| 47 | 
            +
                c_out = (timestep / 0.1) / ((timestep / 0.1) ** 2 + sigma_data**2) ** 0.5
         | 
| 48 | 
            +
                return c_skip, c_out
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            # Compare LCMScheduler.step, Step 4
         | 
| 52 | 
            +
            def predicted_origin(model_output, timesteps, sample, prediction_type, alphas, sigmas):
         | 
| 53 | 
            +
                if prediction_type == "epsilon":
         | 
| 54 | 
            +
                    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
         | 
| 55 | 
            +
                    alphas = extract_into_tensor(alphas, timesteps, sample.shape)
         | 
| 56 | 
            +
                    pred_x_0 = (sample - sigmas * model_output) / alphas
         | 
| 57 | 
            +
                elif prediction_type == "v_prediction":
         | 
| 58 | 
            +
                    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
         | 
| 59 | 
            +
                    alphas = extract_into_tensor(alphas, timesteps, sample.shape)
         | 
| 60 | 
            +
                    pred_x_0 = alphas * sample - sigmas * model_output
         | 
| 61 | 
            +
                else:
         | 
| 62 | 
            +
                    raise ValueError(
         | 
| 63 | 
            +
                        f"Prediction type {prediction_type} currently not supported.")
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                return pred_x_0
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def scale_for_loss(timesteps, sample, prediction_type, alphas, sigmas):
         | 
| 69 | 
            +
                if prediction_type == "epsilon":
         | 
| 70 | 
            +
                    sigmas = extract_into_tensor(sigmas, timesteps, sample.shape)
         | 
| 71 | 
            +
                    alphas = extract_into_tensor(alphas, timesteps, sample.shape)
         | 
| 72 | 
            +
                    sample = sample * alphas / sigmas
         | 
| 73 | 
            +
                else:
         | 
| 74 | 
            +
                    raise ValueError(
         | 
| 75 | 
            +
                        f"Prediction type {prediction_type} currently not supported.")
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                return sample
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
            def extract_into_tensor(a, t, x_shape):
         | 
| 81 | 
            +
                b, *_ = t.shape
         | 
| 82 | 
            +
                out = a.gather(-1, t)
         | 
| 83 | 
            +
                return out.reshape(b, *((1,) * (len(x_shape) - 1)))
         | 
| 84 | 
            +
             | 
| 85 | 
            +
             | 
| 86 | 
            +
            class DDIMSolver:
         | 
| 87 | 
            +
                def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50):
         | 
| 88 | 
            +
                    # DDIM sampling parameters
         | 
| 89 | 
            +
                    step_ratio = timesteps // ddim_timesteps
         | 
| 90 | 
            +
                    self.ddim_timesteps = (
         | 
| 91 | 
            +
                        np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1
         | 
| 92 | 
            +
                    # self.ddim_timesteps = (torch.linspace(100**2,1000**2,30)**0.5).round().numpy().astype(np.int64) - 1
         | 
| 93 | 
            +
                    self.ddim_timesteps_prev = np.asarray(
         | 
| 94 | 
            +
                        [0] + self.ddim_timesteps[:-1].tolist()
         | 
| 95 | 
            +
                    )
         | 
| 96 | 
            +
                    self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
         | 
| 97 | 
            +
                    self.ddim_alpha_cumprods_prev = np.asarray(
         | 
| 98 | 
            +
                        [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
                    self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps]
         | 
| 101 | 
            +
                    self.ddim_alpha_cumprods_prev = np.asarray(
         | 
| 102 | 
            +
                        [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist()
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    # convert to torch tensors
         | 
| 105 | 
            +
                    self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long()
         | 
| 106 | 
            +
                    self.ddim_timesteps_prev = torch.from_numpy(
         | 
| 107 | 
            +
                        self.ddim_timesteps_prev).long()
         | 
| 108 | 
            +
                    self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods)
         | 
| 109 | 
            +
                    self.ddim_alpha_cumprods_prev = torch.from_numpy(
         | 
| 110 | 
            +
                        self.ddim_alpha_cumprods_prev)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def to(self, device):
         | 
| 113 | 
            +
                    self.ddim_timesteps = self.ddim_timesteps.to(device)
         | 
| 114 | 
            +
                    self.ddim_timesteps_prev = self.ddim_timesteps_prev.to(device)
         | 
| 115 | 
            +
                    self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device)
         | 
| 116 | 
            +
                    self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(
         | 
| 117 | 
            +
                        device)
         | 
| 118 | 
            +
                    return self
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def ddim_step(self, pred_x0, pred_noise, timestep_index):
         | 
| 121 | 
            +
                    alpha_cumprod_prev = extract_into_tensor(
         | 
| 122 | 
            +
                        self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape)
         | 
| 123 | 
            +
                    dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise
         | 
| 124 | 
            +
                    x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt
         | 
| 125 | 
            +
                    return x_prev
         | 
| 126 | 
            +
             | 
| 127 | 
            +
             | 
| 128 | 
            +
            @torch.no_grad()
         | 
| 129 | 
            +
            def update_ema(target_params, source_params, rate=0.99):
         | 
| 130 | 
            +
                """
         | 
| 131 | 
            +
                Update target parameters to be closer to those of source parameters using
         | 
| 132 | 
            +
                an exponential moving average.
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                :param target_params: the target parameter sequence.
         | 
| 135 | 
            +
                :param source_params: the source parameter sequence.
         | 
| 136 | 
            +
                :param rate: the EMA rate (closer to 1 means slower).
         | 
| 137 | 
            +
                """
         | 
| 138 | 
            +
                for targ, src in zip(target_params, source_params):
         | 
| 139 | 
            +
                    targ.detach().mul_(rate).add_(src, alpha=1 - rate)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            def convert_lcm_lora(unet, path, alpha=1.0):
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                if path.endswith(("ckpt",)):
         | 
| 145 | 
            +
                    state_dict = torch.load(path, map_location="cpu")
         | 
| 146 | 
            +
                else:
         | 
| 147 | 
            +
                    state_dict = {}
         | 
| 148 | 
            +
                    with safe_open(path, framework="pt", device="cpu") as f:
         | 
| 149 | 
            +
                        for key in f.keys():
         | 
| 150 | 
            +
                            state_dict[key] = f.get_tensor(key)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                num_alpha = 0
         | 
| 153 | 
            +
                for key in state_dict.keys():
         | 
| 154 | 
            +
                    if "alpha" in key:
         | 
| 155 | 
            +
                        num_alpha += 1
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                lora_keys = [k for k in state_dict.keys(
         | 
| 158 | 
            +
                ) if k.endswith("lora_down.weight")]
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                updated_state_dict = {}
         | 
| 161 | 
            +
                for key in lora_keys:
         | 
| 162 | 
            +
                    lora_name = key.split(".")[0]
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    if lora_name.startswith("lora_unet_"):
         | 
| 165 | 
            +
                        diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                        if "input.blocks" in diffusers_name:
         | 
| 168 | 
            +
                            diffusers_name = diffusers_name.replace(
         | 
| 169 | 
            +
                                "input.blocks", "down_blocks")
         | 
| 170 | 
            +
                        else:
         | 
| 171 | 
            +
                            diffusers_name = diffusers_name.replace(
         | 
| 172 | 
            +
                                "down.blocks", "down_blocks")
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                        if "middle.block" in diffusers_name:
         | 
| 175 | 
            +
                            diffusers_name = diffusers_name.replace(
         | 
| 176 | 
            +
                                "middle.block", "mid_block")
         | 
| 177 | 
            +
                        else:
         | 
| 178 | 
            +
                            diffusers_name = diffusers_name.replace(
         | 
| 179 | 
            +
                                "mid.block", "mid_block")
         | 
| 180 | 
            +
                        if "output.blocks" in diffusers_name:
         | 
| 181 | 
            +
                            diffusers_name = diffusers_name.replace(
         | 
| 182 | 
            +
                                "output.blocks", "up_blocks")
         | 
| 183 | 
            +
                        else:
         | 
| 184 | 
            +
                            diffusers_name = diffusers_name.replace(
         | 
| 185 | 
            +
                                "up.blocks", "up_blocks")
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                        diffusers_name = diffusers_name.replace(
         | 
| 188 | 
            +
                            "transformer.blocks", "transformer_blocks")
         | 
| 189 | 
            +
                        diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
         | 
| 190 | 
            +
                        diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
         | 
| 191 | 
            +
                        diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
         | 
| 192 | 
            +
                        diffusers_name = diffusers_name.replace(
         | 
| 193 | 
            +
                            "to.out.0.lora", "to_out_lora")
         | 
| 194 | 
            +
                        diffusers_name = diffusers_name.replace("proj.in", "proj_in")
         | 
| 195 | 
            +
                        diffusers_name = diffusers_name.replace("proj.out", "proj_out")
         | 
| 196 | 
            +
                        diffusers_name = diffusers_name.replace(
         | 
| 197 | 
            +
                            "time.emb.proj", "time_emb_proj")
         | 
| 198 | 
            +
                        diffusers_name = diffusers_name.replace(
         | 
| 199 | 
            +
                            "conv.shortcut", "conv_shortcut")
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                        updated_state_dict[diffusers_name] = state_dict[key]
         | 
| 202 | 
            +
                        up_diffusers_name = diffusers_name.replace(".down.", ".up.")
         | 
| 203 | 
            +
                        up_key = key.replace("lora_down.weight", "lora_up.weight")
         | 
| 204 | 
            +
                        updated_state_dict[up_diffusers_name] = state_dict[up_key]
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                state_dict = updated_state_dict
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                num_lora = 0
         | 
| 209 | 
            +
                for key in state_dict:
         | 
| 210 | 
            +
                    if "up." in key:
         | 
| 211 | 
            +
                        continue
         | 
| 212 | 
            +
                    up_key = key.replace(".down.", ".up.")
         | 
| 213 | 
            +
                    model_key = key.replace("processor.", "").replace("_lora", "").replace(
         | 
| 214 | 
            +
                        "down.", "").replace("up.", "").replace(".lora", "")
         | 
| 215 | 
            +
                    model_key = model_key.replace("to_out.", "to_out.0.")
         | 
| 216 | 
            +
                    layer_infos = model_key.split(".")[:-1]
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    curr_layer = unet
         | 
| 219 | 
            +
                    while len(layer_infos) > 0:
         | 
| 220 | 
            +
                        temp_name = layer_infos.pop(0)
         | 
| 221 | 
            +
                        curr_layer = curr_layer.__getattr__(temp_name)
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                    weight_down = state_dict[key].to(
         | 
| 224 | 
            +
                        curr_layer.weight.data.device, curr_layer.weight.data.dtype)
         | 
| 225 | 
            +
                    weight_up = state_dict[up_key].to(
         | 
| 226 | 
            +
                        curr_layer.weight.data.device, curr_layer.weight.data.dtype)
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                    if weight_up.ndim == 2:
         | 
| 229 | 
            +
                        curr_layer.weight.data += 1/8 * alpha * \
         | 
| 230 | 
            +
                            torch.mm(weight_up, weight_down)
         | 
| 231 | 
            +
                    else:
         | 
| 232 | 
            +
                        assert weight_up.ndim == 4
         | 
| 233 | 
            +
                        curr_layer.weight.data += 1/8 * alpha * torch.mm(weight_up.flatten(
         | 
| 234 | 
            +
                            start_dim=1), weight_down.flatten(start_dim=1)).reshape(curr_layer.weight.data.shape)
         | 
| 235 | 
            +
                    num_lora += 1
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                return unet
         | 
    	
        animatelcm/utils/util.py
    ADDED
    
    | @@ -0,0 +1,153 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import imageio
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from typing import Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
            import torchvision
         | 
| 8 | 
            +
            import torch.distributed as dist
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from safetensors import safe_open
         | 
| 11 | 
            +
            from tqdm import tqdm
         | 
| 12 | 
            +
            from einops import rearrange
         | 
| 13 | 
            +
            from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
         | 
| 14 | 
            +
            from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def zero_rank_print(s):
         | 
| 18 | 
            +
                if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s)
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
         | 
| 22 | 
            +
                videos = rearrange(videos, "b c t h w -> t b c h w")
         | 
| 23 | 
            +
                outputs = []
         | 
| 24 | 
            +
                for x in videos:
         | 
| 25 | 
            +
                    x = torchvision.utils.make_grid(x, nrow=n_rows)
         | 
| 26 | 
            +
                    x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
         | 
| 27 | 
            +
                    if rescale:
         | 
| 28 | 
            +
                        x = (x + 1.0) / 2.0  # -1,1 -> 0,1
         | 
| 29 | 
            +
                    x = (x * 255).numpy().astype(np.uint8)
         | 
| 30 | 
            +
                    outputs.append(x)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                os.makedirs(os.path.dirname(path), exist_ok=True)
         | 
| 33 | 
            +
                imageio.mimsave(path, outputs, fps=fps)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            # DDIM Inversion
         | 
| 37 | 
            +
            @torch.no_grad()
         | 
| 38 | 
            +
            def init_prompt(prompt, pipeline):
         | 
| 39 | 
            +
                uncond_input = pipeline.tokenizer(
         | 
| 40 | 
            +
                    [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length,
         | 
| 41 | 
            +
                    return_tensors="pt"
         | 
| 42 | 
            +
                )
         | 
| 43 | 
            +
                uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0]
         | 
| 44 | 
            +
                text_input = pipeline.tokenizer(
         | 
| 45 | 
            +
                    [prompt],
         | 
| 46 | 
            +
                    padding="max_length",
         | 
| 47 | 
            +
                    max_length=pipeline.tokenizer.model_max_length,
         | 
| 48 | 
            +
                    truncation=True,
         | 
| 49 | 
            +
                    return_tensors="pt",
         | 
| 50 | 
            +
                )
         | 
| 51 | 
            +
                text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0]
         | 
| 52 | 
            +
                context = torch.cat([uncond_embeddings, text_embeddings])
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                return context
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
         | 
| 58 | 
            +
                          sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler):
         | 
| 59 | 
            +
                timestep, next_timestep = min(
         | 
| 60 | 
            +
                    timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep
         | 
| 61 | 
            +
                alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod
         | 
| 62 | 
            +
                alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep]
         | 
| 63 | 
            +
                beta_prod_t = 1 - alpha_prod_t
         | 
| 64 | 
            +
                next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
         | 
| 65 | 
            +
                next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
         | 
| 66 | 
            +
                next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
         | 
| 67 | 
            +
                return next_sample
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def get_noise_pred_single(latents, t, context, unet):
         | 
| 71 | 
            +
                noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"]
         | 
| 72 | 
            +
                return noise_pred
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            @torch.no_grad()
         | 
| 76 | 
            +
            def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
         | 
| 77 | 
            +
                context = init_prompt(prompt, pipeline)
         | 
| 78 | 
            +
                uncond_embeddings, cond_embeddings = context.chunk(2)
         | 
| 79 | 
            +
                all_latent = [latent]
         | 
| 80 | 
            +
                latent = latent.clone().detach()
         | 
| 81 | 
            +
                for i in tqdm(range(num_inv_steps)):
         | 
| 82 | 
            +
                    t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1]
         | 
| 83 | 
            +
                    noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet)
         | 
| 84 | 
            +
                    latent = next_step(noise_pred, t, latent, ddim_scheduler)
         | 
| 85 | 
            +
                    all_latent.append(latent)
         | 
| 86 | 
            +
                return all_latent
         | 
| 87 | 
            +
             | 
| 88 | 
            +
             | 
| 89 | 
            +
            @torch.no_grad()
         | 
| 90 | 
            +
            def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
         | 
| 91 | 
            +
                ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
         | 
| 92 | 
            +
                return ddim_latents
         | 
| 93 | 
            +
             | 
| 94 | 
            +
            def load_weights(
         | 
| 95 | 
            +
                animation_pipeline,
         | 
| 96 | 
            +
                motion_module_path         = "",
         | 
| 97 | 
            +
                motion_module_lora_configs = [],
         | 
| 98 | 
            +
                dreambooth_model_path = "",
         | 
| 99 | 
            +
                lora_model_path       = "",
         | 
| 100 | 
            +
                lora_alpha            = 0.8,
         | 
| 101 | 
            +
            ):
         | 
| 102 | 
            +
                unet_state_dict = {}
         | 
| 103 | 
            +
                if motion_module_path != "":
         | 
| 104 | 
            +
                    print(f"load motion module from {motion_module_path}")
         | 
| 105 | 
            +
                    motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
         | 
| 106 | 
            +
                    motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
         | 
| 107 | 
            +
                    unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
         | 
| 108 | 
            +
                
         | 
| 109 | 
            +
                missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
         | 
| 110 | 
            +
                assert len(unexpected) == 0
         | 
| 111 | 
            +
                del unet_state_dict
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                if dreambooth_model_path != "":
         | 
| 114 | 
            +
                    print(f"load dreambooth model from {dreambooth_model_path}")
         | 
| 115 | 
            +
                    if dreambooth_model_path.endswith(".safetensors"):
         | 
| 116 | 
            +
                        dreambooth_state_dict = {}
         | 
| 117 | 
            +
                        with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
         | 
| 118 | 
            +
                            for key in f.keys():
         | 
| 119 | 
            +
                                dreambooth_state_dict[key] = f.get_tensor(key)
         | 
| 120 | 
            +
                    elif dreambooth_model_path.endswith(".ckpt"):
         | 
| 121 | 
            +
                        dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
         | 
| 122 | 
            +
                        
         | 
| 123 | 
            +
                    converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
         | 
| 124 | 
            +
                    animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
         | 
| 127 | 
            +
                    animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
         | 
| 130 | 
            +
                    del dreambooth_state_dict
         | 
| 131 | 
            +
                    
         | 
| 132 | 
            +
                if lora_model_path != "":
         | 
| 133 | 
            +
                    print(f"load lora model from {lora_model_path}")
         | 
| 134 | 
            +
                    assert lora_model_path.endswith(".safetensors")
         | 
| 135 | 
            +
                    lora_state_dict = {}
         | 
| 136 | 
            +
                    with safe_open(lora_model_path, framework="pt", device="cpu") as f:
         | 
| 137 | 
            +
                        for key in f.keys():
         | 
| 138 | 
            +
                            lora_state_dict[key] = f.get_tensor(key)
         | 
| 139 | 
            +
                            
         | 
| 140 | 
            +
                    animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
         | 
| 141 | 
            +
                    del lora_state_dict
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
                for motion_module_lora_config in motion_module_lora_configs:
         | 
| 145 | 
            +
                    path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
         | 
| 146 | 
            +
                    print(f"load motion LoRA from {path}")
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                    motion_lora_state_dict = torch.load(path, map_location="cpu")
         | 
| 149 | 
            +
                    motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                    animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                return animation_pipeline
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,392 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
             | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import json
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import random
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import gradio as gr
         | 
| 8 | 
            +
            from glob import glob
         | 
| 9 | 
            +
            from omegaconf import OmegaConf
         | 
| 10 | 
            +
            from datetime import datetime
         | 
| 11 | 
            +
            from safetensors import safe_open
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from diffusers import AutoencoderKL
         | 
| 14 | 
            +
            from diffusers.utils.import_utils import is_xformers_available
         | 
| 15 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            from animatelcm.scheduler.lcm_scheduler import LCMScheduler
         | 
| 18 | 
            +
            from animatelcm.models.unet import UNet3DConditionModel
         | 
| 19 | 
            +
            from animatelcm.pipelines.pipeline_animation import AnimationPipeline
         | 
| 20 | 
            +
            from animatelcm.utils.util import save_videos_grid
         | 
| 21 | 
            +
            from animatelcm.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
         | 
| 22 | 
            +
            from animatelcm.utils.convert_lora_safetensor_to_diffusers import convert_lora
         | 
| 23 | 
            +
            from animatelcm.utils.lcm_utils import convert_lcm_lora
         | 
| 24 | 
            +
            import copy
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            sample_idx = 0
         | 
| 27 | 
            +
            scheduler_dict = {
         | 
| 28 | 
            +
                "LCM": LCMScheduler,
         | 
| 29 | 
            +
            }
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            css = """
         | 
| 32 | 
            +
            .toolbutton {
         | 
| 33 | 
            +
                margin-buttom: 0em 0em 0em 0em;
         | 
| 34 | 
            +
                max-width: 2.5em;
         | 
| 35 | 
            +
                min-width: 2.5em !important;
         | 
| 36 | 
            +
                height: 2.5em;
         | 
| 37 | 
            +
            }
         | 
| 38 | 
            +
            """
         | 
| 39 | 
            +
             | 
| 40 | 
            +
             | 
| 41 | 
            +
            class AnimateController:
         | 
| 42 | 
            +
                def __init__(self):
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    # config dirs
         | 
| 45 | 
            +
                    self.basedir = os.getcwd()
         | 
| 46 | 
            +
                    self.stable_diffusion_dir = os.path.join(
         | 
| 47 | 
            +
                        self.basedir, "models", "StableDiffusion")
         | 
| 48 | 
            +
                    self.motion_module_dir = os.path.join(
         | 
| 49 | 
            +
                        self.basedir, "models", "Motion_Module")
         | 
| 50 | 
            +
                    self.personalized_model_dir = os.path.join(
         | 
| 51 | 
            +
                        self.basedir, "models", "DreamBooth_LoRA")
         | 
| 52 | 
            +
                    self.savedir = os.path.join(
         | 
| 53 | 
            +
                        self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
         | 
| 54 | 
            +
                    self.savedir_sample = os.path.join(self.savedir, "sample")
         | 
| 55 | 
            +
                    self.lcm_lora_path = "models/LCM_LoRA/sd15_t2v_beta_lora.safetensors"
         | 
| 56 | 
            +
                    os.makedirs(self.savedir, exist_ok=True)
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.stable_diffusion_list = []
         | 
| 59 | 
            +
                    self.motion_module_list = []
         | 
| 60 | 
            +
                    self.personalized_model_list = []
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    self.refresh_stable_diffusion()
         | 
| 63 | 
            +
                    self.refresh_motion_module()
         | 
| 64 | 
            +
                    self.refresh_personalized_model()
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    # config models
         | 
| 67 | 
            +
                    self.tokenizer = None
         | 
| 68 | 
            +
                    self.text_encoder = None
         | 
| 69 | 
            +
                    self.vae = None
         | 
| 70 | 
            +
                    self.unet = None
         | 
| 71 | 
            +
                    self.pipeline = None
         | 
| 72 | 
            +
                    self.lora_model_state_dict = {}
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                    self.inference_config = OmegaConf.load("configs/inference.yaml")
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def refresh_stable_diffusion(self):
         | 
| 77 | 
            +
                    self.stable_diffusion_list = glob(
         | 
| 78 | 
            +
                        os.path.join(self.stable_diffusion_dir, "*/"))
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                def refresh_motion_module(self):
         | 
| 81 | 
            +
                    motion_module_list = glob(os.path.join(
         | 
| 82 | 
            +
                        self.motion_module_dir, "*.ckpt"))
         | 
| 83 | 
            +
                    self.motion_module_list = [
         | 
| 84 | 
            +
                        os.path.basename(p) for p in motion_module_list]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def refresh_personalized_model(self):
         | 
| 87 | 
            +
                    personalized_model_list = glob(os.path.join(
         | 
| 88 | 
            +
                        self.personalized_model_dir, "*.safetensors"))
         | 
| 89 | 
            +
                    self.personalized_model_list = [
         | 
| 90 | 
            +
                        os.path.basename(p) for p in personalized_model_list]
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                def update_stable_diffusion(self, stable_diffusion_dropdown):
         | 
| 93 | 
            +
                    stable_diffusion_dropdown = os.path.join(self.stable_diffusion_dir,stable_diffusion_dropdown)
         | 
| 94 | 
            +
                    self.tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 95 | 
            +
                        stable_diffusion_dropdown, subfolder="tokenizer")
         | 
| 96 | 
            +
                    self.text_encoder = CLIPTextModel.from_pretrained(
         | 
| 97 | 
            +
                        stable_diffusion_dropdown, subfolder="text_encoder").cuda()
         | 
| 98 | 
            +
                    self.vae = AutoencoderKL.from_pretrained(
         | 
| 99 | 
            +
                        stable_diffusion_dropdown, subfolder="vae").cuda()
         | 
| 100 | 
            +
                    self.unet = UNet3DConditionModel.from_pretrained_2d(
         | 
| 101 | 
            +
                        stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
         | 
| 102 | 
            +
                    return gr.Dropdown.update()
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def update_motion_module(self, motion_module_dropdown):
         | 
| 105 | 
            +
                    if self.unet is None:
         | 
| 106 | 
            +
                        gr.Info(f"Please select a pretrained model path.")
         | 
| 107 | 
            +
                        return gr.Dropdown.update(value=None)
         | 
| 108 | 
            +
                    else:
         | 
| 109 | 
            +
                        motion_module_dropdown = os.path.join(
         | 
| 110 | 
            +
                            self.motion_module_dir, motion_module_dropdown)
         | 
| 111 | 
            +
                        motion_module_state_dict = torch.load(
         | 
| 112 | 
            +
                            motion_module_dropdown, map_location="cpu")
         | 
| 113 | 
            +
                        missing, unexpected = self.unet.load_state_dict(
         | 
| 114 | 
            +
                            motion_module_state_dict, strict=False)
         | 
| 115 | 
            +
                        assert len(unexpected) == 0
         | 
| 116 | 
            +
                        return gr.Dropdown.update()
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def update_base_model(self, base_model_dropdown):
         | 
| 119 | 
            +
                    if self.unet is None:
         | 
| 120 | 
            +
                        gr.Info(f"Please select a pretrained model path.")
         | 
| 121 | 
            +
                        return gr.Dropdown.update(value=None)
         | 
| 122 | 
            +
                    else:
         | 
| 123 | 
            +
                        base_model_dropdown = os.path.join(
         | 
| 124 | 
            +
                            self.personalized_model_dir, base_model_dropdown)
         | 
| 125 | 
            +
                        base_model_state_dict = {}
         | 
| 126 | 
            +
                        with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
         | 
| 127 | 
            +
                            for key in f.keys():
         | 
| 128 | 
            +
                                base_model_state_dict[key] = f.get_tensor(key)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                        converted_vae_checkpoint = convert_ldm_vae_checkpoint(
         | 
| 131 | 
            +
                            base_model_state_dict, self.vae.config)
         | 
| 132 | 
            +
                        self.vae.load_state_dict(converted_vae_checkpoint)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                        converted_unet_checkpoint = convert_ldm_unet_checkpoint(
         | 
| 135 | 
            +
                            base_model_state_dict, self.unet.config)
         | 
| 136 | 
            +
                        self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        # self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
         | 
| 139 | 
            +
                        return gr.Dropdown.update()
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def update_lora_model(self, lora_model_dropdown):
         | 
| 142 | 
            +
                    lora_model_dropdown = os.path.join(
         | 
| 143 | 
            +
                        self.personalized_model_dir, lora_model_dropdown)
         | 
| 144 | 
            +
                    self.lora_model_state_dict = {}
         | 
| 145 | 
            +
                    if lora_model_dropdown == "none":
         | 
| 146 | 
            +
                        pass
         | 
| 147 | 
            +
                    else:
         | 
| 148 | 
            +
                        with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
         | 
| 149 | 
            +
                            for key in f.keys():
         | 
| 150 | 
            +
                                self.lora_model_state_dict[key] = f.get_tensor(key)
         | 
| 151 | 
            +
                    return gr.Dropdown.update()
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                def animate(
         | 
| 154 | 
            +
                    self,
         | 
| 155 | 
            +
                    stable_diffusion_dropdown,
         | 
| 156 | 
            +
                    motion_module_dropdown,
         | 
| 157 | 
            +
                    base_model_dropdown,
         | 
| 158 | 
            +
                    lora_alpha_slider,
         | 
| 159 | 
            +
                    spatial_lora_slider,
         | 
| 160 | 
            +
                    prompt_textbox,
         | 
| 161 | 
            +
                    negative_prompt_textbox,
         | 
| 162 | 
            +
                    sampler_dropdown,
         | 
| 163 | 
            +
                    sample_step_slider,
         | 
| 164 | 
            +
                    width_slider,
         | 
| 165 | 
            +
                    length_slider,
         | 
| 166 | 
            +
                    height_slider,
         | 
| 167 | 
            +
                    cfg_scale_slider,
         | 
| 168 | 
            +
                    seed_textbox
         | 
| 169 | 
            +
                ):
         | 
| 170 | 
            +
                    if self.unet is None:
         | 
| 171 | 
            +
                        raise gr.Error(f"Please select a pretrained model path.")
         | 
| 172 | 
            +
                    if motion_module_dropdown == "":
         | 
| 173 | 
            +
                        raise gr.Error(f"Please select a motion module.")
         | 
| 174 | 
            +
                    if base_model_dropdown == "":
         | 
| 175 | 
            +
                        raise gr.Error(f"Please select a base DreamBooth model.")
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                    if is_xformers_available():
         | 
| 178 | 
            +
                        self.unet.enable_xformers_memory_efficient_attention()
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                    pipeline = AnimationPipeline(
         | 
| 181 | 
            +
                        vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
         | 
| 182 | 
            +
                        scheduler=scheduler_dict[sampler_dropdown](
         | 
| 183 | 
            +
                            **OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
         | 
| 184 | 
            +
                    ).to("cuda")
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    if self.lora_model_state_dict != {}:
         | 
| 187 | 
            +
                        pipeline = convert_lora(
         | 
| 188 | 
            +
                            pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    pipeline.unet = convert_lcm_lora(copy.deepcopy(
         | 
| 191 | 
            +
                        self.unet), self.lcm_lora_path, spatial_lora_slider)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    pipeline.to("cuda")
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    if seed_textbox != -1 and seed_textbox != "":
         | 
| 196 | 
            +
                        torch.manual_seed(int(seed_textbox))
         | 
| 197 | 
            +
                    else:
         | 
| 198 | 
            +
                        torch.seed()
         | 
| 199 | 
            +
                    seed = torch.initial_seed()
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    sample = pipeline(
         | 
| 202 | 
            +
                        prompt_textbox,
         | 
| 203 | 
            +
                        negative_prompt=negative_prompt_textbox,
         | 
| 204 | 
            +
                        num_inference_steps=sample_step_slider,
         | 
| 205 | 
            +
                        guidance_scale=cfg_scale_slider,
         | 
| 206 | 
            +
                        width=width_slider,
         | 
| 207 | 
            +
                        height=height_slider,
         | 
| 208 | 
            +
                        video_length=length_slider,
         | 
| 209 | 
            +
                    ).videos
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    save_sample_path = os.path.join(
         | 
| 212 | 
            +
                        self.savedir_sample, f"{sample_idx}.mp4")
         | 
| 213 | 
            +
                    save_videos_grid(sample, save_sample_path)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    sample_config = {
         | 
| 216 | 
            +
                        "prompt": prompt_textbox,
         | 
| 217 | 
            +
                        "n_prompt": negative_prompt_textbox,
         | 
| 218 | 
            +
                        "sampler": sampler_dropdown,
         | 
| 219 | 
            +
                        "num_inference_steps": sample_step_slider,
         | 
| 220 | 
            +
                        "guidance_scale": cfg_scale_slider,
         | 
| 221 | 
            +
                        "width": width_slider,
         | 
| 222 | 
            +
                        "height": height_slider,
         | 
| 223 | 
            +
                        "video_length": length_slider,
         | 
| 224 | 
            +
                        "seed": seed
         | 
| 225 | 
            +
                    }
         | 
| 226 | 
            +
                    json_str = json.dumps(sample_config, indent=4)
         | 
| 227 | 
            +
                    with open(os.path.join(self.savedir, "logs.json"), "a") as f:
         | 
| 228 | 
            +
                        f.write(json_str)
         | 
| 229 | 
            +
                        f.write("\n\n")
         | 
| 230 | 
            +
                    return gr.Video.update(value=save_sample_path)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            controller = AnimateController()
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            def ui():
         | 
| 237 | 
            +
                with gr.Blocks(css=css) as demo:
         | 
| 238 | 
            +
                    gr.Markdown(
         | 
| 239 | 
            +
                        """
         | 
| 240 | 
            +
                        # [AnimateLCM: Accelerating the Animation of Personalized Diffusion Models and Adapters with Decoupled Consistency Learning](https://arxiv.org/abs/2402.00769)
         | 
| 241 | 
            +
                        Fu-Yun Wang, Zhaoyang Huang (*Corresponding Author), Xiaoyu Shi, Weikang Bian, Guanglu Song, Yu Liu, Hongsheng Li (*Corresponding Author)<br>
         | 
| 242 | 
            +
                        [arXiv Report](https://arxiv.org/abs/2402.00769) | [Project Page](https://animatelcm.github.io/) | [Github](https://github.com/G-U-N/AnimateLCM)
         | 
| 243 | 
            +
                        """
         | 
| 244 | 
            +
                    )
         | 
| 245 | 
            +
                    with gr.Column(variant="panel"):
         | 
| 246 | 
            +
                        gr.Markdown(
         | 
| 247 | 
            +
                            """
         | 
| 248 | 
            +
                            ### 1. Model checkpoints (select pretrained model path first).
         | 
| 249 | 
            +
                            """
         | 
| 250 | 
            +
                        )
         | 
| 251 | 
            +
                        with gr.Row():
         | 
| 252 | 
            +
                            stable_diffusion_dropdown = gr.Dropdown(
         | 
| 253 | 
            +
                                label="Pretrained Model Path",
         | 
| 254 | 
            +
                                choices=controller.stable_diffusion_list,
         | 
| 255 | 
            +
                                interactive=True,
         | 
| 256 | 
            +
                            )
         | 
| 257 | 
            +
                            stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[
         | 
| 258 | 
            +
                                                             stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
         | 
| 259 | 
            +
             | 
| 260 | 
            +
                            stable_diffusion_refresh_button = gr.Button(
         | 
| 261 | 
            +
                                value="\U0001F503", elem_classes="toolbutton")
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                            def update_stable_diffusion():
         | 
| 264 | 
            +
                                controller.refresh_stable_diffusion()
         | 
| 265 | 
            +
                                return gr.Dropdown.update(choices=controller.stable_diffusion_list)
         | 
| 266 | 
            +
                            stable_diffusion_refresh_button.click(
         | 
| 267 | 
            +
                                fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                        with gr.Row():
         | 
| 270 | 
            +
                            motion_module_dropdown = gr.Dropdown(
         | 
| 271 | 
            +
                                label="Select motion module",
         | 
| 272 | 
            +
                                choices=controller.motion_module_list,
         | 
| 273 | 
            +
                                interactive=True,
         | 
| 274 | 
            +
                            )
         | 
| 275 | 
            +
                            motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[
         | 
| 276 | 
            +
                                                          motion_module_dropdown], outputs=[motion_module_dropdown])
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                            motion_module_refresh_button = gr.Button(
         | 
| 279 | 
            +
                                value="\U0001F503", elem_classes="toolbutton")
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                            def update_motion_module():
         | 
| 282 | 
            +
                                controller.refresh_motion_module()
         | 
| 283 | 
            +
                                return gr.Dropdown.update(choices=controller.motion_module_list)
         | 
| 284 | 
            +
                            motion_module_refresh_button.click(
         | 
| 285 | 
            +
                                fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
         | 
| 286 | 
            +
             | 
| 287 | 
            +
                            base_model_dropdown = gr.Dropdown(
         | 
| 288 | 
            +
                                label="Select base Dreambooth model (required)",
         | 
| 289 | 
            +
                                choices=controller.personalized_model_list,
         | 
| 290 | 
            +
                                interactive=True,
         | 
| 291 | 
            +
                            )
         | 
| 292 | 
            +
                            base_model_dropdown.change(fn=controller.update_base_model, inputs=[
         | 
| 293 | 
            +
                                                       base_model_dropdown], outputs=[base_model_dropdown])
         | 
| 294 | 
            +
             | 
| 295 | 
            +
                            lora_model_dropdown = gr.Dropdown(
         | 
| 296 | 
            +
                                label="Select LoRA model (optional)",
         | 
| 297 | 
            +
                                choices=["none"]
         | 
| 298 | 
            +
                                value="none",
         | 
| 299 | 
            +
                                interactive=True,
         | 
| 300 | 
            +
                            )
         | 
| 301 | 
            +
                            lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[
         | 
| 302 | 
            +
                                                       lora_model_dropdown], outputs=[lora_model_dropdown])
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                            lora_alpha_slider = gr.Slider(
         | 
| 305 | 
            +
                                label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
         | 
| 306 | 
            +
                            spatial_lora_slider = gr.Slider(
         | 
| 307 | 
            +
                                label="LCM LoRA alpha", value=0.8, minimum=0.0, maximum=1.0, interactive=True)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                            personalized_refresh_button = gr.Button(
         | 
| 310 | 
            +
                                value="\U0001F503", elem_classes="toolbutton")
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                            def update_personalized_model():
         | 
| 313 | 
            +
                                controller.refresh_personalized_model()
         | 
| 314 | 
            +
                                return [
         | 
| 315 | 
            +
                                    gr.Dropdown.update(
         | 
| 316 | 
            +
                                        choices=controller.personalized_model_list),
         | 
| 317 | 
            +
                                    gr.Dropdown.update(
         | 
| 318 | 
            +
                                        choices=["none"] + controller.personalized_model_list)
         | 
| 319 | 
            +
                                ]
         | 
| 320 | 
            +
                            personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[
         | 
| 321 | 
            +
                                                              base_model_dropdown, lora_model_dropdown])
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                    with gr.Column(variant="panel"):
         | 
| 324 | 
            +
                        gr.Markdown(
         | 
| 325 | 
            +
                            """
         | 
| 326 | 
            +
                            ### 2. Configs for AnimateLCM.
         | 
| 327 | 
            +
                            """
         | 
| 328 | 
            +
                        )
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                        prompt_textbox = gr.Textbox(label="Prompt", lines=2)
         | 
| 331 | 
            +
                        negative_prompt_textbox = gr.Textbox(
         | 
| 332 | 
            +
                            label="Negative prompt", lines=2)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                        with gr.Row().style(equal_height=False):
         | 
| 335 | 
            +
                            with gr.Column():
         | 
| 336 | 
            +
                                with gr.Row():
         | 
| 337 | 
            +
                                    sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(
         | 
| 338 | 
            +
                                        scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
         | 
| 339 | 
            +
                                    sample_step_slider = gr.Slider(
         | 
| 340 | 
            +
                                        label="Sampling steps", value=4, minimum=1, maximum=25, step=1)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                                width_slider = gr.Slider(
         | 
| 343 | 
            +
                                    label="Width",            value=512, minimum=256, maximum=1024, step=64)
         | 
| 344 | 
            +
                                height_slider = gr.Slider(
         | 
| 345 | 
            +
                                    label="Height",           value=512, minimum=256, maximum=1024, step=64)
         | 
| 346 | 
            +
                                length_slider = gr.Slider(
         | 
| 347 | 
            +
                                    label="Animation length", value=16,  minimum=12,   maximum=20,   step=1)
         | 
| 348 | 
            +
                                cfg_scale_slider = gr.Slider(
         | 
| 349 | 
            +
                                    label="CFG Scale",        value=1, minimum=1,   maximum=2)
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                                with gr.Row():
         | 
| 352 | 
            +
                                    seed_textbox = gr.Textbox(label="Seed", value=-1)
         | 
| 353 | 
            +
                                    seed_button = gr.Button(
         | 
| 354 | 
            +
                                        value="\U0001F3B2", elem_classes="toolbutton")
         | 
| 355 | 
            +
                                    seed_button.click(fn=lambda: gr.Textbox.update(
         | 
| 356 | 
            +
                                        value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                                generate_button = gr.Button(
         | 
| 359 | 
            +
                                    value="Generate", variant='primary')
         | 
| 360 | 
            +
             | 
| 361 | 
            +
                            result_video = gr.Video(
         | 
| 362 | 
            +
                                label="Generated Animation", interactive=False)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                        generate_button.click(
         | 
| 365 | 
            +
                            fn=controller.animate,
         | 
| 366 | 
            +
                            inputs=[
         | 
| 367 | 
            +
                                stable_diffusion_dropdown,
         | 
| 368 | 
            +
                                motion_module_dropdown,
         | 
| 369 | 
            +
                                base_model_dropdown,
         | 
| 370 | 
            +
                                lora_alpha_slider,
         | 
| 371 | 
            +
                                spatial_lora_slider,
         | 
| 372 | 
            +
                                prompt_textbox,
         | 
| 373 | 
            +
                                negative_prompt_textbox,
         | 
| 374 | 
            +
                                sampler_dropdown,
         | 
| 375 | 
            +
                                sample_step_slider,
         | 
| 376 | 
            +
                                width_slider,
         | 
| 377 | 
            +
                                length_slider,
         | 
| 378 | 
            +
                                height_slider,
         | 
| 379 | 
            +
                                cfg_scale_slider,
         | 
| 380 | 
            +
                                seed_textbox,
         | 
| 381 | 
            +
                            ],
         | 
| 382 | 
            +
                            outputs=[result_video]
         | 
| 383 | 
            +
                        )
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                return demo
         | 
| 386 | 
            +
             | 
| 387 | 
            +
             | 
| 388 | 
            +
            if __name__ == "__main__":
         | 
| 389 | 
            +
                demo = ui()
         | 
| 390 | 
            +
                # gr.close_all()
         | 
| 391 | 
            +
                demo.queue(concurrency_count=3, max_size=20)
         | 
| 392 | 
            +
                demo.launch(share=True, server_name="127.0.0.1")
         | 
    	
        models/.DS_Store
    ADDED
    
    | Binary file (6.15 kB). View file | 
|  | 
    	
        models/DreamBooth_LoRA/cartoon2d.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:cbfba64e662370f59d4aa2aa69bf16749fce93846ccce20506aee5df01169859
         | 
| 3 | 
            +
            size 4244124028
         | 
    	
        models/DreamBooth_LoRA/cartoon3d.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:a6b4c0392d7486bfa4fd1a31c7b7d2679f743f8ea8d9f219c82b5c33db31ddb9
         | 
| 3 | 
            +
            size 2132625644
         | 
    	
        models/DreamBooth_LoRA/realistic1.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c0d1994c73d784a17a5b335ae8bda02dcc8dd2fc5f5dbf55169d5aab385e53f2
         | 
| 3 | 
            +
            size 2132650523
         | 
    	
        models/DreamBooth_LoRA/realistic2.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:a38fa861a24f4f4c6e0f68289101e645dd9ca1e93e1049cc8a4f2a77513fad52
         | 
| 3 | 
            +
            size 2400040290
         | 
    	
        models/LCM_LoRA/Put LCMLoRA checkpoints here.txt
    ADDED
    
    | 
            File without changes
         | 
    	
        models/LCM_LoRA/sd15_t2v_beta_lora.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:8f90d840e075ff588a58e22c6586e2ae9a6f7922996ee6649a7f01072333afe4
         | 
| 3 | 
            +
            size 134621556
         | 
    	
        models/Motion_Module/Put motion module checkpoints here.txt
    ADDED
    
    | 
            File without changes
         | 
    	
        models/Motion_Module/sd15_t2v_beta_motion.ckpt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:b46c3de62e5696af72c4056e3cdcbea12fbc19581c0aad7b6f2b027851148f5f
         | 
| 3 | 
            +
            size 1813041929
         | 
    	
        models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt
    ADDED
    
    | 
            File without changes
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/.gitattributes
    ADDED
    
    | @@ -0,0 +1,35 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         | 
| 2 | 
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         | 
| 3 | 
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         | 
| 4 | 
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         | 
| 5 | 
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         | 
| 6 | 
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         | 
| 7 | 
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         | 
| 8 | 
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         | 
| 9 | 
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         | 
| 10 | 
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         | 
| 11 | 
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         | 
| 12 | 
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         | 
| 13 | 
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         | 
| 14 | 
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         | 
| 15 | 
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         | 
| 16 | 
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         | 
| 17 | 
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         | 
| 18 | 
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         | 
| 19 | 
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         | 
| 20 | 
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         | 
| 21 | 
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         | 
| 22 | 
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         | 
| 23 | 
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         | 
| 24 | 
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         | 
| 25 | 
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         | 
| 26 | 
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         | 
| 27 | 
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         | 
| 28 | 
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         | 
| 29 | 
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         | 
| 30 | 
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         | 
| 31 | 
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         | 
| 32 | 
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         | 
| 33 | 
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         | 
| 34 | 
            +
            v1-5-pruned-emaonly.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
| 35 | 
            +
            v1-5-pruned.ckpt filter=lfs diff=lfs merge=lfs -text
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/README.md
    ADDED
    
    | @@ -0,0 +1,207 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            ---
         | 
| 2 | 
            +
            license: creativeml-openrail-m
         | 
| 3 | 
            +
            tags:
         | 
| 4 | 
            +
            - stable-diffusion
         | 
| 5 | 
            +
            - stable-diffusion-diffusers
         | 
| 6 | 
            +
            - text-to-image
         | 
| 7 | 
            +
            inference: true
         | 
| 8 | 
            +
            extra_gated_prompt: |-
         | 
| 9 | 
            +
              This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.
         | 
| 10 | 
            +
              The CreativeML OpenRAIL License specifies: 
         | 
| 11 | 
            +
             | 
| 12 | 
            +
              1. You can't use the model to deliberately produce nor share illegal or harmful outputs or content 
         | 
| 13 | 
            +
              2. CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license
         | 
| 14 | 
            +
              3. You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)
         | 
| 15 | 
            +
              Please read the full license carefully here: https://huggingface.co/spaces/CompVis/stable-diffusion-license
         | 
| 16 | 
            +
                  
         | 
| 17 | 
            +
            extra_gated_heading: Please read the LICENSE to access this model
         | 
| 18 | 
            +
            ---
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            # Stable Diffusion v1-5 Model Card
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            Stable Diffusion is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input.
         | 
| 23 | 
            +
            For more information about how Stable Diffusion functions, please have a look at [🤗's Stable Diffusion blog](https://huggingface.co/blog/stable_diffusion).
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            The **Stable-Diffusion-v1-5** checkpoint was initialized with the weights of the [Stable-Diffusion-v1-2](https:/steps/huggingface.co/CompVis/stable-diffusion-v1-2) 
         | 
| 26 | 
            +
            checkpoint and subsequently fine-tuned on 595k steps at resolution 512x512 on "laion-aesthetics v2 5+" and 10% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            You can use this both with the [🧨Diffusers library](https://github.com/huggingface/diffusers) and the [RunwayML GitHub repository](https://github.com/runwayml/stable-diffusion).
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            ### Diffusers
         | 
| 31 | 
            +
            ```py
         | 
| 32 | 
            +
            from diffusers import StableDiffusionPipeline
         | 
| 33 | 
            +
            import torch
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            model_id = "runwayml/stable-diffusion-v1-5"
         | 
| 36 | 
            +
            pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
         | 
| 37 | 
            +
            pipe = pipe.to("cuda")
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            prompt = "a photo of an astronaut riding a horse on mars"
         | 
| 40 | 
            +
            image = pipe(prompt).images[0]  
         | 
| 41 | 
            +
                
         | 
| 42 | 
            +
            image.save("astronaut_rides_horse.png")
         | 
| 43 | 
            +
            ```
         | 
| 44 | 
            +
            For more detailed instructions, use-cases and examples in JAX follow the instructions [here](https://github.com/huggingface/diffusers#text-to-image-generation-with-stable-diffusion)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            ### Original GitHub Repository
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            1. Download the weights 
         | 
| 49 | 
            +
               - [v1-5-pruned-emaonly.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt) - 4.27GB, ema-only weight. uses less VRAM - suitable for inference
         | 
| 50 | 
            +
               - [v1-5-pruned.ckpt](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt) - 7.7GB, ema+non-ema weights. uses more VRAM - suitable for fine-tuning
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            2. Follow instructions [here](https://github.com/runwayml/stable-diffusion).
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            ## Model Details
         | 
| 55 | 
            +
            - **Developed by:** Robin Rombach, Patrick Esser
         | 
| 56 | 
            +
            - **Model type:** Diffusion-based text-to-image generation model
         | 
| 57 | 
            +
            - **Language(s):** English
         | 
| 58 | 
            +
            - **License:** [The CreativeML OpenRAIL M license](https://huggingface.co/spaces/CompVis/stable-diffusion-license) is an [Open RAIL M license](https://www.licenses.ai/blog/2022/8/18/naming-convention-of-responsible-ai-licenses), adapted from the work that [BigScience](https://bigscience.huggingface.co/) and [the RAIL Initiative](https://www.licenses.ai/) are jointly carrying in the area of responsible AI licensing. See also [the article about the BLOOM Open RAIL license](https://bigscience.huggingface.co/blog/the-bigscience-rail-license) on which our license is based.
         | 
| 59 | 
            +
            - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
         | 
| 60 | 
            +
            - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
         | 
| 61 | 
            +
            - **Cite as:**
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                  @InProceedings{Rombach_2022_CVPR,
         | 
| 64 | 
            +
                      author    = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
         | 
| 65 | 
            +
                      title     = {High-Resolution Image Synthesis With Latent Diffusion Models},
         | 
| 66 | 
            +
                      booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
         | 
| 67 | 
            +
                      month     = {June},
         | 
| 68 | 
            +
                      year      = {2022},
         | 
| 69 | 
            +
                      pages     = {10684-10695}
         | 
| 70 | 
            +
                  }
         | 
| 71 | 
            +
             | 
| 72 | 
            +
            # Uses
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            ## Direct Use 
         | 
| 75 | 
            +
            The model is intended for research purposes only. Possible research areas and
         | 
| 76 | 
            +
            tasks include
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            - Safe deployment of models which have the potential to generate harmful content.
         | 
| 79 | 
            +
            - Probing and understanding the limitations and biases of generative models.
         | 
| 80 | 
            +
            - Generation of artworks and use in design and other artistic processes.
         | 
| 81 | 
            +
            - Applications in educational or creative tools.
         | 
| 82 | 
            +
            - Research on generative models.
         | 
| 83 | 
            +
             | 
| 84 | 
            +
            Excluded uses are described below.
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             ### Misuse, Malicious Use, and Out-of-Scope Use
         | 
| 87 | 
            +
            _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            #### Out-of-Scope Use
         | 
| 93 | 
            +
            The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            #### Misuse and Malicious Use
         | 
| 96 | 
            +
            Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
         | 
| 99 | 
            +
            - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
         | 
| 100 | 
            +
            - Impersonating individuals without their consent.
         | 
| 101 | 
            +
            - Sexual content without consent of the people who might see it.
         | 
| 102 | 
            +
            - Mis- and disinformation
         | 
| 103 | 
            +
            - Representations of egregious violence and gore
         | 
| 104 | 
            +
            - Sharing of copyrighted or licensed material in violation of its terms of use.
         | 
| 105 | 
            +
            - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            ## Limitations and Bias
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            ### Limitations
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            - The model does not achieve perfect photorealism
         | 
| 112 | 
            +
            - The model cannot render legible text
         | 
| 113 | 
            +
            - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
         | 
| 114 | 
            +
            - Faces and people in general may not be generated properly.
         | 
| 115 | 
            +
            - The model was trained mainly with English captions and will not work as well in other languages.
         | 
| 116 | 
            +
            - The autoencoding part of the model is lossy
         | 
| 117 | 
            +
            - The model was trained on a large-scale dataset
         | 
| 118 | 
            +
              [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
         | 
| 119 | 
            +
              and is not fit for product use without additional safety mechanisms and
         | 
| 120 | 
            +
              considerations.
         | 
| 121 | 
            +
            - No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
         | 
| 122 | 
            +
              The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
         | 
| 123 | 
            +
             | 
| 124 | 
            +
            ### Bias
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. 
         | 
| 127 | 
            +
            Stable Diffusion v1 was trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/), 
         | 
| 128 | 
            +
            which consists of images that are primarily limited to English descriptions. 
         | 
| 129 | 
            +
            Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for. 
         | 
| 130 | 
            +
            This affects the overall output of the model, as white and western cultures are often set as the default. Further, the 
         | 
| 131 | 
            +
            ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            ### Safety Module
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers. 
         | 
| 136 | 
            +
            This checker works by checking model outputs against known hard-coded NSFW concepts.
         | 
| 137 | 
            +
            The concepts are intentionally hidden to reduce the likelihood of reverse-engineering this filter.
         | 
| 138 | 
            +
            Specifically, the checker compares the class probability of harmful concepts in the embedding space of the `CLIPTextModel` *after generation* of the images. 
         | 
| 139 | 
            +
            The concepts are passed into the model with the generated image and compared to a hand-engineered weight for each NSFW concept.
         | 
| 140 | 
            +
             | 
| 141 | 
            +
             | 
| 142 | 
            +
            ## Training
         | 
| 143 | 
            +
             | 
| 144 | 
            +
            **Training Data**
         | 
| 145 | 
            +
            The model developers used the following dataset for training the model:
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            - LAION-2B (en) and subsets thereof (see next section)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
            **Training Procedure**
         | 
| 150 | 
            +
            Stable Diffusion v1-5 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training, 
         | 
| 151 | 
            +
             | 
| 152 | 
            +
            - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
         | 
| 153 | 
            +
            - Text prompts are encoded through a ViT-L/14 text-encoder.
         | 
| 154 | 
            +
            - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
         | 
| 155 | 
            +
            - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
         | 
| 156 | 
            +
             | 
| 157 | 
            +
            Currently six Stable Diffusion checkpoints are provided, which were trained as follows.
         | 
| 158 | 
            +
            - [`stable-diffusion-v1-1`](https://huggingface.co/CompVis/stable-diffusion-v1-1): 237,000 steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
         | 
| 159 | 
            +
              194,000 steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
         | 
| 160 | 
            +
            - [`stable-diffusion-v1-2`](https://huggingface.co/CompVis/stable-diffusion-v1-2): Resumed from `stable-diffusion-v1-1`.
         | 
| 161 | 
            +
              515,000 steps at resolution `512x512` on "laion-improved-aesthetics" (a subset of laion2B-en,
         | 
| 162 | 
            +
            filtered to images with an original size `>= 512x512`, estimated aesthetics score `> 5.0`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the LAION-5B metadata, the aesthetics score is estimated using an [improved aesthetics estimator](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
         | 
| 163 | 
            +
            - [`stable-diffusion-v1-3`](https://huggingface.co/CompVis/stable-diffusion-v1-3): Resumed from `stable-diffusion-v1-2` - 195,000 steps at resolution `512x512` on "laion-improved-aesthetics" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
         | 
| 164 | 
            +
            - [`stable-diffusion-v1-4`](https://huggingface.co/CompVis/stable-diffusion-v1-4) Resumed from `stable-diffusion-v1-2` - 225,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
         | 
| 165 | 
            +
            - [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) Resumed from `stable-diffusion-v1-2` - 595,000 steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10 % dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
         | 
| 166 | 
            +
            - [`stable-diffusion-inpainting`](https://huggingface.co/runwayml/stable-diffusion-inpainting) Resumed from `stable-diffusion-v1-5` - then 440,000 steps of inpainting training at resolution 512x512 on “laion-aesthetics v2 5+” and 10% dropping of the text-conditioning. For inpainting, the UNet has 5 additional input channels (4 for the encoded masked-image and 1 for the mask itself) whose weights were zero-initialized after restoring the non-inpainting checkpoint. During training, we generate synthetic masks and in 25% mask everything.
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            - **Hardware:** 32 x 8 x A100 GPUs
         | 
| 169 | 
            +
            - **Optimizer:** AdamW
         | 
| 170 | 
            +
            - **Gradient Accumulations**: 2
         | 
| 171 | 
            +
            - **Batch:** 32 x 8 x 2 x 4 = 2048
         | 
| 172 | 
            +
            - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
         | 
| 173 | 
            +
             | 
| 174 | 
            +
            ## Evaluation Results 
         | 
| 175 | 
            +
            Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
         | 
| 176 | 
            +
            5.0, 6.0, 7.0, 8.0) and 50 PNDM/PLMS sampling
         | 
| 177 | 
            +
            steps show the relative improvements of the checkpoints:
         | 
| 178 | 
            +
             | 
| 179 | 
            +
            
         | 
| 180 | 
            +
             | 
| 181 | 
            +
            Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution.  Not optimized for FID scores.
         | 
| 182 | 
            +
            ## Environmental Impact
         | 
| 183 | 
            +
             | 
| 184 | 
            +
            **Stable Diffusion v1** **Estimated Emissions**
         | 
| 185 | 
            +
            Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
         | 
| 186 | 
            +
             | 
| 187 | 
            +
            - **Hardware Type:** A100 PCIe 40GB
         | 
| 188 | 
            +
            - **Hours used:** 150000
         | 
| 189 | 
            +
            - **Cloud Provider:** AWS
         | 
| 190 | 
            +
            - **Compute Region:** US-east
         | 
| 191 | 
            +
            - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
         | 
| 192 | 
            +
             | 
| 193 | 
            +
             | 
| 194 | 
            +
            ## Citation
         | 
| 195 | 
            +
             | 
| 196 | 
            +
            ```bibtex
         | 
| 197 | 
            +
                @InProceedings{Rombach_2022_CVPR,
         | 
| 198 | 
            +
                    author    = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
         | 
| 199 | 
            +
                    title     = {High-Resolution Image Synthesis With Latent Diffusion Models},
         | 
| 200 | 
            +
                    booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
         | 
| 201 | 
            +
                    month     = {June},
         | 
| 202 | 
            +
                    year      = {2022},
         | 
| 203 | 
            +
                    pages     = {10684-10695}
         | 
| 204 | 
            +
                }
         | 
| 205 | 
            +
            ```
         | 
| 206 | 
            +
             | 
| 207 | 
            +
            *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/feature_extractor/preprocessor_config.json
    ADDED
    
    | @@ -0,0 +1,20 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "crop_size": 224,
         | 
| 3 | 
            +
              "do_center_crop": true,
         | 
| 4 | 
            +
              "do_convert_rgb": true,
         | 
| 5 | 
            +
              "do_normalize": true,
         | 
| 6 | 
            +
              "do_resize": true,
         | 
| 7 | 
            +
              "feature_extractor_type": "CLIPFeatureExtractor",
         | 
| 8 | 
            +
              "image_mean": [
         | 
| 9 | 
            +
                0.48145466,
         | 
| 10 | 
            +
                0.4578275,
         | 
| 11 | 
            +
                0.40821073
         | 
| 12 | 
            +
              ],
         | 
| 13 | 
            +
              "image_std": [
         | 
| 14 | 
            +
                0.26862954,
         | 
| 15 | 
            +
                0.26130258,
         | 
| 16 | 
            +
                0.27577711
         | 
| 17 | 
            +
              ],
         | 
| 18 | 
            +
              "resample": 3,
         | 
| 19 | 
            +
              "size": 224
         | 
| 20 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/model_index.json
    ADDED
    
    | @@ -0,0 +1,32 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_class_name": "StableDiffusionPipeline",
         | 
| 3 | 
            +
              "_diffusers_version": "0.6.0",
         | 
| 4 | 
            +
              "feature_extractor": [
         | 
| 5 | 
            +
                "transformers",
         | 
| 6 | 
            +
                "CLIPImageProcessor"
         | 
| 7 | 
            +
              ],
         | 
| 8 | 
            +
              "safety_checker": [
         | 
| 9 | 
            +
                "stable_diffusion",
         | 
| 10 | 
            +
                "StableDiffusionSafetyChecker"
         | 
| 11 | 
            +
              ],
         | 
| 12 | 
            +
              "scheduler": [
         | 
| 13 | 
            +
                "diffusers",
         | 
| 14 | 
            +
                "PNDMScheduler"
         | 
| 15 | 
            +
              ],
         | 
| 16 | 
            +
              "text_encoder": [
         | 
| 17 | 
            +
                "transformers",
         | 
| 18 | 
            +
                "CLIPTextModel"
         | 
| 19 | 
            +
              ],
         | 
| 20 | 
            +
              "tokenizer": [
         | 
| 21 | 
            +
                "transformers",
         | 
| 22 | 
            +
                "CLIPTokenizer"
         | 
| 23 | 
            +
              ],
         | 
| 24 | 
            +
              "unet": [
         | 
| 25 | 
            +
                "diffusers",
         | 
| 26 | 
            +
                "UNet2DConditionModel"
         | 
| 27 | 
            +
              ],
         | 
| 28 | 
            +
              "vae": [
         | 
| 29 | 
            +
                "diffusers",
         | 
| 30 | 
            +
                "AutoencoderKL"
         | 
| 31 | 
            +
              ]
         | 
| 32 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/safety_checker/config.json
    ADDED
    
    | @@ -0,0 +1,175 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_commit_hash": "4bb648a606ef040e7685bde262611766a5fdd67b",
         | 
| 3 | 
            +
              "_name_or_path": "CompVis/stable-diffusion-safety-checker",
         | 
| 4 | 
            +
              "architectures": [
         | 
| 5 | 
            +
                "StableDiffusionSafetyChecker"
         | 
| 6 | 
            +
              ],
         | 
| 7 | 
            +
              "initializer_factor": 1.0,
         | 
| 8 | 
            +
              "logit_scale_init_value": 2.6592,
         | 
| 9 | 
            +
              "model_type": "clip",
         | 
| 10 | 
            +
              "projection_dim": 768,
         | 
| 11 | 
            +
              "text_config": {
         | 
| 12 | 
            +
                "_name_or_path": "",
         | 
| 13 | 
            +
                "add_cross_attention": false,
         | 
| 14 | 
            +
                "architectures": null,
         | 
| 15 | 
            +
                "attention_dropout": 0.0,
         | 
| 16 | 
            +
                "bad_words_ids": null,
         | 
| 17 | 
            +
                "bos_token_id": 0,
         | 
| 18 | 
            +
                "chunk_size_feed_forward": 0,
         | 
| 19 | 
            +
                "cross_attention_hidden_size": null,
         | 
| 20 | 
            +
                "decoder_start_token_id": null,
         | 
| 21 | 
            +
                "diversity_penalty": 0.0,
         | 
| 22 | 
            +
                "do_sample": false,
         | 
| 23 | 
            +
                "dropout": 0.0,
         | 
| 24 | 
            +
                "early_stopping": false,
         | 
| 25 | 
            +
                "encoder_no_repeat_ngram_size": 0,
         | 
| 26 | 
            +
                "eos_token_id": 2,
         | 
| 27 | 
            +
                "exponential_decay_length_penalty": null,
         | 
| 28 | 
            +
                "finetuning_task": null,
         | 
| 29 | 
            +
                "forced_bos_token_id": null,
         | 
| 30 | 
            +
                "forced_eos_token_id": null,
         | 
| 31 | 
            +
                "hidden_act": "quick_gelu",
         | 
| 32 | 
            +
                "hidden_size": 768,
         | 
| 33 | 
            +
                "id2label": {
         | 
| 34 | 
            +
                  "0": "LABEL_0",
         | 
| 35 | 
            +
                  "1": "LABEL_1"
         | 
| 36 | 
            +
                },
         | 
| 37 | 
            +
                "initializer_factor": 1.0,
         | 
| 38 | 
            +
                "initializer_range": 0.02,
         | 
| 39 | 
            +
                "intermediate_size": 3072,
         | 
| 40 | 
            +
                "is_decoder": false,
         | 
| 41 | 
            +
                "is_encoder_decoder": false,
         | 
| 42 | 
            +
                "label2id": {
         | 
| 43 | 
            +
                  "LABEL_0": 0,
         | 
| 44 | 
            +
                  "LABEL_1": 1
         | 
| 45 | 
            +
                },
         | 
| 46 | 
            +
                "layer_norm_eps": 1e-05,
         | 
| 47 | 
            +
                "length_penalty": 1.0,
         | 
| 48 | 
            +
                "max_length": 20,
         | 
| 49 | 
            +
                "max_position_embeddings": 77,
         | 
| 50 | 
            +
                "min_length": 0,
         | 
| 51 | 
            +
                "model_type": "clip_text_model",
         | 
| 52 | 
            +
                "no_repeat_ngram_size": 0,
         | 
| 53 | 
            +
                "num_attention_heads": 12,
         | 
| 54 | 
            +
                "num_beam_groups": 1,
         | 
| 55 | 
            +
                "num_beams": 1,
         | 
| 56 | 
            +
                "num_hidden_layers": 12,
         | 
| 57 | 
            +
                "num_return_sequences": 1,
         | 
| 58 | 
            +
                "output_attentions": false,
         | 
| 59 | 
            +
                "output_hidden_states": false,
         | 
| 60 | 
            +
                "output_scores": false,
         | 
| 61 | 
            +
                "pad_token_id": 1,
         | 
| 62 | 
            +
                "prefix": null,
         | 
| 63 | 
            +
                "problem_type": null,
         | 
| 64 | 
            +
                "pruned_heads": {},
         | 
| 65 | 
            +
                "remove_invalid_values": false,
         | 
| 66 | 
            +
                "repetition_penalty": 1.0,
         | 
| 67 | 
            +
                "return_dict": true,
         | 
| 68 | 
            +
                "return_dict_in_generate": false,
         | 
| 69 | 
            +
                "sep_token_id": null,
         | 
| 70 | 
            +
                "task_specific_params": null,
         | 
| 71 | 
            +
                "temperature": 1.0,
         | 
| 72 | 
            +
                "tf_legacy_loss": false,
         | 
| 73 | 
            +
                "tie_encoder_decoder": false,
         | 
| 74 | 
            +
                "tie_word_embeddings": true,
         | 
| 75 | 
            +
                "tokenizer_class": null,
         | 
| 76 | 
            +
                "top_k": 50,
         | 
| 77 | 
            +
                "top_p": 1.0,
         | 
| 78 | 
            +
                "torch_dtype": null,
         | 
| 79 | 
            +
                "torchscript": false,
         | 
| 80 | 
            +
                "transformers_version": "4.22.0.dev0",
         | 
| 81 | 
            +
                "typical_p": 1.0,
         | 
| 82 | 
            +
                "use_bfloat16": false,
         | 
| 83 | 
            +
                "vocab_size": 49408
         | 
| 84 | 
            +
              },
         | 
| 85 | 
            +
              "text_config_dict": {
         | 
| 86 | 
            +
                "hidden_size": 768,
         | 
| 87 | 
            +
                "intermediate_size": 3072,
         | 
| 88 | 
            +
                "num_attention_heads": 12,
         | 
| 89 | 
            +
                "num_hidden_layers": 12
         | 
| 90 | 
            +
              },
         | 
| 91 | 
            +
              "torch_dtype": "float32",
         | 
| 92 | 
            +
              "transformers_version": null,
         | 
| 93 | 
            +
              "vision_config": {
         | 
| 94 | 
            +
                "_name_or_path": "",
         | 
| 95 | 
            +
                "add_cross_attention": false,
         | 
| 96 | 
            +
                "architectures": null,
         | 
| 97 | 
            +
                "attention_dropout": 0.0,
         | 
| 98 | 
            +
                "bad_words_ids": null,
         | 
| 99 | 
            +
                "bos_token_id": null,
         | 
| 100 | 
            +
                "chunk_size_feed_forward": 0,
         | 
| 101 | 
            +
                "cross_attention_hidden_size": null,
         | 
| 102 | 
            +
                "decoder_start_token_id": null,
         | 
| 103 | 
            +
                "diversity_penalty": 0.0,
         | 
| 104 | 
            +
                "do_sample": false,
         | 
| 105 | 
            +
                "dropout": 0.0,
         | 
| 106 | 
            +
                "early_stopping": false,
         | 
| 107 | 
            +
                "encoder_no_repeat_ngram_size": 0,
         | 
| 108 | 
            +
                "eos_token_id": null,
         | 
| 109 | 
            +
                "exponential_decay_length_penalty": null,
         | 
| 110 | 
            +
                "finetuning_task": null,
         | 
| 111 | 
            +
                "forced_bos_token_id": null,
         | 
| 112 | 
            +
                "forced_eos_token_id": null,
         | 
| 113 | 
            +
                "hidden_act": "quick_gelu",
         | 
| 114 | 
            +
                "hidden_size": 1024,
         | 
| 115 | 
            +
                "id2label": {
         | 
| 116 | 
            +
                  "0": "LABEL_0",
         | 
| 117 | 
            +
                  "1": "LABEL_1"
         | 
| 118 | 
            +
                },
         | 
| 119 | 
            +
                "image_size": 224,
         | 
| 120 | 
            +
                "initializer_factor": 1.0,
         | 
| 121 | 
            +
                "initializer_range": 0.02,
         | 
| 122 | 
            +
                "intermediate_size": 4096,
         | 
| 123 | 
            +
                "is_decoder": false,
         | 
| 124 | 
            +
                "is_encoder_decoder": false,
         | 
| 125 | 
            +
                "label2id": {
         | 
| 126 | 
            +
                  "LABEL_0": 0,
         | 
| 127 | 
            +
                  "LABEL_1": 1
         | 
| 128 | 
            +
                },
         | 
| 129 | 
            +
                "layer_norm_eps": 1e-05,
         | 
| 130 | 
            +
                "length_penalty": 1.0,
         | 
| 131 | 
            +
                "max_length": 20,
         | 
| 132 | 
            +
                "min_length": 0,
         | 
| 133 | 
            +
                "model_type": "clip_vision_model",
         | 
| 134 | 
            +
                "no_repeat_ngram_size": 0,
         | 
| 135 | 
            +
                "num_attention_heads": 16,
         | 
| 136 | 
            +
                "num_beam_groups": 1,
         | 
| 137 | 
            +
                "num_beams": 1,
         | 
| 138 | 
            +
                "num_channels": 3,
         | 
| 139 | 
            +
                "num_hidden_layers": 24,
         | 
| 140 | 
            +
                "num_return_sequences": 1,
         | 
| 141 | 
            +
                "output_attentions": false,
         | 
| 142 | 
            +
                "output_hidden_states": false,
         | 
| 143 | 
            +
                "output_scores": false,
         | 
| 144 | 
            +
                "pad_token_id": null,
         | 
| 145 | 
            +
                "patch_size": 14,
         | 
| 146 | 
            +
                "prefix": null,
         | 
| 147 | 
            +
                "problem_type": null,
         | 
| 148 | 
            +
                "pruned_heads": {},
         | 
| 149 | 
            +
                "remove_invalid_values": false,
         | 
| 150 | 
            +
                "repetition_penalty": 1.0,
         | 
| 151 | 
            +
                "return_dict": true,
         | 
| 152 | 
            +
                "return_dict_in_generate": false,
         | 
| 153 | 
            +
                "sep_token_id": null,
         | 
| 154 | 
            +
                "task_specific_params": null,
         | 
| 155 | 
            +
                "temperature": 1.0,
         | 
| 156 | 
            +
                "tf_legacy_loss": false,
         | 
| 157 | 
            +
                "tie_encoder_decoder": false,
         | 
| 158 | 
            +
                "tie_word_embeddings": true,
         | 
| 159 | 
            +
                "tokenizer_class": null,
         | 
| 160 | 
            +
                "top_k": 50,
         | 
| 161 | 
            +
                "top_p": 1.0,
         | 
| 162 | 
            +
                "torch_dtype": null,
         | 
| 163 | 
            +
                "torchscript": false,
         | 
| 164 | 
            +
                "transformers_version": "4.22.0.dev0",
         | 
| 165 | 
            +
                "typical_p": 1.0,
         | 
| 166 | 
            +
                "use_bfloat16": false
         | 
| 167 | 
            +
              },
         | 
| 168 | 
            +
              "vision_config_dict": {
         | 
| 169 | 
            +
                "hidden_size": 1024,
         | 
| 170 | 
            +
                "intermediate_size": 4096,
         | 
| 171 | 
            +
                "num_attention_heads": 16,
         | 
| 172 | 
            +
                "num_hidden_layers": 24,
         | 
| 173 | 
            +
                "patch_size": 14
         | 
| 174 | 
            +
              }
         | 
| 175 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/scheduler/scheduler_config.json
    ADDED
    
    | @@ -0,0 +1,13 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_class_name": "PNDMScheduler",
         | 
| 3 | 
            +
              "_diffusers_version": "0.6.0",
         | 
| 4 | 
            +
              "beta_end": 0.012,
         | 
| 5 | 
            +
              "beta_schedule": "scaled_linear",
         | 
| 6 | 
            +
              "beta_start": 0.00085,
         | 
| 7 | 
            +
              "num_train_timesteps": 1000,
         | 
| 8 | 
            +
              "set_alpha_to_one": false,
         | 
| 9 | 
            +
              "skip_prk_steps": true,
         | 
| 10 | 
            +
              "steps_offset": 1,
         | 
| 11 | 
            +
              "trained_betas": null,
         | 
| 12 | 
            +
              "clip_sample": false
         | 
| 13 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/text_encoder/config.json
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_name_or_path": "openai/clip-vit-large-patch14",
         | 
| 3 | 
            +
              "architectures": [
         | 
| 4 | 
            +
                "CLIPTextModel"
         | 
| 5 | 
            +
              ],
         | 
| 6 | 
            +
              "attention_dropout": 0.0,
         | 
| 7 | 
            +
              "bos_token_id": 0,
         | 
| 8 | 
            +
              "dropout": 0.0,
         | 
| 9 | 
            +
              "eos_token_id": 2,
         | 
| 10 | 
            +
              "hidden_act": "quick_gelu",
         | 
| 11 | 
            +
              "hidden_size": 768,
         | 
| 12 | 
            +
              "initializer_factor": 1.0,
         | 
| 13 | 
            +
              "initializer_range": 0.02,
         | 
| 14 | 
            +
              "intermediate_size": 3072,
         | 
| 15 | 
            +
              "layer_norm_eps": 1e-05,
         | 
| 16 | 
            +
              "max_position_embeddings": 77,
         | 
| 17 | 
            +
              "model_type": "clip_text_model",
         | 
| 18 | 
            +
              "num_attention_heads": 12,
         | 
| 19 | 
            +
              "num_hidden_layers": 12,
         | 
| 20 | 
            +
              "pad_token_id": 1,
         | 
| 21 | 
            +
              "projection_dim": 768,
         | 
| 22 | 
            +
              "torch_dtype": "float32",
         | 
| 23 | 
            +
              "transformers_version": "4.22.0.dev0",
         | 
| 24 | 
            +
              "vocab_size": 49408
         | 
| 25 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/text_encoder/model.safetensors
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:d008943c017f0092921106440254dbbe00b6a285f7883ec8ba160c3faad88334
         | 
| 3 | 
            +
            size 492265874
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/tokenizer/merges.txt
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/tokenizer/special_tokens_map.json
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "bos_token": {
         | 
| 3 | 
            +
                "content": "<|startoftext|>",
         | 
| 4 | 
            +
                "lstrip": false,
         | 
| 5 | 
            +
                "normalized": true,
         | 
| 6 | 
            +
                "rstrip": false,
         | 
| 7 | 
            +
                "single_word": false
         | 
| 8 | 
            +
              },
         | 
| 9 | 
            +
              "eos_token": {
         | 
| 10 | 
            +
                "content": "<|endoftext|>",
         | 
| 11 | 
            +
                "lstrip": false,
         | 
| 12 | 
            +
                "normalized": true,
         | 
| 13 | 
            +
                "rstrip": false,
         | 
| 14 | 
            +
                "single_word": false
         | 
| 15 | 
            +
              },
         | 
| 16 | 
            +
              "pad_token": "<|endoftext|>",
         | 
| 17 | 
            +
              "unk_token": {
         | 
| 18 | 
            +
                "content": "<|endoftext|>",
         | 
| 19 | 
            +
                "lstrip": false,
         | 
| 20 | 
            +
                "normalized": true,
         | 
| 21 | 
            +
                "rstrip": false,
         | 
| 22 | 
            +
                "single_word": false
         | 
| 23 | 
            +
              }
         | 
| 24 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/tokenizer/tokenizer_config.json
    ADDED
    
    | @@ -0,0 +1,34 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "add_prefix_space": false,
         | 
| 3 | 
            +
              "bos_token": {
         | 
| 4 | 
            +
                "__type": "AddedToken",
         | 
| 5 | 
            +
                "content": "<|startoftext|>",
         | 
| 6 | 
            +
                "lstrip": false,
         | 
| 7 | 
            +
                "normalized": true,
         | 
| 8 | 
            +
                "rstrip": false,
         | 
| 9 | 
            +
                "single_word": false
         | 
| 10 | 
            +
              },
         | 
| 11 | 
            +
              "do_lower_case": true,
         | 
| 12 | 
            +
              "eos_token": {
         | 
| 13 | 
            +
                "__type": "AddedToken",
         | 
| 14 | 
            +
                "content": "<|endoftext|>",
         | 
| 15 | 
            +
                "lstrip": false,
         | 
| 16 | 
            +
                "normalized": true,
         | 
| 17 | 
            +
                "rstrip": false,
         | 
| 18 | 
            +
                "single_word": false
         | 
| 19 | 
            +
              },
         | 
| 20 | 
            +
              "errors": "replace",
         | 
| 21 | 
            +
              "model_max_length": 77,
         | 
| 22 | 
            +
              "name_or_path": "openai/clip-vit-large-patch14",
         | 
| 23 | 
            +
              "pad_token": "<|endoftext|>",
         | 
| 24 | 
            +
              "special_tokens_map_file": "./special_tokens_map.json",
         | 
| 25 | 
            +
              "tokenizer_class": "CLIPTokenizer",
         | 
| 26 | 
            +
              "unk_token": {
         | 
| 27 | 
            +
                "__type": "AddedToken",
         | 
| 28 | 
            +
                "content": "<|endoftext|>",
         | 
| 29 | 
            +
                "lstrip": false,
         | 
| 30 | 
            +
                "normalized": true,
         | 
| 31 | 
            +
                "rstrip": false,
         | 
| 32 | 
            +
                "single_word": false
         | 
| 33 | 
            +
              }
         | 
| 34 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/tokenizer/vocab.json
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/unet/config.json
    ADDED
    
    | @@ -0,0 +1,36 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_class_name": "UNet2DConditionModel",
         | 
| 3 | 
            +
              "_diffusers_version": "0.6.0",
         | 
| 4 | 
            +
              "act_fn": "silu",
         | 
| 5 | 
            +
              "attention_head_dim": 8,
         | 
| 6 | 
            +
              "block_out_channels": [
         | 
| 7 | 
            +
                320,
         | 
| 8 | 
            +
                640,
         | 
| 9 | 
            +
                1280,
         | 
| 10 | 
            +
                1280
         | 
| 11 | 
            +
              ],
         | 
| 12 | 
            +
              "center_input_sample": false,
         | 
| 13 | 
            +
              "cross_attention_dim": 768,
         | 
| 14 | 
            +
              "down_block_types": [
         | 
| 15 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 16 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 17 | 
            +
                "CrossAttnDownBlock2D",
         | 
| 18 | 
            +
                "DownBlock2D"
         | 
| 19 | 
            +
              ],
         | 
| 20 | 
            +
              "downsample_padding": 1,
         | 
| 21 | 
            +
              "flip_sin_to_cos": true,
         | 
| 22 | 
            +
              "freq_shift": 0,
         | 
| 23 | 
            +
              "in_channels": 4,
         | 
| 24 | 
            +
              "layers_per_block": 2,
         | 
| 25 | 
            +
              "mid_block_scale_factor": 1,
         | 
| 26 | 
            +
              "norm_eps": 1e-05,
         | 
| 27 | 
            +
              "norm_num_groups": 32,
         | 
| 28 | 
            +
              "out_channels": 4,
         | 
| 29 | 
            +
              "sample_size": 64,
         | 
| 30 | 
            +
              "up_block_types": [
         | 
| 31 | 
            +
                "UpBlock2D",
         | 
| 32 | 
            +
                "CrossAttnUpBlock2D",
         | 
| 33 | 
            +
                "CrossAttnUpBlock2D",
         | 
| 34 | 
            +
                "CrossAttnUpBlock2D"
         | 
| 35 | 
            +
              ]
         | 
| 36 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/unet/diffusion_pytorch_model.bin
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c7da0e21ba7ea50637bee26e81c220844defdf01aafca02b2c42ecdadb813de4
         | 
| 3 | 
            +
            size 3438354725
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/v1-inference.yaml
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            model:
         | 
| 2 | 
            +
              base_learning_rate: 1.0e-04
         | 
| 3 | 
            +
              target: ldm.models.diffusion.ddpm.LatentDiffusion
         | 
| 4 | 
            +
              params:
         | 
| 5 | 
            +
                linear_start: 0.00085
         | 
| 6 | 
            +
                linear_end: 0.0120
         | 
| 7 | 
            +
                num_timesteps_cond: 1
         | 
| 8 | 
            +
                log_every_t: 200
         | 
| 9 | 
            +
                timesteps: 1000
         | 
| 10 | 
            +
                first_stage_key: "jpg"
         | 
| 11 | 
            +
                cond_stage_key: "txt"
         | 
| 12 | 
            +
                image_size: 64
         | 
| 13 | 
            +
                channels: 4
         | 
| 14 | 
            +
                cond_stage_trainable: false   # Note: different from the one we trained before
         | 
| 15 | 
            +
                conditioning_key: crossattn
         | 
| 16 | 
            +
                monitor: val/loss_simple_ema
         | 
| 17 | 
            +
                scale_factor: 0.18215
         | 
| 18 | 
            +
                use_ema: False
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                scheduler_config: # 10000 warmup steps
         | 
| 21 | 
            +
                  target: ldm.lr_scheduler.LambdaLinearScheduler
         | 
| 22 | 
            +
                  params:
         | 
| 23 | 
            +
                    warm_up_steps: [ 10000 ]
         | 
| 24 | 
            +
                    cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
         | 
| 25 | 
            +
                    f_start: [ 1.e-6 ]
         | 
| 26 | 
            +
                    f_max: [ 1. ]
         | 
| 27 | 
            +
                    f_min: [ 1. ]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                unet_config:
         | 
| 30 | 
            +
                  target: ldm.modules.diffusionmodules.openaimodel.UNetModel
         | 
| 31 | 
            +
                  params:
         | 
| 32 | 
            +
                    image_size: 32 # unused
         | 
| 33 | 
            +
                    in_channels: 4
         | 
| 34 | 
            +
                    out_channels: 4
         | 
| 35 | 
            +
                    model_channels: 320
         | 
| 36 | 
            +
                    attention_resolutions: [ 4, 2, 1 ]
         | 
| 37 | 
            +
                    num_res_blocks: 2
         | 
| 38 | 
            +
                    channel_mult: [ 1, 2, 4, 4 ]
         | 
| 39 | 
            +
                    num_heads: 8
         | 
| 40 | 
            +
                    use_spatial_transformer: True
         | 
| 41 | 
            +
                    transformer_depth: 1
         | 
| 42 | 
            +
                    context_dim: 768
         | 
| 43 | 
            +
                    use_checkpoint: True
         | 
| 44 | 
            +
                    legacy: False
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                first_stage_config:
         | 
| 47 | 
            +
                  target: ldm.models.autoencoder.AutoencoderKL
         | 
| 48 | 
            +
                  params:
         | 
| 49 | 
            +
                    embed_dim: 4
         | 
| 50 | 
            +
                    monitor: val/rec_loss
         | 
| 51 | 
            +
                    ddconfig:
         | 
| 52 | 
            +
                      double_z: true
         | 
| 53 | 
            +
                      z_channels: 4
         | 
| 54 | 
            +
                      resolution: 256
         | 
| 55 | 
            +
                      in_channels: 3
         | 
| 56 | 
            +
                      out_ch: 3
         | 
| 57 | 
            +
                      ch: 128
         | 
| 58 | 
            +
                      ch_mult:
         | 
| 59 | 
            +
                      - 1
         | 
| 60 | 
            +
                      - 2
         | 
| 61 | 
            +
                      - 4
         | 
| 62 | 
            +
                      - 4
         | 
| 63 | 
            +
                      num_res_blocks: 2
         | 
| 64 | 
            +
                      attn_resolutions: []
         | 
| 65 | 
            +
                      dropout: 0.0
         | 
| 66 | 
            +
                    lossconfig:
         | 
| 67 | 
            +
                      target: torch.nn.Identity
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                cond_stage_config:
         | 
| 70 | 
            +
                  target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/vae/config.json
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
              "_class_name": "AutoencoderKL",
         | 
| 3 | 
            +
              "_diffusers_version": "0.6.0",
         | 
| 4 | 
            +
              "act_fn": "silu",
         | 
| 5 | 
            +
              "block_out_channels": [
         | 
| 6 | 
            +
                128,
         | 
| 7 | 
            +
                256,
         | 
| 8 | 
            +
                512,
         | 
| 9 | 
            +
                512
         | 
| 10 | 
            +
              ],
         | 
| 11 | 
            +
              "down_block_types": [
         | 
| 12 | 
            +
                "DownEncoderBlock2D",
         | 
| 13 | 
            +
                "DownEncoderBlock2D",
         | 
| 14 | 
            +
                "DownEncoderBlock2D",
         | 
| 15 | 
            +
                "DownEncoderBlock2D"
         | 
| 16 | 
            +
              ],
         | 
| 17 | 
            +
              "in_channels": 3,
         | 
| 18 | 
            +
              "latent_channels": 4,
         | 
| 19 | 
            +
              "layers_per_block": 2,
         | 
| 20 | 
            +
              "norm_num_groups": 32,
         | 
| 21 | 
            +
              "out_channels": 3,
         | 
| 22 | 
            +
              "sample_size": 512,
         | 
| 23 | 
            +
              "up_block_types": [
         | 
| 24 | 
            +
                "UpDecoderBlock2D",
         | 
| 25 | 
            +
                "UpDecoderBlock2D",
         | 
| 26 | 
            +
                "UpDecoderBlock2D",
         | 
| 27 | 
            +
                "UpDecoderBlock2D"
         | 
| 28 | 
            +
              ]
         | 
| 29 | 
            +
            }
         | 
    	
        models/StableDiffusion/stable-diffusion-v1-5/vae/diffusion_pytorch_model.bin
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:1b134cded8eb78b184aefb8805b6b572f36fa77b255c483665dda931fa0130c5
         | 
| 3 | 
            +
            size 334707217
         | 
    	
        requirements.txt
    ADDED
    
    | @@ -0,0 +1,15 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            torch==1.13.1
         | 
| 2 | 
            +
            torchvision==0.14.1
         | 
| 3 | 
            +
            torchaudio==0.13.1
         | 
| 4 | 
            +
            diffusers==0.11.1
         | 
| 5 | 
            +
            transformers==4.25.1
         | 
| 6 | 
            +
            xformers==0.0.16
         | 
| 7 | 
            +
            imageio==2.27.0
         | 
| 8 | 
            +
            gradio==3.48.0
         | 
| 9 | 
            +
            gdown
         | 
| 10 | 
            +
            einops
         | 
| 11 | 
            +
            omegaconf
         | 
| 12 | 
            +
            safetensors
         | 
| 13 | 
            +
            imageio[ffmpeg]
         | 
| 14 | 
            +
            imageio[pyav]
         | 
| 15 | 
            +
            accelerate
         | 
