model_index.json CHANGED
@@ -30,7 +30,7 @@
30
  "T5Tokenizer"
31
  ],
32
  "transformer": [
33
- "pipelines.sd3_model",
34
  "SD3Transformer2DModel"
35
  ],
36
  "vae": [
 
30
  "T5Tokenizer"
31
  ],
32
  "transformer": [
33
+ "sd3_model",
34
  "SD3Transformer2DModel"
35
  ],
36
  "vae": [
pipelines/__init__.py DELETED
File without changes
pipelines/sd3_teefusion_pipeline.py DELETED
@@ -1,264 +0,0 @@
1
- # Copyright (C) 2025 AIDC-AI
2
- # This project is licensed under the Attribution-NonCommercial 4.0 International
3
- # License (SPDX-License-Identifier: CC-BY-NC-4.0).
4
-
5
- # Unless required by applicable law or agreed to in writing, software
6
- # distributed under the License is distributed on an "AS IS" BASIS,
7
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
8
- # See the License for the specific language governing permissions and
9
- # limitations under the License.
10
-
11
- import os
12
- import torch
13
- import torch.nn as nn
14
-
15
- from typing import Union, List, Any, Optional
16
- from diffusers.configuration_utils import ConfigMixin, register_to_config
17
- from PIL import Image
18
- from diffusers import DiffusionPipeline, AutoencoderKL
19
- from transformers import CLIPTextModelWithProjection, T5EncoderModel, CLIPTokenizer, T5Tokenizer
20
-
21
- def get_noise(
22
- num_samples: int,
23
- channel: int,
24
- height: int,
25
- width: int,
26
- device: torch.device,
27
- dtype: torch.dtype,
28
- seed: int,
29
- ):
30
- return torch.randn(
31
- num_samples,
32
- channel,
33
- height // 8,
34
- width // 8,
35
- device=device,
36
- dtype=dtype,
37
- generator=torch.Generator(device=device).manual_seed(seed),
38
- )
39
-
40
- def get_clip_prompt_embeds(
41
- clip_tokenizers,
42
- clip_text_encoders,
43
- prompt: Union[str, List[str]],
44
- num_images_per_prompt: int = 1,
45
- device: Optional[torch.device] = None,
46
- clip_skip: Optional[int] = None,
47
- clip_model_index: int = 0,
48
- ):
49
-
50
- tokenizer_max_length = 77
51
- tokenizer = clip_tokenizers[clip_model_index]
52
- text_encoder = clip_text_encoders[clip_model_index]
53
-
54
- batch_size = len(prompt)
55
-
56
- text_inputs = tokenizer(
57
- prompt,
58
- padding="max_length",
59
- max_length=tokenizer_max_length,
60
- truncation=True,
61
- return_tensors="pt",
62
- )
63
-
64
- text_input_ids = text_inputs.input_ids
65
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
66
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
67
- removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
68
-
69
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
70
- pooled_prompt_embeds = prompt_embeds[0]
71
-
72
- if clip_skip is None:
73
- prompt_embeds = prompt_embeds.hidden_states[-2]
74
- else:
75
- prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
76
-
77
- prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
78
-
79
- _, seq_len, _ = prompt_embeds.shape
80
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
81
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
82
-
83
- pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
84
- pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
85
-
86
- return prompt_embeds, pooled_prompt_embeds
87
-
88
- def get_t5_prompt_embeds(
89
- tokenizer_3,
90
- text_encoder_3,
91
- prompt: Union[str, List[str]] = None,
92
- num_images_per_prompt: int = 1,
93
- max_sequence_length: int = 256,
94
- device: Optional[torch.device] = None,
95
- dtype: Optional[torch.dtype] = None,
96
- ):
97
-
98
- tokenizer_max_length = 77
99
- batch_size = len(prompt)
100
-
101
- text_inputs = tokenizer_3(
102
- prompt,
103
- padding="max_length",
104
- max_length=max_sequence_length,
105
- truncation=True,
106
- add_special_tokens=True,
107
- return_tensors="pt",
108
- )
109
- text_input_ids = text_inputs.input_ids
110
- untruncated_ids = tokenizer_3(prompt, padding="longest", return_tensors="pt").input_ids
111
-
112
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
113
- removed_text = tokenizer_3.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
114
-
115
- prompt_embeds = text_encoder_3(text_input_ids.to(device))[0]
116
-
117
- dtype = text_encoder_3.dtype
118
- prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
119
-
120
- _, seq_len, _ = prompt_embeds.shape
121
-
122
- # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
123
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
124
- prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
125
-
126
- return prompt_embeds
127
-
128
-
129
- @torch.no_grad()
130
- def encode_text(clip_tokenizers, clip_text_encoders, tokenizer_3, text_encoder_3, prompt, device, max_sequence_length=256):
131
-
132
- prompt_embed, pooled_prompt_embed = get_clip_prompt_embeds(clip_tokenizers, clip_text_encoders, prompt=prompt, device=device, clip_model_index=0)
133
- prompt_2_embed, pooled_prompt_2_embed = get_clip_prompt_embeds(clip_tokenizers, clip_text_encoders, prompt=prompt, device=device, clip_model_index=1)
134
- clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1)
135
-
136
- t5_prompt_embed = get_t5_prompt_embeds(tokenizer_3, text_encoder_3, prompt=prompt, max_sequence_length=max_sequence_length, device=device)
137
-
138
- clip_prompt_embeds = torch.nn.functional.pad(clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]))
139
-
140
- prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
141
- pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1)
142
-
143
- return prompt_embeds, pooled_prompt_embeds
144
-
145
-
146
- class TeEFusionSD3Pipeline(DiffusionPipeline, ConfigMixin):
147
-
148
- @register_to_config
149
- def __init__(
150
- self,
151
- transformer: nn.Module,
152
- text_encoder: CLIPTextModelWithProjection,
153
- text_encoder_2: CLIPTextModelWithProjection,
154
- text_encoder_3: T5EncoderModel,
155
- tokenizer: CLIPTokenizer,
156
- tokenizer_2: CLIPTokenizer,
157
- tokenizer_3: T5Tokenizer,
158
- vae: AutoencoderKL,
159
- scheduler: Any
160
- ):
161
- super().__init__()
162
-
163
- self.register_modules(
164
- transformer=transformer,
165
- text_encoder=text_encoder,
166
- text_encoder_2=text_encoder_2,
167
- text_encoder_3=text_encoder_3,
168
- tokenizer=tokenizer,
169
- tokenizer_2=tokenizer_2,
170
- tokenizer_3=tokenizer_3,
171
- vae=vae,
172
- scheduler=scheduler
173
- )
174
-
175
-
176
- @classmethod
177
- def from_pretrained(
178
- cls,
179
- pretrained_model_name_or_path: Union[str, os.PathLike],
180
- **kwargs,
181
- ) -> "TeEFusionSD3Pipeline":
182
-
183
- return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
184
-
185
- def save_pretrained(self, save_directory: Union[str, os.PathLike]):
186
- super().save_pretrained(save_directory)
187
-
188
- @torch.no_grad()
189
- def __call__(
190
- self,
191
- prompt: Union[str, List[str]],
192
- num_inference_steps: int = 50,
193
- guidance_scale: float = 7.5,
194
- latents: torch.FloatTensor = None,
195
- height: int = 1024,
196
- width: int = 1024,
197
- seed: int = 0,
198
- ):
199
- if isinstance(prompt, str):
200
- prompt = [prompt]
201
-
202
- device = self.transformer.device
203
-
204
- clip_tokenizers = [self.tokenizer, self.tokenizer_2]
205
- clip_text_encoders = [self.text_encoder, self.text_encoder_2]
206
-
207
- prompt_embeds, pooled_prompt_embeds = encode_text(clip_tokenizers, clip_text_encoders, self.tokenizer_3, self.text_encoder_3, prompt, device)
208
-
209
- _, negative_pooled_prompt_embeds = encode_text(clip_tokenizers, clip_text_encoders, self.tokenizer_3, self.text_encoder_3, [''], device)
210
-
211
-
212
- self.scheduler.set_timesteps(num_inference_steps, device=device)
213
- timesteps = self.scheduler.timesteps
214
-
215
- bs = len(prompt)
216
- channels = self.transformer.config.in_channels
217
- height = 16 * (height // 16)
218
- width = 16 * (width // 16)
219
-
220
- # prepare input
221
- if latents is None:
222
- latents = get_noise(
223
- bs,
224
- channels,
225
- height,
226
- width,
227
- device=device,
228
- dtype=self.transformer.dtype,
229
- seed=seed,
230
- )
231
-
232
- for i, t in enumerate(timesteps):
233
- noise_pred = self.transformer(
234
- hidden_states=latents,
235
- timestep=t.reshape(1),
236
- encoder_hidden_states=prompt_embeds,
237
- pooled_projections=pooled_prompt_embeds,
238
- return_dict=False,
239
- txt_align_guidance=torch.tensor(data=(guidance_scale,), dtype=self.transformer.dtype, device=self.transformer.device) * 1000.,
240
- txt_align_vec=pooled_prompt_embeds - negative_pooled_prompt_embeds
241
- )[0]
242
-
243
- latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
244
-
245
- x = latents.float()
246
-
247
- with torch.no_grad():
248
- with torch.autocast(device_type=device.type, dtype=torch.float32):
249
- if hasattr(self.vae.config, 'scaling_factor') and self.vae.config.scaling_factor is not None:
250
- x = x / self.vae.config.scaling_factor
251
- if hasattr(self.vae.config, 'shift_factor') and self.vae.config.shift_factor is not None:
252
- x = x + self.vae.config.shift_factor
253
- x = self.vae.decode(x, return_dict=False)[0]
254
-
255
- # bring into PIL format and save
256
- x = (x / 2 + 0.5).clamp(0, 1)
257
- x = x.cpu().permute(0, 2, 3, 1).float().numpy()
258
- images = (x * 255).round().astype("uint8")
259
- pil_images = [Image.fromarray(image) for image in images]
260
-
261
- return pil_images
262
-
263
-
264
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{pipelines → transformer}/sd3_model.py RENAMED
File without changes