Spaces:
Configuration error
Configuration error
| from torch import Tensor | |
| import folder_paths | |
| from nodes import VAEEncode | |
| import comfy.utils | |
| from comfy.sd import VAE | |
| from .utils import TimestepKeyframeGroup | |
| from .control_sparsectrl import SparseMethod, SparseIndexMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper | |
| from .control import load_sparsectrl, load_controlnet, ControlNetAdvanced, SparseCtrlAdvanced | |
| # node for SparseCtrl loading | |
| class SparseCtrlLoaderAdvanced: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ), | |
| "use_motion": ("BOOLEAN", {"default": True}, ), | |
| "motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), | |
| "motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), | |
| }, | |
| "optional": { | |
| "sparse_method": ("SPARSE_METHOD", ), | |
| "tk_optional": ("TIMESTEP_KEYFRAME", ), | |
| } | |
| } | |
| RETURN_TYPES = ("CONTROL_NET", ) | |
| FUNCTION = "load_controlnet" | |
| CATEGORY = "Adv-ControlNet ππ π π /SparseCtrl" | |
| def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None): | |
| sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name) | |
| sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale) | |
| sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=sparse_settings) | |
| return (sparsectrl,) | |
| class SparseCtrlMergedLoaderAdvanced: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ), | |
| "control_net_name": (folder_paths.get_filename_list("controlnet"), ), | |
| "use_motion": ("BOOLEAN", {"default": True}, ), | |
| "motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), | |
| "motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ), | |
| }, | |
| "optional": { | |
| "sparse_method": ("SPARSE_METHOD", ), | |
| "tk_optional": ("TIMESTEP_KEYFRAME", ), | |
| } | |
| } | |
| RETURN_TYPES = ("CONTROL_NET", ) | |
| FUNCTION = "load_controlnet" | |
| CATEGORY = "Adv-ControlNet ππ π π /SparseCtrl/experimental" | |
| def load_controlnet(self, sparsectrl_name: str, control_net_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None): | |
| sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name) | |
| controlnet_path = folder_paths.get_full_path("controlnet", control_net_name) | |
| sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale, merged=True) | |
| # first, load normal controlnet | |
| controlnet = load_controlnet(controlnet_path, timestep_keyframe=tk_optional) | |
| # confirm that controlnet is ControlNetAdvanced | |
| if controlnet is None or type(controlnet) != ControlNetAdvanced: | |
| raise ValueError(f"controlnet_path must point to a normal ControlNet, but instead: {type(controlnet).__name__}") | |
| # next, load sparsectrl, making sure to load motion portion | |
| sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=SparseSettings.default()) | |
| # now, combine state dicts | |
| new_state_dict = controlnet.control_model.state_dict() | |
| for key, value in sparsectrl.control_model.motion_holder.motion_wrapper.state_dict().items(): | |
| new_state_dict[key] = value | |
| # now, reload sparsectrl with real settings | |
| sparsectrl = load_sparsectrl(sparsectrl_path, controlnet_data=new_state_dict, timestep_keyframe=tk_optional, sparse_settings=sparse_settings) | |
| return (sparsectrl,) | |
| class SparseIndexMethodNode: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "indexes": ("STRING", {"default": "0"}), | |
| } | |
| } | |
| RETURN_TYPES = ("SPARSE_METHOD",) | |
| FUNCTION = "get_method" | |
| CATEGORY = "Adv-ControlNet ππ π π /SparseCtrl" | |
| def get_method(self, indexes: str): | |
| idxs = [] | |
| unique_idxs = set() | |
| # get indeces from string | |
| str_idxs = [x.strip() for x in indexes.strip().split(",")] | |
| for str_idx in str_idxs: | |
| try: | |
| idx = int(str_idx) | |
| if idx in unique_idxs: | |
| raise ValueError(f"'{idx}' is duplicated; indexes must be unique.") | |
| idxs.append(idx) | |
| unique_idxs.add(idx) | |
| except ValueError: | |
| raise ValueError(f"'{str_idx}' is not a valid integer index.") | |
| if len(idxs) == 0: | |
| raise ValueError(f"No indexes were listed in Sparse Index Method.") | |
| return (SparseIndexMethod(idxs),) | |
| class SparseSpreadMethodNode: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "spread": (SparseSpreadMethod.LIST,), | |
| } | |
| } | |
| RETURN_TYPES = ("SPARSE_METHOD",) | |
| FUNCTION = "get_method" | |
| CATEGORY = "Adv-ControlNet ππ π π /SparseCtrl" | |
| def get_method(self, spread: str): | |
| return (SparseSpreadMethod(spread=spread),) | |
| class RgbSparseCtrlPreprocessor: | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "image": ("IMAGE", ), | |
| "vae": ("VAE", ), | |
| "latent_size": ("LATENT", ), | |
| } | |
| } | |
| RETURN_TYPES = ("IMAGE",) | |
| RETURN_NAMES = ("proc_IMAGE",) | |
| FUNCTION = "preprocess_images" | |
| CATEGORY = "Adv-ControlNet ππ π π /SparseCtrl/preprocess" | |
| def preprocess_images(self, vae: VAE, image: Tensor, latent_size: Tensor): | |
| # first, resize image to match latents | |
| image = image.movedim(-1,1) | |
| image = comfy.utils.common_upscale(image, latent_size["samples"].shape[3] * 8, latent_size["samples"].shape[2] * 8, 'nearest-exact', "center") | |
| image = image.movedim(1,-1) | |
| # then, vae encode | |
| try: | |
| image = vae.vae_encode_crop_pixels(image) | |
| except Exception: | |
| image = VAEEncode.vae_encode_crop_pixels(image) | |
| encoded = vae.encode(image[:,:,:,:3]) | |
| return (PreprocSparseRGBWrapper(condhint=encoded),) | |