yuxinjiang11 commited on
Commit
f4281df
·
verified ·
1 Parent(s): 16b5bff

Upload ip_adapter_anomagic.py

Browse files
Files changed (1) hide show
  1. ip_adapter/ip_adapter_anomagic.py +11 -135
ip_adapter/ip_adapter_anomagic.py CHANGED
@@ -11,7 +11,6 @@ import torch.nn as nn
11
  import math
12
  from .utils import is_torch2_available, get_generator
13
  import numpy as np
14
-
15
  if is_torch2_available():
16
  from .attention_processor import (
17
  AttnProcessor2_0 as AttnProcessor,
@@ -21,45 +20,33 @@ if is_torch2_available():
21
  else:
22
  from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
23
  from .resampler import Resampler
24
-
25
-
26
  def load_lora_model(unet, device, diffusion_model_learning_rate, dtype):
27
  for param in unet.parameters():
28
  param.requires_grad_(False)
29
-
30
  unet_lora_config = LoraConfig(
31
  r=16,
32
  lora_alpha=16,
33
  init_lora_weights="gaussian",
34
  target_modules=["to_k", "to_q", "to_v", "to_out.0"],
35
  )
36
-
37
  unet.add_adapter(unet_lora_config)
38
  lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
39
-
40
  optimizer = torch.optim.AdamW(
41
  lora_layers,
42
  lr=diffusion_model_learning_rate,
43
  )
44
-
45
  # 确保LoRA层使用正确的dtype
46
  for layer in lora_layers:
47
  layer.data = layer.data.to(dtype)
48
-
49
  return unet, lora_layers
50
-
51
-
52
  class ImageProjModel(torch.nn.Module):
53
  """Projection Model"""
54
-
55
  def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
56
  super().__init__()
57
-
58
  self.cross_attention_dim = cross_attention_dim
59
  self.clip_extra_context_tokens = clip_extra_context_tokens
60
  self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
61
  self.norm = torch.nn.LayerNorm(cross_attention_dim)
62
-
63
  def forward(self, image_embeds):
64
  embeds = image_embeds
65
  b = embeds.shape[0]
@@ -68,26 +55,19 @@ class ImageProjModel(torch.nn.Module):
68
  )
69
  clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
70
  return clip_extra_context_tokens
71
-
72
-
73
  class MLPProjModel(torch.nn.Module):
74
  """SD model with image prompt"""
75
-
76
  def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
77
  super().__init__()
78
-
79
  self.proj = torch.nn.Sequential(
80
  torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
81
  torch.nn.GELU(),
82
  torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
83
  torch.nn.LayerNorm(cross_attention_dim)
84
  )
85
-
86
  def forward(self, image_embeds):
87
  clip_extra_context_tokens = self.proj(image_embeds)
88
  return clip_extra_context_tokens
89
-
90
-
91
  class SelfAttention(nn.Module):
92
  def __init__(self, in_channels, device, dtype=torch.float16):
93
  super(SelfAttention, self).__init__()
@@ -98,98 +78,80 @@ class SelfAttention(nn.Module):
98
  self.gamma = nn.Parameter(torch.zeros(1, dtype=dtype, device=device))
99
  self.softmax = nn.Softmax(dim=-1)
100
  self.proj_out = nn.Linear(1280, 1024).to(device, dtype=dtype)
101
-
102
  def forward(self, x, mask=None):
103
  # 统一转换为模型dtype
104
  x = x.to(dtype=self.dtype)
105
-
106
  x = x.permute(0, 2, 1)
107
  batch_size, channels, h = x.size()
108
  height = int(math.sqrt(h))
109
  width = height
110
  x = x.view(batch_size, channels, width, height)
111
  batch_size, channels, height, width = x.size()
112
-
113
  # 计算 query, key, value
114
  q = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)
115
  k = self.key(x).view(batch_size, -1, height * width)
116
  v = self.value(x).view(batch_size, -1, height * width)
117
-
118
  # 计算注意力分数
119
  attention_scores = torch.bmm(q, k)
120
-
121
  if mask is not None:
122
  # 将 mask 转换为正确的dtype并移到正确设备
123
  mask = mask.to(device=x.device, dtype=self.dtype)
124
-
125
  # 将 mask 的尺寸调整为和 x 一致
126
  mask = nn.functional.interpolate(mask, size=(height, width), mode='nearest')
127
  mask = mask.view(batch_size, 1, height * width)
128
-
129
  # 应用mask
130
  large_constant = torch.tensor(1e6, dtype=self.dtype, device=x.device)
131
  attention_scores = attention_scores - (1 - mask) * large_constant
132
-
133
  # 计算注意力权重
134
  attention_weights = self.softmax(attention_scores)
135
-
136
  # 应用注意力权重
137
  out = torch.bmm(v, attention_weights.permute(0, 2, 1))
138
  out = out.view(batch_size, channels, height, width)
139
-
140
  # 加权求和
141
  out = self.gamma * out + x
142
  out = out.view(batch_size, channels, height * width)
143
  out = out.permute(0, 2, 1)
144
  out = self.proj_out(out)
145
-
146
  return out
147
  import requests
148
  import io
149
  class Anomagic:
150
  def __init__(self, sd_pipe, image_encoder, ip_ckpt_url, att_ckpt_url, device, num_tokens=4, dtype=torch.float16):
151
  self.device = device
152
- self.dtype = dtype
153
- if torch.device(device).type == 'cpu':
154
  self.dtype = torch.float32
155
-
 
156
  # 1. 初始化Attention���块(统一dtype)
