#credit to huchenlei for this module #from https://github.com/huchenlei/ComfyUI-IC-Light-Native import torch import numpy as np from typing import Tuple, TypedDict, Callable import comfy.model_management from comfy.sd import load_unet from comfy.ldm.models.autoencoder import AutoencoderKL from comfy.model_base import BaseModel from comfy.model_patcher import ModelPatcher from PIL import Image from nodes import VAEEncode from ..libs.image import np2tensor, pil2tensor class UnetParams(TypedDict): input: torch.Tensor timestep: torch.Tensor c: dict cond_or_uncond: torch.Tensor class VAEEncodeArgMax(VAEEncode): def encode(self, vae, pixels): assert isinstance( vae.first_stage_model, AutoencoderKL ), "ArgMax only supported for AutoencoderKL" original_sample_mode = vae.first_stage_model.regularization.sample vae.first_stage_model.regularization.sample = False ret = super().encode(vae, pixels) vae.first_stage_model.regularization.sample = original_sample_mode return ret class ICLight: @staticmethod def apply_c_concat(params: UnetParams, concat_conds) -> UnetParams: """Apply c_concat on unet call.""" sample = params["input"] params["c"]["c_concat"] = torch.cat( ( [concat_conds.to(sample.device)] * (sample.shape[0] // concat_conds.shape[0]) ), dim=0, ) return params @staticmethod def create_custom_conv( original_conv: torch.nn.Module, dtype: torch.dtype, device=torch.device, ) -> torch.nn.Module: with torch.no_grad(): new_conv_in = torch.nn.Conv2d( 8, original_conv.out_channels, original_conv.kernel_size, original_conv.stride, original_conv.padding, ) new_conv_in.weight.zero_() new_conv_in.weight[:, :4, :, :].copy_(original_conv.weight) new_conv_in.bias = original_conv.bias return new_conv_in.to(dtype=dtype, device=device) def generate_lighting_image(self, original_image, direction): _, image_height, image_width, _ = original_image.shape match direction: case 'Left Light': gradient = np.linspace(255, 0, image_width) image = np.tile(gradient, (image_height, 1)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) return np2tensor(input_bg) case 'Right Light': gradient = np.linspace(0, 255, image_width) image = np.tile(gradient, (image_height, 1)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) return np2tensor(input_bg) case 'Top Light': gradient = np.linspace(255, 0, image_height)[:, None] image = np.tile(gradient, (1, image_width)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) return np2tensor(input_bg) case 'Bottom Light': gradient = np.linspace(0, 255, image_height)[:, None] image = np.tile(gradient, (1, image_width)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) return np2tensor(input_bg) case 'Circle Light': x = np.linspace(-1, 1, image_width) y = np.linspace(-1, 1, image_height) x, y = np.meshgrid(x, y) r = np.sqrt(x ** 2 + y ** 2) r = r / r.max() color1 = np.array([0, 0, 0])[np.newaxis, np.newaxis, :] color2 = np.array([255, 255, 255])[np.newaxis, np.newaxis, :] gradient = (color1 * r[..., np.newaxis] + color2 * (1 - r)[..., np.newaxis]).astype(np.uint8) image = pil2tensor(Image.fromarray(gradient)) return image case _: image = pil2tensor(Image.new('RGB', (1, 1), (0, 0, 0))) return image def generate_source_image(self, original_image, source): batch_size, image_height, image_width, _ = original_image.shape match source: case 'Use Flipped Background Image': if batch_size < 2: raise ValueError('Must be at least 2 image to use flipped background image.') original_image = [img.unsqueeze(0) for img in original_image] image = torch.flip(original_image[1], [2]) return image case 'Ambient': input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64 return np2tensor(input_bg) case 'Left Light': gradient = np.linspace(224, 32, image_width) image = np.tile(gradient, (image_height, 1)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) return np2tensor(input_bg) case 'Right Light': gradient = np.linspace(32, 224, image_width) image = np.tile(gradient, (image_height, 1)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) return np2tensor(input_bg) case 'Top Light': gradient = np.linspace(224, 32, image_height)[:, None] image = np.tile(gradient, (1, image_width)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) return np2tensor(input_bg) case 'Bottom Light': gradient = np.linspace(32, 224, image_height)[:, None] image = np.tile(gradient, (1, image_width)) input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8) return np2tensor(input_bg) case _: image = pil2tensor(Image.new('RGB', (1, 1), (0, 0, 0))) return image def apply(self, ic_model_path, model, c_concat: dict, ic_model=None) -> Tuple[ModelPatcher]: device = comfy.model_management.get_torch_device() dtype = comfy.model_management.unet_dtype() work_model = model.clone() # Apply scale factor. base_model: BaseModel = work_model.model scale_factor = base_model.model_config.latent_format.scale_factor # [B, 4, H, W] concat_conds: torch.Tensor = c_concat["samples"] * scale_factor # [1, 4 * B, H, W] concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1) def unet_dummy_apply(unet_apply: Callable, params: UnetParams): """A dummy unet apply wrapper serving as the endpoint of wrapper chain.""" return unet_apply(x=params["input"], t=params["timestep"], **params["c"]) existing_wrapper = work_model.model_options.get( "model_function_wrapper", unet_dummy_apply ) def wrapper_func(unet_apply: Callable, params: UnetParams): return existing_wrapper(unet_apply, params=self.apply_c_concat(params, concat_conds)) work_model.set_model_unet_function_wrapper(wrapper_func) if not ic_model: ic_model = load_unet(ic_model_path) ic_model_state_dict = ic_model.model.diffusion_model.state_dict() work_model.add_patches( patches={ ("diffusion_model." + key): ( 'diff', [ value.to(dtype=dtype, device=device), {"pad_weight": key == 'input_blocks.0.0.weight'} ] ) for key, value in ic_model_state_dict.items() } ) return (work_model, ic_model)