| import os |
| from typing import List |
|
|
| import torch |
| from PIL import Image |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection |
| from safetensors.torch import load_file |
|
|
| from nested_attention_processor import AttnProcessor, NestedAttnProcessor |
| from utils import get_generator |
|
|
| from resampler import Resampler |
|
|
|
|
|
|
| def add_special_token_to_tokenizer( |
| pipe, |
| placeholder_token, |
| initializer_token |
| ): |
| num_added_tokens1 = pipe.tokenizer.add_tokens([placeholder_token]) |
| num_added_tokens2 = pipe.tokenizer_2.add_tokens([placeholder_token]) |
| if num_added_tokens1 != 1 or num_added_tokens2 != 1: |
| raise ValueError("Failed to add placeholder token to tokenizer") |
|
|
| token_ids1 = pipe.tokenizer.encode(initializer_token, add_special_tokens=False) |
| token_ids2 = pipe.tokenizer_2.encode(initializer_token, add_special_tokens=False) |
| if len(token_ids1) > 1 or len(token_ids2) > 1: |
| raise ValueError("The initializer token must be a single token.") |
| initializer_token_id1 = token_ids1[0] |
| initializer_token_id2 = token_ids2[0] |
| placeholder_token_ids1 = pipe.tokenizer.convert_tokens_to_ids([placeholder_token]) |
| placeholder_token_ids2 = pipe.tokenizer_2.convert_tokens_to_ids([placeholder_token]) |
| pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer)) |
| pipe.text_encoder_2.resize_token_embeddings(len(pipe.tokenizer_2)) |
| token_embeds1 = pipe.text_encoder.get_input_embeddings().weight.data |
| token_embeds2 = pipe.text_encoder_2.get_input_embeddings().weight.data |
| with torch.no_grad(): |
| for token_id in placeholder_token_ids1: |
| token_embeds1[token_id] = token_embeds1[initializer_token_id1].clone() |
| for token_id in placeholder_token_ids2: |
| token_embeds2[token_id] = token_embeds2[initializer_token_id2].clone() |
|
|
|
|
| class NestedAdapterInference: |
| def __init__( |
| self, |
| sd_pipe, |
| image_encoder_path, |
| adapter_ckpt, |
| resampler_num_queries, |
| vq_normalize_factor, |
| device, |
| ): |
| self.device = device |
| self.image_encoder_path = image_encoder_path |
| self.adapter_ckpt = adapter_ckpt |
|
|
| self.vq_normalize_factor = vq_normalize_factor |
|
|
| self.pipe = sd_pipe.to(self.device) |
| self.set_nested_adapter() |
|
|
| |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained( |
| self.image_encoder_path, use_safetensors=True |
| ).to(self.device, dtype=torch.float16) |
| self.clip_image_processor = CLIPImageProcessor() |
|
|
| |
| self.qformer = Resampler( |
| dim=self.pipe.unet.config.cross_attention_dim, |
| depth=4, |
| dim_head=64, |
| heads=12, |
| num_queries=resampler_num_queries, |
| embedding_dim=self.image_encoder.config.hidden_size, |
| output_dim=self.pipe.unet.config.cross_attention_dim, |
| ff_mult=4, |
| ).to(self.device, dtype=torch.float16) |
|
|
| if adapter_ckpt is not None: |
| self.load_nested_adapter() |
|
|
| def set_nested_adapter(self): |
| unet = self.pipe.unet |
| attn_procs = {} |
| for name in unet.attn_processors.keys(): |
| cross_attention_dim = ( |
| None |
| if name.endswith("attn1.processor") |
| else unet.config.cross_attention_dim |
| ) |
| if name.startswith("mid_block"): |
| hidden_size = unet.config.block_out_channels[-1] |
| elif name.startswith("up_blocks"): |
| block_id = int(name[len("up_blocks.")]) |
| hidden_size = list(reversed(unet.config.block_out_channels))[block_id] |
| elif name.startswith("down_blocks"): |
| block_id = int(name[len("down_blocks.")]) |
| hidden_size = unet.config.block_out_channels[block_id] |
| if cross_attention_dim is None: |
| attn_procs[name] = AttnProcessor() |
| else: |
| attn_procs[name] = NestedAttnProcessor( |
| hidden_size=hidden_size, |
| cross_attention_dim=cross_attention_dim, |
| normalize_factor=self.vq_normalize_factor, |
| ).to(self.device, dtype=torch.float16) |
| unet.set_attn_processor(attn_procs) |
|
|
| def load_nested_adapter(self): |
| state_dict = {"adapter_modules": {}, "qformer": {}} |
| f = load_file(self.adapter_ckpt) |
| for key in f.keys(): |
| if key.startswith("adapter_modules."): |
| state_dict["adapter_modules"][key.replace("adapter_modules.", "")] = f[ |
| key |
| ] |
| elif key.startswith("spatial_features_model."): |
| state_dict["qformer"][key.replace("spatial_features_model.", "")] = f[ |
| key |
| ] |
| self.qformer.load_state_dict(state_dict["qformer"]) |
| adapter_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) |
| adapter_layers.load_state_dict(state_dict["adapter_modules"]) |
|
|
| @torch.inference_mode() |
| def get_image_embeds(self, pil_image=None, clip_image_embeds=None): |
| if isinstance(pil_image, Image.Image): |
| pil_image = [pil_image] |
| clip_image = self.clip_image_processor( |
| images=pil_image, return_tensors="pt" |
| ).pixel_values |
| clip_image_embeds = self.image_encoder( |
| clip_image.to(self.device, dtype=torch.float16) |
| ) |
| spatial_clip_image_embeds = clip_image_embeds.last_hidden_state |
| spatial_clip_image_embeds = spatial_clip_image_embeds[:, 1:] |
| return spatial_clip_image_embeds |
|
|
| def generate( |
| self, |
| pil_image=None, |
| clip_image_embeds=None, |
| prompt=None, |
| placeholder_token_ids=None, |
| negative_prompt=None, |
| scale=1.0, |
| num_samples=4, |
| seed=None, |
| guidance_scale=5.0, |
| num_inference_steps=30, |
| multiple_images=False, |
| special_token_weight=1.0, |
| **kwargs, |
| ): |
| if pil_image is not None: |
| num_prompts = ( |
| 1 |
| if isinstance(pil_image, Image.Image) or multiple_images |
| else len(pil_image) |
| ) |
| else: |
| num_prompts = clip_image_embeds.size(0) |
|
|
| if prompt is None: |
| prompt = "best quality, high quality" |
| if negative_prompt is None: |
| negative_prompt = ( |
| "monochrome, lowres, bad anatomy, worst quality, low quality" |
| ) |
|
|
| if not isinstance(prompt, List): |
| prompt = [prompt] * num_prompts |
| if not isinstance(negative_prompt, List): |
| negative_prompt = [negative_prompt] * num_prompts |
|
|
| text_input_ids = self.pipe.tokenizer( |
| prompt, |
| max_length=self.pipe.tokenizer.model_max_length, |
| padding="max_length", |
| truncation=True, |
| return_tensors="pt", |
| ).input_ids |
| special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ |
| :, 1 |
| ] |
|
|
| spatial_clip_image_embeds = self.get_image_embeds( |
| pil_image=pil_image, clip_image_embeds=clip_image_embeds |
| ) |
|
|
| with torch.no_grad(): |
| ( |
| prompt_embeds, |
| negative_prompt_embeds, |
| pooled_prompt_embeds, |
| negative_pooled_prompt_embeds, |
| ) = self.pipe.encode_prompt( |
| prompt, |
| num_images_per_prompt=num_samples, |
| do_classifier_free_guidance=True, |
| negative_prompt=negative_prompt, |
| ) |
|
|
| special_token_indices = (text_input_ids == placeholder_token_ids[0]).nonzero()[ |
| :, 1 |
| ] |
|
|
| with torch.no_grad(): |
| qformer_tokens_out = self.qformer(spatial_clip_image_embeds) |
|
|
| if multiple_images: |
| b, num_tokens, d = qformer_tokens_out.shape |
| qformer_tokens_out = qformer_tokens_out.reshape( |
| 1, num_tokens * b, d |
| ) |
|
|
| bs_embed, num_tokens, _ = qformer_tokens_out.shape |
|
|
| qformer_tokens_out = qformer_tokens_out.repeat(1, num_samples, 1, 1) |
| qformer_tokens_out = qformer_tokens_out.view( |
| bs_embed * num_samples, num_tokens, -1 |
| ) |
| qformer_tokens_out = qformer_tokens_out.repeat_interleave(2, dim=0) |
|
|
| cross_attention_kwargs = { |
| "qformer_tokens_out": qformer_tokens_out, |
| "special_token_indices": special_token_indices, |
| "special_token_weight": special_token_weight, |
| "inference_mode": True, |
| } |
|
|
| generator = get_generator(seed, self.device) |
|
|
| images = self.pipe( |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| pooled_prompt_embeds=pooled_prompt_embeds, |
| negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
| guidance_scale=guidance_scale, |
| num_inference_steps=num_inference_steps, |
| generator=generator, |
| cross_attention_kwargs=cross_attention_kwargs, |
| **kwargs, |
| ).images |
|
|
| return images |
|
|