157
  self.attention_module = SelfAttention(1280, device, dtype=self.dtype)
158
-
159
  # 2. 初始化SD管道(统一dtype)
160
  self.pipe = sd_pipe.to(self.device, dtype=self.dtype)
161
  self.set_anomagic()
162
-
163
  # 3. 处理image_encoder(优先使用传入的模型,而非重新加载)
164
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
165
- "yuxinjiang11/image_encoder", # 完整仓库路径
166
  torch_dtype=self.dtype,
167
  ).to(self.device, dtype=self.dtype)
168
-
169
  self.clip_image_processor = CLIPImageProcessor()
170
-
171
  # 4. 初始化image_proj模型(统一dtype)
172
  self.image_proj_model = self.init_proj()
173
-
174
  # 5. 从URL加载权重到内存(核心修正)
175
  self.ip_state_dict = self.load_weight_from_url(ip_ckpt_url)
176
  self.att_state_dict = self.load_weight_from_url(att_ckpt_url)
177
-
178
  # 6. 加载权重到模型
179
  self.load_anomagic()
180
-
181
  def load_weight_from_url(self, url):
182
  """从URL下载权重到内存并返回state_dict(处理异常)"""
183
  try:
184
  response = requests.get(url, stream=True, timeout=30)
185
- response.raise_for_status() # 捕获HTTP请求错误
186
  buffer = io.BytesIO(response.content)
187
  return torch.load(buffer, map_location="cpu")
188
  except requests.exceptions.RequestException as e:
189
  raise RuntimeError(f"权重URL请求失败: {str(e)}")
190
  except Exception as e:
191
  raise RuntimeError(f"权重加载失败: {str(e)}")
192
-
193
  def init_proj(self):
194
  """初始化image_proj模型(绑定dtype和device)"""
195
  image_proj_model = ImageProjModel(
@@ -198,16 +160,13 @@ class Anomagic:
198
  clip_extra_context_tokens=self.num_tokens,
199
  ).to(self.device, dtype=self.dtype)
200
  return image_proj_model
201
-
202
  def set_anomagic(self):
203
  """配置UNet的Attention处理器和LoRA"""
204
  unet = self.pipe.unet
205
  attn_procs = {}
206
-
207
  for name in unet.attn_processors.keys():
208
  # 判断是否为cross attention
209
  cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
210
-
211
  # 获取对应层的hidden_size
212
  if name.startswith("mid_block"):
213
  hidden_size = unet.config.block_out_channels[-1]
@@ -218,8 +177,7 @@ class Anomagic:
218
  block_id = int(name[len("down_blocks.")])
219
  hidden_size = unet.config.block_out_channels[block_id]
220
  else:
221
- hidden_size = unet.config.cross_attention_dim # 兜底
222
-
223
  # 分配Attention处理器
224
  if cross_attention_dim is None:
225
  attn_procs[name] = AttnProcessor()
@@ -230,11 +188,9 @@ class Anomagic:
230
  scale=1.0,
231
  num_tokens=self.num_tokens,
232
  ).to(self.device, dtype=self.dtype)
233
-
234
  # 应用处理器并加载LoRA
235
  unet.set_attn_processor(attn_procs)
236
  unet, lora_layers = load_lora_model(unet, self.device, 4e-4, self.dtype)
237
-
238
  # 处理ControlNet(若存在)
239
  if hasattr(self.pipe, "controlnet"):
240
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
@@ -244,7 +200,6 @@ class Anomagic:
244
  else:
245
  self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
246
  self.pipe.controlnet.to(self.device, dtype=self.dtype)
247
-
248
  def load_anomagic(self):
249
  """统一加载IP Adapter和Attention权重(修复类型和冗余问题)"""
250
  # ========== 处理IP Adapter权重 ==========
@@ -253,35 +208,28 @@ class Anomagic:
253
  state_dict = self.ip_state_dict
254
  # 转换张量精度(兼容嵌套字典)
255
  self._convert_state_dict_dtype(state_dict)
256
-
257
  # 加载到对应模块(仅执行一次,删除冗余代码)
258
  def print_param_shapes(model, state_dict, prefix=""):
259
  """打印模型和state_dict的参数形状"""
260
  print(f"\n===== {prefix} 参数形状对比 =====")
261
-
262
  # 1. 打印模型的参数形状
263
  print("【模型参数】")
264
  for name, param in model.named_parameters():
265
- print(f" {name}: {param.shape}")
266
-
267
  # 2. 打印state_dict的参数形状
268
  print("\n【StateDict参数】")
269
  for key, tensor in state_dict.items():
270
- print(f" {key}: {tensor.shape}")
271
-
272
  # 在self.image_proj_model.load_state_dict(state_dict["image_proj"])前调用
273
  print_param_shapes(self.image_proj_model, state_dict["image_proj"], "image_proj_model")
274
-
275
  self.image_proj_model.load_state_dict(state_dict["image_proj"])
276
  ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
277
  ip_layers.load_state_dict(state_dict["ip_adapter"])
278
-
279
  # 加载UNet额外权重(若有)
280
  if "unet" in state_dict:
281
  self.pipe.unet.load_state_dict(state_dict["unet"], strict=False)
282
  else:
283
  raise TypeError("ip_state_dict必须是内存中的权重字典,而非文件路径")
284
-
285
  # ========== 处理Attention模块权重 ==========
286
  if isinstance(self.att_state_dict, dict):
287
  att_state_dict = self.att_state_dict.get("att", self.att_state_dict)
@@ -290,7 +238,6 @@ class Anomagic:
290
  self.attention_module.load_state_dict(att_state_dict, strict=True)
