SingleZombie commited on
Commit
7baddd0
·
verified ·
1 Parent(s): 973b68e

Create my_controlnet.py

Browse files
Files changed (1) hide show
  1. controlnet/my_controlnet.py +238 -0
controlnet/my_controlnet.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from diffusers import ControlNetModel, ModelMixin
6
+ from diffusers.configuration_utils import register_to_config
7
+ from diffusers.models.controlnet import ControlNetOutput
8
+
9
+
10
+ def zero_module(module):
11
+ for p in module.parameters():
12
+ nn.init.zeros_(p)
13
+ return module
14
+
15
+
16
+ class MyControlNetModel(ControlNetModel, ModelMixin):
17
+ @register_to_config
18
+ def __init__(
19
+ self,
20
+ in_channels: int = 4,
21
+ conditioning_channels: int = 3,
22
+ flip_sin_to_cos: bool = True,
23
+ freq_shift: int = 0,
24
+ down_block_types: Tuple[str, ...] = (
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D",
29
+ ),
30
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
31
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
32
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
33
+ layers_per_block: int = 2,
34
+ downsample_padding: int = 1,
35
+ mid_block_scale_factor: float = 1,
36
+ act_fn: str = "silu",
37
+ norm_num_groups: Optional[int] = 32,
38
+ norm_eps: float = 1e-5,
39
+ cross_attention_dim: int = 1280,
40
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
41
+ encoder_hid_dim: Optional[int] = None,
42
+ encoder_hid_dim_type: Optional[str] = None,
43
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
44
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
45
+ use_linear_projection: bool = False,
46
+ class_embed_type: Optional[str] = None,
47
+ addition_embed_type: Optional[str] = None,
48
+ addition_time_embed_dim: Optional[int] = None,
49
+ num_class_embeds: Optional[int] = None,
50
+ upcast_attention: bool = False,
51
+ resnet_time_scale_shift: str = "default",
52
+ projection_class_embeddings_input_dim: Optional[int] = None,
53
+ controlnet_conditioning_channel_order: str = "rgb",
54
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
55
+ 16, 32, 96, 256),
56
+ global_pool_conditions: bool = False,
57
+ addition_embed_type_num_heads: int = 64):
58
+ super().__init__(in_channels, conditioning_channels, flip_sin_to_cos, freq_shift, down_block_types, mid_block_type, only_cross_attention, block_out_channels, layers_per_block, downsample_padding, mid_block_scale_factor, act_fn, norm_num_groups, norm_eps, cross_attention_dim, transformer_layers_per_block, encoder_hid_dim, encoder_hid_dim_type,
59
+ attention_head_dim, num_attention_heads, use_linear_projection, class_embed_type, addition_embed_type, addition_time_embed_dim, num_class_embeds, upcast_attention, resnet_time_scale_shift, projection_class_embeddings_input_dim, controlnet_conditioning_channel_order, conditioning_embedding_out_channels, global_pool_conditions, addition_embed_type_num_heads)
60
+ self.controlnet_cond_embedding = nn.Identity()
61
+ conv_in_kernel = 3
62
+ conv_in_padding = (conv_in_kernel - 1) // 2
63
+ self.conv_in2 = nn.Conv2d(
64
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
65
+ )
66
+ zero_module(self.conv_in2)
67
+
68
+ def forward(
69
+ self,
70
+ sample: torch.Tensor,
71
+ timestep: Union[torch.Tensor, float, int],
72
+ encoder_hidden_states: torch.Tensor,
73
+ controlnet_cond: torch.Tensor,
74
+ conditioning_scale: float = 1.0,
75
+ class_labels: Optional[torch.Tensor] = None,
76
+ timestep_cond: Optional[torch.Tensor] = None,
77
+ attention_mask: Optional[torch.Tensor] = None,
78
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
79
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
80
+ guess_mode: bool = False,
81
+ return_dict: bool = True,
82
+ ) -> Union[ControlNetOutput, Tuple[Tuple[torch.Tensor, ...], torch.Tensor]]:
83
+ # check channel order
84
+ channel_order = self.config.controlnet_conditioning_channel_order
85
+
86
+ if channel_order == "rgb":
87
+ # in rgb order by default
88
+ ...
89
+ elif channel_order == "bgr":
90
+ controlnet_cond = torch.flip(controlnet_cond, dims=[1])
91
+ else:
92
+ raise ValueError(
93
+ f"unknown `controlnet_conditioning_channel_order`: {channel_order}")
94
+
95
+ # prepare attention_mask
96
+ if attention_mask is not None:
97
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
98
+ attention_mask = attention_mask.unsqueeze(1)
99
+
100
+ # 1. time
101
+ timesteps = timestep
102
+ if not torch.is_tensor(timesteps):
103
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
104
+ # This would be a good case for the `match` statement (Python 3.10+)
105
+ is_mps = sample.device.type == "mps"
106
+ if isinstance(timestep, float):
107
+ dtype = torch.float32 if is_mps else torch.float64
108
+ else:
109
+ dtype = torch.int32 if is_mps else torch.int64
110
+ timesteps = torch.tensor(
111
+ [timesteps], dtype=dtype, device=sample.device)
112
+ elif len(timesteps.shape) == 0:
113
+ timesteps = timesteps[None].to(sample.device)
114
+
115
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
116
+ timesteps = timesteps.expand(sample.shape[0])
117
+
118
+ t_emb = self.time_proj(timesteps)
119
+
120
+ # timesteps does not contain any weights and will always return f32 tensors
121
+ # but time_embedding might actually be running in fp16. so we need to cast here.
122
+ # there might be better ways to encapsulate this.
123
+ t_emb = t_emb.to(dtype=sample.dtype)
124
+
125
+ emb = self.time_embedding(t_emb, timestep_cond)
126
+ aug_emb = None
127
+
128
+ if self.class_embedding is not None:
129
+ if class_labels is None:
130
+ raise ValueError(
131
+ "class_labels should be provided when num_class_embeds > 0")
132
+
133
+ if self.config.class_embed_type == "timestep":
134
+ class_labels = self.time_proj(class_labels)
135
+
136
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
137
+ emb = emb + class_emb
138
+
139
+ if self.config.addition_embed_type is not None:
140
+ if self.config.addition_embed_type == "text":
141
+ aug_emb = self.add_embedding(encoder_hidden_states)
142
+
143
+ elif self.config.addition_embed_type == "text_time":
144
+ if "text_embeds" not in added_cond_kwargs:
145
+ raise ValueError(
146
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
147
+ )
148
+ text_embeds = added_cond_kwargs.get("text_embeds")
149
+ if "time_ids" not in added_cond_kwargs:
150
+ raise ValueError(
151
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
152
+ )
153
+ time_ids = added_cond_kwargs.get("time_ids")
154
+ time_embeds = self.add_time_proj(time_ids.flatten())
155
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
156
+
157
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
158
+ add_embeds = add_embeds.to(emb.dtype)
159
+ aug_emb = self.add_embedding(add_embeds)
160
+
161
+ emb = emb + aug_emb if aug_emb is not None else emb
162
+
163
+ # 2. pre-process
164
+ sample = self.conv_in(sample)
165
+ controlnet_cond = self.conv_in2(controlnet_cond)
166
+
167
+ sample = sample + controlnet_cond
168
+
169
+ # 3. down
170
+ down_block_res_samples = (sample,)
171
+ for downsample_block in self.down_blocks:
172
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
173
+ sample, res_samples = downsample_block(
174
+ hidden_states=sample,
175
+ temb=emb,
176
+ encoder_hidden_states=encoder_hidden_states,
177
+ attention_mask=attention_mask,
178
+ cross_attention_kwargs=cross_attention_kwargs,
179
+ )
180
+ else:
181
+ sample, res_samples = downsample_block(
182
+ hidden_states=sample, temb=emb)
183
+
184
+ down_block_res_samples += res_samples
185
+
186
+ # 4. mid
187
+ if self.mid_block is not None:
188
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
189
+ sample = self.mid_block(
190
+ sample,
191
+ emb,
192
+ encoder_hidden_states=encoder_hidden_states,
193
+ attention_mask=attention_mask,
194
+ cross_attention_kwargs=cross_attention_kwargs,
195
+ )
196
+ else:
197
+ sample = self.mid_block(sample, emb)
198
+
199
+ # 5. Control net blocks
200
+ controlnet_down_block_res_samples = ()
201
+
202
+ for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
203
+ down_block_res_sample = controlnet_block(down_block_res_sample)
204
+ controlnet_down_block_res_samples = controlnet_down_block_res_samples + \
205
+ (down_block_res_sample,)
206
+
207
+ down_block_res_samples = controlnet_down_block_res_samples
208
+
209
+ mid_block_res_sample = self.controlnet_mid_block(sample)
210
+
211
+ # 6. scaling
212
+ if guess_mode and not self.config.global_pool_conditions:
213
+ # 0.1 to 1.0
214
+ scales = torch.logspace(-1, 0, len(down_block_res_samples) +
215
+ 1, device=sample.device)
216
+ scales = scales * conditioning_scale
217
+ down_block_res_samples = [
218
+ sample * scale for sample, scale in zip(down_block_res_samples, scales)]
219
+ mid_block_res_sample = mid_block_res_sample * \
220
+ scales[-1] # last one
221
+ else:
222
+ down_block_res_samples = [
223
+ sample * conditioning_scale for sample in down_block_res_samples]
224
+ mid_block_res_sample = mid_block_res_sample * conditioning_scale
225
+
226
+ if self.config.global_pool_conditions:
227
+ down_block_res_samples = [
228
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in down_block_res_samples
229
+ ]
230
+ mid_block_res_sample = torch.mean(
231
+ mid_block_res_sample, dim=(2, 3), keepdim=True)
232
+
233
+ if not return_dict:
234
+ return (down_block_res_samples, mid_block_res_sample)
235
+
236
+ return ControlNetOutput(
237
+ down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
238
+ )