291
  else:
292
  raise TypeError("att_state_dict必须是内存中的权重字典")
293
-
294
  def _convert_state_dict_dtype(self, state_dict):
295
  """递归转换state_dict中所有张量的dtype(工具函数)"""
296
  for key in list(state_dict.keys()):
@@ -298,8 +245,7 @@ class Anomagic:
298
  if isinstance(value, torch.Tensor):
299
  state_dict[key] = value.to(self.dtype)
300
  elif isinstance(value, dict):
301
- self._convert_state_dict_dtype(value) # 递归处理嵌套字典
302
-
303
  @torch.inference_mode()
304
  def get_image_embeds(self, pil_image=None, clip_image_embeds=None, mask_image_0=None):
305
  if pil_image is not None:
@@ -307,40 +253,33 @@ class Anomagic:
307
  pil_image = [pil_image]
308
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
309
  clip_image = clip_image.to(self.device, dtype=self.dtype)
310
-
311
  outputs = self.image_encoder(clip_image)
312
  clip_image_embeds = outputs.image_embeds
313
  last_feature_layer_output = outputs.last_hidden_state
314
  else:
315
  clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
316
-
317
  # 处理mask_image_0
318
  if mask_image_0 is not None:
319
  mask_image_0 = mask_image_0.resize((64, 64))
320
  mask_image_0 = mask_image_0.convert('L')
321
  mask_image_0 = torch.tensor(np.array(mask_image_0), dtype=self.dtype, device=self.device)
322
  mask_image_0 = (mask_image_0 > 0.5).float()
323
- mask_image_0 = mask_image_0.unsqueeze(0).unsqueeze(0) # 添加batch和channel维度
324
  else:
325
  mask_image_0 = None
326
-
327
  # 使用统一的dtype处理特征
328
  image_embeds = self.attention_module(
329
  last_feature_layer_output[:, :256, :],
330
  mask_image_0
331
  )
332
-
333
  # 生成image_prompt_embeds
334
  image_prompt_embeds = self.image_proj_model(image_embeds)
335
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(image_embeds))
336
-
337
  return image_prompt_embeds, uncond_image_prompt_embeds
338
-
339
  def set_scale(self, scale):
340
  for attn_processor in self.pipe.unet.attn_processors.values():
341
  if isinstance(attn_processor, IPAttnProcessor):
342
  attn_processor.scale = scale
343
-
344
  def encode_long_text(self,
345
  input_ids: torch.Tensor,
346
  tokenizer: CLIPTokenizer,
@@ -349,28 +288,21 @@ class Anomagic:
349
  device: str = None
350
  ) -> torch.Tensor:
351
  device = device or self.device
352
-
353
  if input_ids.dim() == 1:
354
  input_ids = input_ids.unsqueeze(0)
355
-
356
  batch_size = input_ids.size(0)
357
  hidden_dim = text_encoder.config.hidden_size
358
-
359
  combined_embeddings = torch.zeros(batch_size, hidden_dim, device=device, dtype=self.dtype)
360
-
361
  for batch_idx in range(batch_size):
362
  current_input_ids = input_ids[batch_idx]
363
-
364
  chunks = [
365
  current_input_ids[i:i + max_length]
366
  for i in range(0, len(current_input_ids), max_length)
367
  ]
368
-
369
  embeddings = []
370
  for chunk in chunks:
371
  chunk_len = len(chunk)
372
  padding_len = max_length - chunk_len
373
-
374
  chunk_input = {
375
  "input_ids": torch.cat([
376
  chunk.unsqueeze(0).to(device),
@@ -381,16 +313,12 @@ class Anomagic:
381
  torch.zeros(1, padding_len, dtype=torch.long, device=device)
382
  ], dim=1)
383
  }
384
-
385
  with torch.no_grad():
386
  chunk_emb = text_encoder(**chunk_input).last_hidden_state
387
  embeddings.append(chunk_emb[:, :chunk_len, :].mean(dim=1))
388
-
389
  if embeddings:
390
  combined_embeddings[batch_idx] = torch.mean(torch.cat(embeddings, dim=0), dim=0)
391
-
392
  return combined_embeddings.unsqueeze(1)
393
-
394
  def generate(
395
  self,
396
  pil_image=None,
@@ -406,32 +334,26 @@ class Anomagic:
406
  **kwargs,
407
  ):
408
  self.set_scale(scale)
409
-
410
  if pil_image is not None:
411
  num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
412
  else:
413
  num_prompts = clip_image_embeds.size(0) if clip_image_embeds is not None else 1
414
-
415
  if prompt is None:
416
  prompt = "best quality, high quality"
417
  if negative_prompt is None:
418
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
419
-
420
  if not isinstance(prompt, List):
421
  prompt = [prompt] * num_prompts
422
  if not isinstance(negative_prompt, List):
423
  negative_prompt = [negative_prompt] * num_prompts
424
-
425
  image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
426
  pil_image=pil_image, clip_image_embeds=clip_image_embeds, mask_image_0=mask_image_0,
427
  )
428
-
429
  bs_embed, seq_len, _ = image_prompt_embeds.shape
430
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
431
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
432
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
433
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
434
-
435
  with torch.inference_mode():
436
  # 编码文本提示
437
  prompt_embeds_list = []
@@ -444,7 +366,6 @@ class Anomagic:
444
  return_tensors="pt"
445
  )
446
  input_ids = inputs.input_ids.to(self.device)
447
-
448
  prompt_embed = self.encode_long_text(
449
  input_ids=input_ids,
450
  tokenizer=self.pipe.tokenizer,
@@ -452,9 +373,7 @@ class Anomagic:
452
  device=self.device
453
  )
454
  prompt_embeds_list.append(prompt_embed)
455
-
456
  prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
457
-
458
  # 编码负向提示
459
  negative_prompt_embeds_list = []
460
  for p in negative_prompt:
@@ -466,7 +385,6 @@ class Anomagic:
466
  return_tensors="pt"
467
  )
468
  input_ids = inputs.input_ids.to(self.device)
469
-
470
  negative_prompt_embed = self.encode_long_text(
471
  input_ids=input_ids,
472
  tokenizer=self.pipe.tokenizer,
@@ -474,15 +392,11 @@ class Anomagic:
474
  device=self.device
475
  )
476
  negative_prompt_embeds_list.append(negative_prompt_embed)
477
-
478
  negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=0)
479
-
480
  # 合并图像嵌入与文本嵌入
481
  prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
482
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
483
-
484
  generator = get_generator(seed, self.device)
485
-
486
  images = self.pipe(
487
  prompt_embeds=prompt_embeds,
488
  negative_prompt_embeds=negative_prompt_embeds,
@@ -490,13 +404,9 @@ class Anomagic:
490
  num_inference_steps=num_inference_steps,
491
  generator=generator, **kwargs,
492
  ).images
493
-
494
  return images
495
-
496
-
497
  class AnomagicXL(Anomagic):
498
  """SDXL"""
499
-
500
  def generate(
501
  self,
502
  pil_image,
@@ -508,26 +418,21 @@ class AnomagicXL(Anomagic):
508
  num_inference_steps=30, **kwargs,
509
  ):
510
  self.set_scale(scale)
511
-
512
  num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
513
-
514
  if prompt is None:
515
  prompt = "best quality, high quality"
516
  if negative_prompt is None:
517
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
518
-
519
  if not isinstance(prompt, List):
520
  prompt = [prompt] * num_prompts
521
  if not isinstance(negative_prompt, List):
522
  negative_prompt = [negative_prompt] * num_prompts
523
-
524
  image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
525
  bs_embed, seq_len, _ = image_prompt_embeds.shape
526
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
527
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
528
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
529
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
530
-
531
  with torch.inference_mode():
532
  (
533
  prompt_embeds,
@@ -542,9 +447,7 @@ class AnomagicXL(Anomagic):
542
  )
543
  prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
544
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
545
-
546
  self.generator = get_generator(seed, self.device)
547
-
548
  images = self.pipe(
549
  prompt_embeds=prompt_embeds,
550
  negative_prompt_embeds=negative_prompt_embeds,
@@ -553,13 +456,9 @@ class AnomagicXL(Anomagic):
553
  num_inference_steps=num_inference_steps,
554
  generator=self.generator, **kwargs,
555
  ).images
556
-
557
  return images
558
-
559
-
560
  class AnomagicPlus(Anomagic):
561
  """Anomagic with fine-grained features"""
562
-
563
  def init_proj(self):
564
  image_proj_model = Resampler(
565
  dim=self.pipe.unet.config.cross_attention_dim,
@@ -572,39 +471,29 @@ class AnomagicPlus(Anomagic):
572
  ff_mult=4,
573
  ).to(self.device, dtype=self.dtype)
574
  return image_proj_model
575
-
576
  @torch.inference_mode()
577
  def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
578
  if isinstance(pil_image, Image.Image):
579
  pil_image = [pil_image]
580
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
581
  clip_image = clip_image.to(self.device, dtype=self.dtype)
582
-
583
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
584
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
585
-
586
  uncond_clip_image_embeds = self.image_encoder(
587
  torch.zeros_like(clip_image), output_hidden_states=True
588
  ).hidden_states[-2]
589
  uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
590
-
591
  return image_prompt_embeds, uncond_image_prompt_embeds
592
-
593
-
594
  class AnomagicFull(AnomagicPlus):
595
  """Anomagic with full features"""
596
-
597
  def init_proj(self):
598
  image_proj_model = MLPProjModel(
599
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
600
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
601
  ).to(self.device, dtype=self.dtype)
602
  return image_proj_model
603
-
604
-
605
  class AnomagicPlusXL(Anomagic):
606
  """SDXL"""
607
-
608
  def init_proj(self):
609
  image_proj_model = Resampler(
610
  dim=1280,
@@ -617,24 +506,19 @@ class AnomagicPlusXL(Anomagic):
617
  ff_mult=4,
618
  ).to(self.device, dtype=self.dtype)
619
  return image_proj_model
620
-
621
  @torch.inference_mode()
622
  def get_image_embeds(self, pil_image):
623
  if isinstance(pil_image, Image.Image):
624
  pil_image = [pil_image]
625
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
626
  clip_image = clip_image.to(self.device, dtype=self.dtype)
627
-
628
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
629
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
630
-
631
  uncond_clip_image_embeds = self.image_encoder(
632
  torch.zeros_like(clip_image), output_hidden_states=True
633
  ).hidden_states[-2]
634
  uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
635
-
636
  return image_prompt_embeds, uncond_image_prompt_embeds
637
-
638
  def generate(
639
  self,
640
  pil_image,
@@ -646,26 +530,21 @@ class AnomagicPlusXL(Anomagic):
646
  num_inference_steps=30, **kwargs,
647
  ):
648
  self.set_scale(scale)
649
-
650
  num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
651
-
652
  if prompt is None:
653
  prompt = "best quality, high quality"
654
  if negative_prompt is None:
655
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
656
-
657
  if not isinstance(prompt, List):
658
  prompt = [prompt] * num_prompts
659
  if not isinstance(negative_prompt, List):
660
  negative_prompt = [negative_prompt] * num_prompts
661
-
662
  image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
663
  bs_embed, seq_len, _ = image_prompt_embeds.shape
664
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
665
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
666
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
667
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
668
-
669
  with torch.inference_mode():
670
  (
671
  prompt_embeds,
@@ -680,9 +559,7 @@ class AnomagicPlusXL(Anomagic):
680
  )
681
  prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
682
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
683
-
684
  generator = get_generator(seed, self.device)
685
-
686
  images = self.pipe(
687
  prompt_embeds=prompt_embeds,
688
  negative_prompt_embeds=negative_prompt_embeds,
@@ -691,5 +568,4 @@ class AnomagicPlusXL(Anomagic):
691
  num_inference_steps=num_inference_steps,
692
  generator=generator, **kwargs,
693
  ).images
694
-
695
  return images
 
11
  import math
12
  from .utils import is_torch2_available, get_generator
13
  import numpy as np
 
14
  if is_torch2_available():
15
  from .attention_processor import (
16
  AttnProcessor2_0 as AttnProcessor,
 
20
  else:
21
  from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
22
  from .resampler import Resampler
 
 
23
  def load_lora_model(unet, device, diffusion_model_learning_rate, dtype):
24
  for param in unet.parameters():
25
  param.requires_grad_(False)
 
26
  unet_lora_config = LoraConfig(
27
  r=16,
28
  lora_alpha=16,
29
  init_lora_weights="gaussian",
30
  target_modules=["to_k", "to_q", "to_v", "to_out.0"],
31
  )
 
32
  unet.add_adapter(unet_lora_config)
33
  lora_layers = filter(lambda p: p.requires_grad, unet.parameters())
 
34
  optimizer = torch.optim.AdamW(
35
  lora_layers,
36
  lr=diffusion_model_learning_rate,
37
  )
 
38
  # 确保LoRA层使用正确的dtype
39
  for layer in lora_layers:
40
  layer.data = layer.data.to(dtype)
 
41
  return unet, lora_layers
 
 
42
  class ImageProjModel(torch.nn.Module):
43
  """Projection Model"""
 
44
  def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
45
  super().__init__()
 
46
  self.cross_attention_dim = cross_attention_dim
47
  self.clip_extra_context_tokens = clip_extra_context_tokens
48
  self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
49
  self.norm = torch.nn.LayerNorm(cross_attention_dim)
 
50
  def forward(self, image_embeds):
51
  embeds = image_embeds
52
  b = embeds.shape[0]
 
55
  )
56
  clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
57
  return clip_extra_context_tokens
 
 
58
  class MLPProjModel(torch.nn.Module):
59
  """SD model with image prompt"""
 
60
  def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
61
  super().__init__()
 
62
  self.proj = torch.nn.Sequential(
63
  torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
64
  torch.nn.GELU(),
65
  torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
66
  torch.nn.LayerNorm(cross_attention_dim)
67
  )
 
68
  def forward(self, image_embeds):
69
  clip_extra_context_tokens = self.proj(image_embeds)
70
  return clip_extra_context_tokens
 
 
71
  class SelfAttention(nn.Module):
72
  def __init__(self, in_channels, device, dtype=torch.float16):
73
  super(SelfAttention, self).__init__()
 
78
  self.gamma = nn.Parameter(torch.zeros(1, dtype=dtype, device=device))
79
  self.softmax = nn.Softmax(dim=-1)
80
  self.proj_out = nn.Linear(1280, 1024).to(device, dtype=dtype)
 
81
  def forward(self, x, mask=None):
82
  # 统一转换为模型dtype
83
  x = x.to(dtype=self.dtype)
 
84
  x = x.permute(0, 2, 1)
85
  batch_size, channels, h = x.size()
86
  height = int(math.sqrt(h))
87
  width = height
88
  x = x.view(batch_size, channels, width, height)
89
  batch_size, channels, height, width = x.size()
 
90
  # 计算 query, key, value
91
  q = self.query(x).view(batch_size, -1, height * width).permute(0, 2, 1)
92
  k = self.key(x).view(batch_size, -1, height * width)
93
  v = self.value(x).view(batch_size, -1, height * width)
 
94
  # 计算注意力分数
95
  attention_scores = torch.bmm(q, k)
 
96
  if mask is not None:
97
  # 将 mask 转换为正确的dtype并移到正确设备
98
  mask = mask.to(device=x.device, dtype=self.dtype)
 
99
  # 将 mask 的尺寸调整为和 x 一致
100
  mask = nn.functional.interpolate(mask, size=(height, width), mode='nearest')
101
  mask = mask.view(batch_size, 1, height * width)
 
102
  # 应用mask
103
  large_constant = torch.tensor(1e6, dtype=self.dtype, device=x.device)
104
  attention_scores = attention_scores - (1 - mask) * large_constant
 
105
  # 计算注意力权重
106
  attention_weights = self.softmax(attention_scores)
 
107
  # 应用注意力权重
108
  out = torch.bmm(v, attention_weights.permute(0, 2, 1))
109
  out = out.view(batch_size, channels, height, width)
 
110
  # 加权求和
111
  out = self.gamma * out + x
112
  out = out.view(batch_size, channels, height * width)
113
  out = out.permute(0, 2, 1)
114
  out = self.proj_out(out)
 
115
  return out
116
  import requests
117
  import io
118
  class Anomagic:
119
  def __init__(self, sd_pipe, image_encoder, ip_ckpt_url, att_ckpt_url, device, num_tokens=4, dtype=torch.float16):
120
  self.device = device
121
+ self.num_tokens = num_tokens
122
+ if str(device).startswith('cpu'):
123
  self.dtype = torch.float32
124
+ else:
125
+ self.dtype = dtype
126
  # 1. 初始化Attention���块(统一dtype)
127
  self.attention_module = SelfAttention(1280, device, dtype=self.dtype)
 
128
  # 2. 初始化SD管道(统一dtype)
129
  self.pipe = sd_pipe.to(self.device, dtype=self.dtype)
130
  self.set_anomagic()
 
131
  # 3. 处理image_encoder(优先使用传入的模型,而非重新加载)
132
  self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(
133
+ "yuxinjiang11/image_encoder", # 完整仓库路径
134
  torch_dtype=self.dtype,
135
  ).to(self.device, dtype=self.dtype)
 
136
  self.clip_image_processor = CLIPImageProcessor()
 
137
  # 4. 初始化image_proj模型(统一dtype)
138
  self.image_proj_model = self.init_proj()
 
139
  # 5. 从URL加载权重到内存(核心修正)
140
  self.ip_state_dict = self.load_weight_from_url(ip_ckpt_url)
141
  self.att_state_dict = self.load_weight_from_url(att_ckpt_url)
 
142
  # 6. 加载权重到模型
143
  self.load_anomagic()
 
144
  def load_weight_from_url(self, url):
145
  """从URL下载权重到内存并返回state_dict(处理异常)"""
146
  try:
147
  response = requests.get(url, stream=True, timeout=30)
148
+ response.raise_for_status() # 捕获HTTP请求错误
149
  buffer = io.BytesIO(response.content)
150
  return torch.load(buffer, map_location="cpu")
151
  except requests.exceptions.RequestException as e:
152
  raise RuntimeError(f"权重URL请求失败: {str(e)}")
153
  except Exception as e:
154
  raise RuntimeError(f"权重加载失败: {str(e)}")
 
155
  def init_proj(self):
156
  """初始化image_proj模型(绑定dtype和device)"""
157
  image_proj_model = ImageProjModel(
 
160
  clip_extra_context_tokens=self.num_tokens,
161
  ).to(self.device, dtype=self.dtype)
162
  return image_proj_model
 
163
  def set_anomagic(self):
164
  """配置UNet的Attention处理器和LoRA"""
165
  unet = self.pipe.unet
166
  attn_procs = {}
 
167
  for name in unet.attn_processors.keys():
168
  # 判断是否为cross attention
169
  cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
 
170
  # 获取对应层的hidden_size
171
  if name.startswith("mid_block"):
172
  hidden_size = unet.config.block_out_channels[-1]
 
177
  block_id = int(name[len("down_blocks.")])
178
  hidden_size = unet.config.block_out_channels[block_id]
179
  else:
180
+ hidden_size = unet.config.cross_attention_dim # 兜底
 
181
  # 分配Attention处理器
182
  if cross_attention_dim is None:
183
  attn_procs[name] = AttnProcessor()
 
188
  scale=1.0,
189
  num_tokens=self.num_tokens,
190
  ).to(self.device, dtype=self.dtype)
 
191
  # 应用处理器并加载LoRA
192
  unet.set_attn_processor(attn_procs)
193
  unet, lora_layers = load_lora_model(unet, self.device, 4e-4, self.dtype)
 
194
  # 处理ControlNet(若存在)
195
  if hasattr(self.pipe, "controlnet"):
196
  if isinstance(self.pipe.controlnet, MultiControlNetModel):
 
200
  else:
201
  self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
202
  self.pipe.controlnet.to(self.device, dtype=self.dtype)
 
203
  def load_anomagic(self):
204
  """统一加载IP Adapter和Attention权重(修复类型和冗余问题)"""
205
  # ========== 处理IP Adapter权重 ==========
 
208
  state_dict = self.ip_state_dict
209
  # 转换张量精度(兼容嵌套字典)
210
  self._convert_state_dict_dtype(state_dict)
 
211
  # 加载到对应模块(仅执行一次,删除冗余代码)
212
  def print_param_shapes(model, state_dict, prefix=""):
213
  """打印模型和state_dict的参数形状"""
214
  print(f"\n===== {prefix} 参数形状对比 =====")
 
215
  # 1. 打印模型的参数形状
216
  print("【模型参数】")
217
  for name, param in model.named_parameters():
218
+ print(f" {name}: {param.shape}")
 
219
  # 2. 打印state_dict的参数形状
220
  print("\n【StateDict参数】")
221
  for key, tensor in state_dict.items():
222
+ print(f" {key}: {tensor.shape}")
 
223
  # 在self.image_proj_model.load_state_dict(state_dict["image_proj"])前调用
224
  print_param_shapes(self.image_proj_model, state_dict["image_proj"], "image_proj_model")
 
225
  self.image_proj_model.load_state_dict(state_dict["image_proj"])
226
  ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
227
  ip_layers.load_state_dict(state_dict["ip_adapter"])
 
228
  # 加载UNet额外权重(若有)
229
  if "unet" in state_dict:
230
  self.pipe.unet.load_state_dict(state_dict["unet"], strict=False)
231
  else:
232
  raise TypeError("ip_state_dict必须是内存中的权重字典,而非文件路径")
 
233
  # ========== 处理Attention模块权重 ==========
234
  if isinstance(self.att_state_dict, dict):
235
  att_state_dict = self.att_state_dict.get("att", self.att_state_dict)
 
238
  self.attention_module.load_state_dict(att_state_dict, strict=True)
239
  else:
240
  raise TypeError("att_state_dict必须是内存中的权重字典")
 
241
  def _convert_state_dict_dtype(self, state_dict):
242
  """递归转换state_dict中所有张量的dtype(工具函数)"""
243
  for key in list(state_dict.keys()):
 
245
  if isinstance(value, torch.Tensor):
246
  state_dict[key] = value.to(self.dtype)
247
  elif isinstance(value, dict):
248
+ self._convert_state_dict_dtype(value) # 递归处理嵌套字典
 
249
  @torch.inference_mode()
250
  def get_image_embeds(self, pil_image=None, clip_image_embeds=None, mask_image_0=None):
251
  if pil_image is not None:
 
253
  pil_image = [pil_image]
254
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
255
  clip_image = clip_image.to(self.device, dtype=self.dtype)
 
256
  outputs = self.image_encoder(clip_image)
257
  clip_image_embeds = outputs.image_embeds
258
  last_feature_layer_output = outputs.last_hidden_state
259
  else:
260
  clip_image_embeds = clip_image_embeds.to(self.device, dtype=self.dtype)
 
261
  # 处理mask_image_0
262
  if mask_image_0 is not None:
263
  mask_image_0 = mask_image_0.resize((64, 64))
264
  mask_image_0 = mask_image_0.convert('L')
265
  mask_image_0 = torch.tensor(np.array(mask_image_0), dtype=self.dtype, device=self.device)
266
  mask_image_0 = (mask_image_0 > 0.5).float()
267
+ mask_image_0 = mask_image_0.unsqueeze(0).unsqueeze(0) # 添加batch和channel维度
268
  else:
269
  mask_image_0 = None
 
270
  # 使用统一的dtype处理特征
271
  image_embeds = self.attention_module(
272
  last_feature_layer_output[:, :256, :],
273
  mask_image_0
274
  )
 
275
  # 生成image_prompt_embeds
276
  image_prompt_embeds = self.image_proj_model(image_embeds)
277
  uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(image_embeds))
 
278
  return image_prompt_embeds, uncond_image_prompt_embeds
 
279
  def set_scale(self, scale):
280
  for attn_processor in self.pipe.unet.attn_processors.values():
281
  if isinstance(attn_processor, IPAttnProcessor):
282
  attn_processor.scale = scale
 
283
  def encode_long_text(self,
284
  input_ids: torch.Tensor,
285
  tokenizer: CLIPTokenizer,
 
288
  device: str = None
289
  ) -> torch.Tensor:
290
  device = device or self.device
 
291
  if input_ids.dim() == 1:
292
  input_ids = input_ids.unsqueeze(0)
 
293
  batch_size = input_ids.size(0)
294
  hidden_dim = text_encoder.config.hidden_size
 
295
  combined_embeddings = torch.zeros(batch_size, hidden_dim, device=device, dtype=self.dtype)
 
296
  for batch_idx in range(batch_size):
297
  current_input_ids = input_ids[batch_idx]
 
298
  chunks = [
299
  current_input_ids[i:i + max_length]
300
  for i in range(0, len(current_input_ids), max_length)
301
  ]
 
302
  embeddings = []
303
  for chunk in chunks:
304
  chunk_len = len(chunk)
305
  padding_len = max_length - chunk_len
 
306
  chunk_input = {
307
  "input_ids": torch.cat([
308
  chunk.unsqueeze(0).to(device),
 
313
  torch.zeros(1, padding_len, dtype=torch.long, device=device)
314
  ], dim=1)
315
  }
 
316
  with torch.no_grad():
317
  chunk_emb = text_encoder(**chunk_input).last_hidden_state
318
  embeddings.append(chunk_emb[:, :chunk_len, :].mean(dim=1))
 
319
  if embeddings:
320
  combined_embeddings[batch_idx] = torch.mean(torch.cat(embeddings, dim=0), dim=0)
 
321
  return combined_embeddings.unsqueeze(1)
 
322
  def generate(
323
  self,
324
  pil_image=None,
 
334
  **kwargs,
335
  ):
336
  self.set_scale(scale)
 
337
  if pil_image is not None:
338
  num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
339
  else:
340
  num_prompts = clip_image_embeds.size(0) if clip_image_embeds is not None else 1
 
341
  if prompt is None:
342
  prompt = "best quality, high quality"
343
  if negative_prompt is None:
344
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
 
345
  if not isinstance(prompt, List):
346
  prompt = [prompt] * num_prompts
347
  if not isinstance(negative_prompt, List):
348
  negative_prompt = [negative_prompt] * num_prompts
 
349
  image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
350
  pil_image=pil_image, clip_image_embeds=clip_image_embeds, mask_image_0=mask_image_0,
351
  )
 
352
  bs_embed, seq_len, _ = image_prompt_embeds.shape
353
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
354
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
355
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
356
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
 
357
  with torch.inference_mode():
358
  # 编码文本提示
359
  prompt_embeds_list = []
 
366
  return_tensors="pt"
367
  )
368
  input_ids = inputs.input_ids.to(self.device)
 
369
  prompt_embed = self.encode_long_text(
370
  input_ids=input_ids,
371
  tokenizer=self.pipe.tokenizer,
 
373
  device=self.device
374
  )
375
  prompt_embeds_list.append(prompt_embed)
 
376
  prompt_embeds = torch.cat(prompt_embeds_list, dim=0)
 
377
  # 编码负向提示
378
  negative_prompt_embeds_list = []
379
  for p in negative_prompt:
 
385
  return_tensors="pt"
386
  )
387
  input_ids = inputs.input_ids.to(self.device)
 
388
  negative_prompt_embed = self.encode_long_text(
389
  input_ids=input_ids,
390
  tokenizer=self.pipe.tokenizer,
 
392
  device=self.device
393
  )
394
  negative_prompt_embeds_list.append(negative_prompt_embed)
 
395
  negative_prompt_embeds = torch.cat(negative_prompt_embeds_list, dim=0)
 
396
  # 合并图像嵌入与文本嵌入
397
  prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
398
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
 
399
  generator = get_generator(seed, self.device)
 
400
  images = self.pipe(
401
  prompt_embeds=prompt_embeds,
402
  negative_prompt_embeds=negative_prompt_embeds,
 
404
  num_inference_steps=num_inference_steps,
405
  generator=generator, **kwargs,
406
  ).images
 
407
  return images
 
 
408
  class AnomagicXL(Anomagic):
409
  """SDXL"""
 
410
  def generate(
411
  self,
412
  pil_image,
 
418
  num_inference_steps=30, **kwargs,
419
  ):
420
  self.set_scale(scale)
 
421
  num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
 
422
  if prompt is None:
423
  prompt = "best quality, high quality"
424
  if negative_prompt is None:
425
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
 
426
  if not isinstance(prompt, List):
427
  prompt = [prompt] * num_prompts
428
  if not isinstance(negative_prompt, List):
429
  negative_prompt = [negative_prompt] * num_prompts
 
430
  image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
431
  bs_embed, seq_len, _ = image_prompt_embeds.shape
432
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
433
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
434
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
435
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
 
436
  with torch.inference_mode():
437
  (
438
  prompt_embeds,
 
447
  )
448
  prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
449
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
 
450
  self.generator = get_generator(seed, self.device)
 
451
  images = self.pipe(
452
  prompt_embeds=prompt_embeds,
453
  negative_prompt_embeds=negative_prompt_embeds,
 
456
  num_inference_steps=num_inference_steps,
457
  generator=self.generator, **kwargs,
458
  ).images
 
459
  return images
 
 
460
  class AnomagicPlus(Anomagic):
461
  """Anomagic with fine-grained features"""
 
462
  def init_proj(self):
463
  image_proj_model = Resampler(
464
  dim=self.pipe.unet.config.cross_attention_dim,
 
471
  ff_mult=4,
472
  ).to(self.device, dtype=self.dtype)
473
  return image_proj_model
 
474
  @torch.inference_mode()
475
  def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
476
  if isinstance(pil_image, Image.Image):
477
  pil_image = [pil_image]
478
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
479
  clip_image = clip_image.to(self.device, dtype=self.dtype)
 
480
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
481
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
 
482
  uncond_clip_image_embeds = self.image_encoder(
483
  torch.zeros_like(clip_image), output_hidden_states=True
484
  ).hidden_states[-2]
485
  uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
 
486
  return image_prompt_embeds, uncond_image_prompt_embeds
 
 
487
  class AnomagicFull(AnomagicPlus):
488
  """Anomagic with full features"""
 
489
  def init_proj(self):
490
  image_proj_model = MLPProjModel(
491
  cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
492
  clip_embeddings_dim=self.image_encoder.config.hidden_size,
493
  ).to(self.device, dtype=self.dtype)
494
  return image_proj_model
 
 
495
  class AnomagicPlusXL(Anomagic):
496
  """SDXL"""
 
497
  def init_proj(self):
498
  image_proj_model = Resampler(
499
  dim=1280,
 
506
  ff_mult=4,
507
  ).to(self.device, dtype=self.dtype)
508
  return image_proj_model
 
509
  @torch.inference_mode()
510
  def get_image_embeds(self, pil_image):
511
  if isinstance(pil_image, Image.Image):
512
  pil_image = [pil_image]
513
  clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
514
  clip_image = clip_image.to(self.device, dtype=self.dtype)
 
515
  clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
516
  image_prompt_embeds = self.image_proj_model(clip_image_embeds)
 
517
  uncond_clip_image_embeds = self.image_encoder(
518
  torch.zeros_like(clip_image), output_hidden_states=True
519
  ).hidden_states[-2]
520
  uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
 
521
  return image_prompt_embeds, uncond_image_prompt_embeds
 
522
  def generate(
523
  self,
524
  pil_image,
 
530
  num_inference_steps=30, **kwargs,
531
  ):
532
  self.set_scale(scale)
 
533
  num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
 
534
  if prompt is None:
535
  prompt = "best quality, high quality"
536
  if negative_prompt is None:
537
  negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
 
538
  if not isinstance(prompt, List):
539
  prompt = [prompt] * num_prompts
540
  if not isinstance(negative_prompt, List):
541
  negative_prompt = [negative_prompt] * num_prompts
 
542
  image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
543
  bs_embed, seq_len, _ = image_prompt_embeds.shape
544
  image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
545
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
546
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
547
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
 
548
  with torch.inference_mode():
549
  (
550
  prompt_embeds,
 
559
  )
560
  prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
561
  negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
 
562
  generator = get_generator(seed, self.device)
 
563
  images = self.pipe(
564
  prompt_embeds=prompt_embeds,
565
  negative_prompt_embeds=negative_prompt_embeds,
 
568
  num_inference_steps=num_inference_steps,
569
  generator=generator, **kwargs,
570
  ).images
 
571
  return images