Spaces:
Configuration error
Configuration error
Upload 578 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +8 -0
- ComfyUI-Advanced-ControlNet/LICENSE +674 -0
- ComfyUI-Advanced-ControlNet/README.md +202 -0
- ComfyUI-Advanced-ControlNet/__init__.py +3 -0
- ComfyUI-Advanced-ControlNet/adv_control/control.py +860 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_lllite.py +254 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_reference.py +833 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_sparsectrl.py +949 -0
- ComfyUI-Advanced-ControlNet/adv_control/control_svd.py +517 -0
- ComfyUI-Advanced-ControlNet/adv_control/logger.py +36 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes.py +235 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_deprecated.py +71 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_keyframes.py +461 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_loosecontrol.py +67 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_reference.py +90 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_sparsectrl.py +163 -0
- ComfyUI-Advanced-ControlNet/adv_control/nodes_weight.py +224 -0
- ComfyUI-Advanced-ControlNet/adv_control/utils.py +915 -0
- ComfyUI-Advanced-ControlNet/pyproject.toml +15 -0
- ComfyUI-Advanced-ControlNet/requirements.txt +0 -0
- ComfyUI-BrushNet/BIG_IMAGE.md +6 -0
- ComfyUI-BrushNet/CN.md +39 -0
- ComfyUI-BrushNet/LICENSE +201 -0
- ComfyUI-BrushNet/PARAMS.md +47 -0
- ComfyUI-BrushNet/RAUNET.md +39 -0
- ComfyUI-BrushNet/README.md +261 -0
- ComfyUI-BrushNet/__init__.py +32 -0
- ComfyUI-BrushNet/brushnet/brushnet.json +58 -0
- ComfyUI-BrushNet/brushnet/brushnet.py +948 -0
- ComfyUI-BrushNet/brushnet/brushnet_ca.py +956 -0
- ComfyUI-BrushNet/brushnet/brushnet_xl.json +63 -0
- ComfyUI-BrushNet/brushnet/powerpaint.json +57 -0
- ComfyUI-BrushNet/brushnet/powerpaint_utils.py +496 -0
- ComfyUI-BrushNet/brushnet/unet_2d_blocks.py +0 -0
- ComfyUI-BrushNet/brushnet/unet_2d_condition.py +1355 -0
- ComfyUI-BrushNet/brushnet_nodes.py +1080 -0
- ComfyUI-BrushNet/model_patch.py +134 -0
- ComfyUI-BrushNet/raunet_nodes.py +158 -0
- ComfyUI-BrushNet/requirements.txt +3 -0
- ComfyUI-Easy-Use/LICENSE +674 -0
- ComfyUI-Easy-Use/README.ZH_CN.md +459 -0
- ComfyUI-Easy-Use/README.en.md +422 -0
- ComfyUI-Easy-Use/README.md +448 -0
- ComfyUI-Easy-Use/__init__.py +92 -0
- ComfyUI-Easy-Use/config.yaml +5 -0
- ComfyUI-Easy-Use/install.bat +18 -0
- ComfyUI-Easy-Use/prestartup_script.py +37 -0
- ComfyUI-Easy-Use/py/__init__.py +0 -0
- ComfyUI-Easy-Use/py/api.py +293 -0
- ComfyUI-Easy-Use/py/bitsandbytes_NF4/__init__.py +167 -0
.gitattributes
CHANGED
|
@@ -37,3 +37,11 @@ comfyui_screenshot.png filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
custom_controlnet_aux/mesh_graphormer/hand_landmarker.task filter=lfs diff=lfs merge=lfs -text
|
| 39 |
custom_controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
custom_controlnet_aux/mesh_graphormer/hand_landmarker.task filter=lfs diff=lfs merge=lfs -text
|
| 39 |
custom_controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
ComfyUI-Easy-Use/py/kolors/chatglm/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
ComfyUI-Easy-Use/resources/OpenSans-Medium.ttf filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
ComfyUI-KJNodes/docs/images/2024-04-03_20_49_29-ComfyUI.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
ComfyUI-KJNodes/fonts/FreeMono.ttf filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
ComfyUI-KJNodes/fonts/FreeMonoBoldOblique.otf filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
ComfyUI-KJNodes/fonts/TTNorms-Black.otf filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
ComfyUI-Kolors-MZ/configs/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
ComfyUI-KwaiKolorsWrapper/configs/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
|
ComfyUI-Advanced-ControlNet/LICENSE
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GNU GENERAL PUBLIC LICENSE
|
| 2 |
+
Version 3, 29 June 2007
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
| 5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
| 6 |
+
of this license document, but changing it is not allowed.
|
| 7 |
+
|
| 8 |
+
Preamble
|
| 9 |
+
|
| 10 |
+
The GNU General Public License is a free, copyleft license for
|
| 11 |
+
software and other kinds of works.
|
| 12 |
+
|
| 13 |
+
The licenses for most software and other practical works are designed
|
| 14 |
+
to take away your freedom to share and change the works. By contrast,
|
| 15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
| 16 |
+
share and change all versions of a program--to make sure it remains free
|
| 17 |
+
software for all its users. We, the Free Software Foundation, use the
|
| 18 |
+
GNU General Public License for most of our software; it applies also to
|
| 19 |
+
any other work released this way by its authors. You can apply it to
|
| 20 |
+
your programs, too.
|
| 21 |
+
|
| 22 |
+
When we speak of free software, we are referring to freedom, not
|
| 23 |
+
price. Our General Public Licenses are designed to make sure that you
|
| 24 |
+
have the freedom to distribute copies of free software (and charge for
|
| 25 |
+
them if you wish), that you receive source code or can get it if you
|
| 26 |
+
want it, that you can change the software or use pieces of it in new
|
| 27 |
+
free programs, and that you know you can do these things.
|
| 28 |
+
|
| 29 |
+
To protect your rights, we need to prevent others from denying you
|
| 30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
| 31 |
+
certain responsibilities if you distribute copies of the software, or if
|
| 32 |
+
you modify it: responsibilities to respect the freedom of others.
|
| 33 |
+
|
| 34 |
+
For example, if you distribute copies of such a program, whether
|
| 35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
| 36 |
+
freedoms that you received. You must make sure that they, too, receive
|
| 37 |
+
or can get the source code. And you must show them these terms so they
|
| 38 |
+
know their rights.
|
| 39 |
+
|
| 40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
| 41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
| 42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
| 43 |
+
|
| 44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
| 45 |
+
that there is no warranty for this free software. For both users' and
|
| 46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
| 47 |
+
changed, so that their problems will not be attributed erroneously to
|
| 48 |
+
authors of previous versions.
|
| 49 |
+
|
| 50 |
+
Some devices are designed to deny users access to install or run
|
| 51 |
+
modified versions of the software inside them, although the manufacturer
|
| 52 |
+
can do so. This is fundamentally incompatible with the aim of
|
| 53 |
+
protecting users' freedom to change the software. The systematic
|
| 54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
| 55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
| 56 |
+
have designed this version of the GPL to prohibit the practice for those
|
| 57 |
+
products. If such problems arise substantially in other domains, we
|
| 58 |
+
stand ready to extend this provision to those domains in future versions
|
| 59 |
+
of the GPL, as needed to protect the freedom of users.
|
| 60 |
+
|
| 61 |
+
Finally, every program is threatened constantly by software patents.
|
| 62 |
+
States should not allow patents to restrict development and use of
|
| 63 |
+
software on general-purpose computers, but in those that do, we wish to
|
| 64 |
+
avoid the special danger that patents applied to a free program could
|
| 65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
| 66 |
+
patents cannot be used to render the program non-free.
|
| 67 |
+
|
| 68 |
+
The precise terms and conditions for copying, distribution and
|
| 69 |
+
modification follow.
|
| 70 |
+
|
| 71 |
+
TERMS AND CONDITIONS
|
| 72 |
+
|
| 73 |
+
0. Definitions.
|
| 74 |
+
|
| 75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
| 76 |
+
|
| 77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
| 78 |
+
works, such as semiconductor masks.
|
| 79 |
+
|
| 80 |
+
"The Program" refers to any copyrightable work licensed under this
|
| 81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
| 82 |
+
"recipients" may be individuals or organizations.
|
| 83 |
+
|
| 84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
| 85 |
+
in a fashion requiring copyright permission, other than the making of an
|
| 86 |
+
exact copy. The resulting work is called a "modified version" of the
|
| 87 |
+
earlier work or a work "based on" the earlier work.
|
| 88 |
+
|
| 89 |
+
A "covered work" means either the unmodified Program or a work based
|
| 90 |
+
on the Program.
|
| 91 |
+
|
| 92 |
+
To "propagate" a work means to do anything with it that, without
|
| 93 |
+
permission, would make you directly or secondarily liable for
|
| 94 |
+
infringement under applicable copyright law, except executing it on a
|
| 95 |
+
computer or modifying a private copy. Propagation includes copying,
|
| 96 |
+
distribution (with or without modification), making available to the
|
| 97 |
+
public, and in some countries other activities as well.
|
| 98 |
+
|
| 99 |
+
To "convey" a work means any kind of propagation that enables other
|
| 100 |
+
parties to make or receive copies. Mere interaction with a user through
|
| 101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
| 102 |
+
|
| 103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
| 104 |
+
to the extent that it includes a convenient and prominently visible
|
| 105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
| 106 |
+
tells the user that there is no warranty for the work (except to the
|
| 107 |
+
extent that warranties are provided), that licensees may convey the
|
| 108 |
+
work under this License, and how to view a copy of this License. If
|
| 109 |
+
the interface presents a list of user commands or options, such as a
|
| 110 |
+
menu, a prominent item in the list meets this criterion.
|
| 111 |
+
|
| 112 |
+
1. Source Code.
|
| 113 |
+
|
| 114 |
+
The "source code" for a work means the preferred form of the work
|
| 115 |
+
for making modifications to it. "Object code" means any non-source
|
| 116 |
+
form of a work.
|
| 117 |
+
|
| 118 |
+
A "Standard Interface" means an interface that either is an official
|
| 119 |
+
standard defined by a recognized standards body, or, in the case of
|
| 120 |
+
interfaces specified for a particular programming language, one that
|
| 121 |
+
is widely used among developers working in that language.
|
| 122 |
+
|
| 123 |
+
The "System Libraries" of an executable work include anything, other
|
| 124 |
+
than the work as a whole, that (a) is included in the normal form of
|
| 125 |
+
packaging a Major Component, but which is not part of that Major
|
| 126 |
+
Component, and (b) serves only to enable use of the work with that
|
| 127 |
+
Major Component, or to implement a Standard Interface for which an
|
| 128 |
+
implementation is available to the public in source code form. A
|
| 129 |
+
"Major Component", in this context, means a major essential component
|
| 130 |
+
(kernel, window system, and so on) of the specific operating system
|
| 131 |
+
(if any) on which the executable work runs, or a compiler used to
|
| 132 |
+
produce the work, or an object code interpreter used to run it.
|
| 133 |
+
|
| 134 |
+
The "Corresponding Source" for a work in object code form means all
|
| 135 |
+
the source code needed to generate, install, and (for an executable
|
| 136 |
+
work) run the object code and to modify the work, including scripts to
|
| 137 |
+
control those activities. However, it does not include the work's
|
| 138 |
+
System Libraries, or general-purpose tools or generally available free
|
| 139 |
+
programs which are used unmodified in performing those activities but
|
| 140 |
+
which are not part of the work. For example, Corresponding Source
|
| 141 |
+
includes interface definition files associated with source files for
|
| 142 |
+
the work, and the source code for shared libraries and dynamically
|
| 143 |
+
linked subprograms that the work is specifically designed to require,
|
| 144 |
+
such as by intimate data communication or control flow between those
|
| 145 |
+
subprograms and other parts of the work.
|
| 146 |
+
|
| 147 |
+
The Corresponding Source need not include anything that users
|
| 148 |
+
can regenerate automatically from other parts of the Corresponding
|
| 149 |
+
Source.
|
| 150 |
+
|
| 151 |
+
The Corresponding Source for a work in source code form is that
|
| 152 |
+
same work.
|
| 153 |
+
|
| 154 |
+
2. Basic Permissions.
|
| 155 |
+
|
| 156 |
+
All rights granted under this License are granted for the term of
|
| 157 |
+
copyright on the Program, and are irrevocable provided the stated
|
| 158 |
+
conditions are met. This License explicitly affirms your unlimited
|
| 159 |
+
permission to run the unmodified Program. The output from running a
|
| 160 |
+
covered work is covered by this License only if the output, given its
|
| 161 |
+
content, constitutes a covered work. This License acknowledges your
|
| 162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
| 163 |
+
|
| 164 |
+
You may make, run and propagate covered works that you do not
|
| 165 |
+
convey, without conditions so long as your license otherwise remains
|
| 166 |
+
in force. You may convey covered works to others for the sole purpose
|
| 167 |
+
of having them make modifications exclusively for you, or provide you
|
| 168 |
+
with facilities for running those works, provided that you comply with
|
| 169 |
+
the terms of this License in conveying all material for which you do
|
| 170 |
+
not control copyright. Those thus making or running the covered works
|
| 171 |
+
for you must do so exclusively on your behalf, under your direction
|
| 172 |
+
and control, on terms that prohibit them from making any copies of
|
| 173 |
+
your copyrighted material outside their relationship with you.
|
| 174 |
+
|
| 175 |
+
Conveying under any other circumstances is permitted solely under
|
| 176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
| 177 |
+
makes it unnecessary.
|
| 178 |
+
|
| 179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
| 180 |
+
|
| 181 |
+
No covered work shall be deemed part of an effective technological
|
| 182 |
+
measure under any applicable law fulfilling obligations under article
|
| 183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
| 184 |
+
similar laws prohibiting or restricting circumvention of such
|
| 185 |
+
measures.
|
| 186 |
+
|
| 187 |
+
When you convey a covered work, you waive any legal power to forbid
|
| 188 |
+
circumvention of technological measures to the extent such circumvention
|
| 189 |
+
is effected by exercising rights under this License with respect to
|
| 190 |
+
the covered work, and you disclaim any intention to limit operation or
|
| 191 |
+
modification of the work as a means of enforcing, against the work's
|
| 192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
| 193 |
+
technological measures.
|
| 194 |
+
|
| 195 |
+
4. Conveying Verbatim Copies.
|
| 196 |
+
|
| 197 |
+
You may convey verbatim copies of the Program's source code as you
|
| 198 |
+
receive it, in any medium, provided that you conspicuously and
|
| 199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
| 200 |
+
keep intact all notices stating that this License and any
|
| 201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
| 202 |
+
keep intact all notices of the absence of any warranty; and give all
|
| 203 |
+
recipients a copy of this License along with the Program.
|
| 204 |
+
|
| 205 |
+
You may charge any price or no price for each copy that you convey,
|
| 206 |
+
and you may offer support or warranty protection for a fee.
|
| 207 |
+
|
| 208 |
+
5. Conveying Modified Source Versions.
|
| 209 |
+
|
| 210 |
+
You may convey a work based on the Program, or the modifications to
|
| 211 |
+
produce it from the Program, in the form of source code under the
|
| 212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
| 213 |
+
|
| 214 |
+
a) The work must carry prominent notices stating that you modified
|
| 215 |
+
it, and giving a relevant date.
|
| 216 |
+
|
| 217 |
+
b) The work must carry prominent notices stating that it is
|
| 218 |
+
released under this License and any conditions added under section
|
| 219 |
+
7. This requirement modifies the requirement in section 4 to
|
| 220 |
+
"keep intact all notices".
|
| 221 |
+
|
| 222 |
+
c) You must license the entire work, as a whole, under this
|
| 223 |
+
License to anyone who comes into possession of a copy. This
|
| 224 |
+
License will therefore apply, along with any applicable section 7
|
| 225 |
+
additional terms, to the whole of the work, and all its parts,
|
| 226 |
+
regardless of how they are packaged. This License gives no
|
| 227 |
+
permission to license the work in any other way, but it does not
|
| 228 |
+
invalidate such permission if you have separately received it.
|
| 229 |
+
|
| 230 |
+
d) If the work has interactive user interfaces, each must display
|
| 231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
| 232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
| 233 |
+
work need not make them do so.
|
| 234 |
+
|
| 235 |
+
A compilation of a covered work with other separate and independent
|
| 236 |
+
works, which are not by their nature extensions of the covered work,
|
| 237 |
+
and which are not combined with it such as to form a larger program,
|
| 238 |
+
in or on a volume of a storage or distribution medium, is called an
|
| 239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
| 240 |
+
used to limit the access or legal rights of the compilation's users
|
| 241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
| 242 |
+
in an aggregate does not cause this License to apply to the other
|
| 243 |
+
parts of the aggregate.
|
| 244 |
+
|
| 245 |
+
6. Conveying Non-Source Forms.
|
| 246 |
+
|
| 247 |
+
You may convey a covered work in object code form under the terms
|
| 248 |
+
of sections 4 and 5, provided that you also convey the
|
| 249 |
+
machine-readable Corresponding Source under the terms of this License,
|
| 250 |
+
in one of these ways:
|
| 251 |
+
|
| 252 |
+
a) Convey the object code in, or embodied in, a physical product
|
| 253 |
+
(including a physical distribution medium), accompanied by the
|
| 254 |
+
Corresponding Source fixed on a durable physical medium
|
| 255 |
+
customarily used for software interchange.
|
| 256 |
+
|
| 257 |
+
b) Convey the object code in, or embodied in, a physical product
|
| 258 |
+
(including a physical distribution medium), accompanied by a
|
| 259 |
+
written offer, valid for at least three years and valid for as
|
| 260 |
+
long as you offer spare parts or customer support for that product
|
| 261 |
+
model, to give anyone who possesses the object code either (1) a
|
| 262 |
+
copy of the Corresponding Source for all the software in the
|
| 263 |
+
product that is covered by this License, on a durable physical
|
| 264 |
+
medium customarily used for software interchange, for a price no
|
| 265 |
+
more than your reasonable cost of physically performing this
|
| 266 |
+
conveying of source, or (2) access to copy the
|
| 267 |
+
Corresponding Source from a network server at no charge.
|
| 268 |
+
|
| 269 |
+
c) Convey individual copies of the object code with a copy of the
|
| 270 |
+
written offer to provide the Corresponding Source. This
|
| 271 |
+
alternative is allowed only occasionally and noncommercially, and
|
| 272 |
+
only if you received the object code with such an offer, in accord
|
| 273 |
+
with subsection 6b.
|
| 274 |
+
|
| 275 |
+
d) Convey the object code by offering access from a designated
|
| 276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
| 277 |
+
Corresponding Source in the same way through the same place at no
|
| 278 |
+
further charge. You need not require recipients to copy the
|
| 279 |
+
Corresponding Source along with the object code. If the place to
|
| 280 |
+
copy the object code is a network server, the Corresponding Source
|
| 281 |
+
may be on a different server (operated by you or a third party)
|
| 282 |
+
that supports equivalent copying facilities, provided you maintain
|
| 283 |
+
clear directions next to the object code saying where to find the
|
| 284 |
+
Corresponding Source. Regardless of what server hosts the
|
| 285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
| 286 |
+
available for as long as needed to satisfy these requirements.
|
| 287 |
+
|
| 288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
| 289 |
+
you inform other peers where the object code and Corresponding
|
| 290 |
+
Source of the work are being offered to the general public at no
|
| 291 |
+
charge under subsection 6d.
|
| 292 |
+
|
| 293 |
+
A separable portion of the object code, whose source code is excluded
|
| 294 |
+
from the Corresponding Source as a System Library, need not be
|
| 295 |
+
included in conveying the object code work.
|
| 296 |
+
|
| 297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
| 298 |
+
tangible personal property which is normally used for personal, family,
|
| 299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
| 300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
| 301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
| 302 |
+
product received by a particular user, "normally used" refers to a
|
| 303 |
+
typical or common use of that class of product, regardless of the status
|
| 304 |
+
of the particular user or of the way in which the particular user
|
| 305 |
+
actually uses, or expects or is expected to use, the product. A product
|
| 306 |
+
is a consumer product regardless of whether the product has substantial
|
| 307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
| 308 |
+
the only significant mode of use of the product.
|
| 309 |
+
|
| 310 |
+
"Installation Information" for a User Product means any methods,
|
| 311 |
+
procedures, authorization keys, or other information required to install
|
| 312 |
+
and execute modified versions of a covered work in that User Product from
|
| 313 |
+
a modified version of its Corresponding Source. The information must
|
| 314 |
+
suffice to ensure that the continued functioning of the modified object
|
| 315 |
+
code is in no case prevented or interfered with solely because
|
| 316 |
+
modification has been made.
|
| 317 |
+
|
| 318 |
+
If you convey an object code work under this section in, or with, or
|
| 319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
| 320 |
+
part of a transaction in which the right of possession and use of the
|
| 321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
| 322 |
+
fixed term (regardless of how the transaction is characterized), the
|
| 323 |
+
Corresponding Source conveyed under this section must be accompanied
|
| 324 |
+
by the Installation Information. But this requirement does not apply
|
| 325 |
+
if neither you nor any third party retains the ability to install
|
| 326 |
+
modified object code on the User Product (for example, the work has
|
| 327 |
+
been installed in ROM).
|
| 328 |
+
|
| 329 |
+
The requirement to provide Installation Information does not include a
|
| 330 |
+
requirement to continue to provide support service, warranty, or updates
|
| 331 |
+
for a work that has been modified or installed by the recipient, or for
|
| 332 |
+
the User Product in which it has been modified or installed. Access to a
|
| 333 |
+
network may be denied when the modification itself materially and
|
| 334 |
+
adversely affects the operation of the network or violates the rules and
|
| 335 |
+
protocols for communication across the network.
|
| 336 |
+
|
| 337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
| 338 |
+
in accord with this section must be in a format that is publicly
|
| 339 |
+
documented (and with an implementation available to the public in
|
| 340 |
+
source code form), and must require no special password or key for
|
| 341 |
+
unpacking, reading or copying.
|
| 342 |
+
|
| 343 |
+
7. Additional Terms.
|
| 344 |
+
|
| 345 |
+
"Additional permissions" are terms that supplement the terms of this
|
| 346 |
+
License by making exceptions from one or more of its conditions.
|
| 347 |
+
Additional permissions that are applicable to the entire Program shall
|
| 348 |
+
be treated as though they were included in this License, to the extent
|
| 349 |
+
that they are valid under applicable law. If additional permissions
|
| 350 |
+
apply only to part of the Program, that part may be used separately
|
| 351 |
+
under those permissions, but the entire Program remains governed by
|
| 352 |
+
this License without regard to the additional permissions.
|
| 353 |
+
|
| 354 |
+
When you convey a copy of a covered work, you may at your option
|
| 355 |
+
remove any additional permissions from that copy, or from any part of
|
| 356 |
+
it. (Additional permissions may be written to require their own
|
| 357 |
+
removal in certain cases when you modify the work.) You may place
|
| 358 |
+
additional permissions on material, added by you to a covered work,
|
| 359 |
+
for which you have or can give appropriate copyright permission.
|
| 360 |
+
|
| 361 |
+
Notwithstanding any other provision of this License, for material you
|
| 362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
| 363 |
+
that material) supplement the terms of this License with terms:
|
| 364 |
+
|
| 365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
| 366 |
+
terms of sections 15 and 16 of this License; or
|
| 367 |
+
|
| 368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
| 369 |
+
author attributions in that material or in the Appropriate Legal
|
| 370 |
+
Notices displayed by works containing it; or
|
| 371 |
+
|
| 372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
| 373 |
+
requiring that modified versions of such material be marked in
|
| 374 |
+
reasonable ways as different from the original version; or
|
| 375 |
+
|
| 376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
| 377 |
+
authors of the material; or
|
| 378 |
+
|
| 379 |
+
e) Declining to grant rights under trademark law for use of some
|
| 380 |
+
trade names, trademarks, or service marks; or
|
| 381 |
+
|
| 382 |
+
f) Requiring indemnification of licensors and authors of that
|
| 383 |
+
material by anyone who conveys the material (or modified versions of
|
| 384 |
+
it) with contractual assumptions of liability to the recipient, for
|
| 385 |
+
any liability that these contractual assumptions directly impose on
|
| 386 |
+
those licensors and authors.
|
| 387 |
+
|
| 388 |
+
All other non-permissive additional terms are considered "further
|
| 389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
| 390 |
+
received it, or any part of it, contains a notice stating that it is
|
| 391 |
+
governed by this License along with a term that is a further
|
| 392 |
+
restriction, you may remove that term. If a license document contains
|
| 393 |
+
a further restriction but permits relicensing or conveying under this
|
| 394 |
+
License, you may add to a covered work material governed by the terms
|
| 395 |
+
of that license document, provided that the further restriction does
|
| 396 |
+
not survive such relicensing or conveying.
|
| 397 |
+
|
| 398 |
+
If you add terms to a covered work in accord with this section, you
|
| 399 |
+
must place, in the relevant source files, a statement of the
|
| 400 |
+
additional terms that apply to those files, or a notice indicating
|
| 401 |
+
where to find the applicable terms.
|
| 402 |
+
|
| 403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
| 404 |
+
form of a separately written license, or stated as exceptions;
|
| 405 |
+
the above requirements apply either way.
|
| 406 |
+
|
| 407 |
+
8. Termination.
|
| 408 |
+
|
| 409 |
+
You may not propagate or modify a covered work except as expressly
|
| 410 |
+
provided under this License. Any attempt otherwise to propagate or
|
| 411 |
+
modify it is void, and will automatically terminate your rights under
|
| 412 |
+
this License (including any patent licenses granted under the third
|
| 413 |
+
paragraph of section 11).
|
| 414 |
+
|
| 415 |
+
However, if you cease all violation of this License, then your
|
| 416 |
+
license from a particular copyright holder is reinstated (a)
|
| 417 |
+
provisionally, unless and until the copyright holder explicitly and
|
| 418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
| 419 |
+
holder fails to notify you of the violation by some reasonable means
|
| 420 |
+
prior to 60 days after the cessation.
|
| 421 |
+
|
| 422 |
+
Moreover, your license from a particular copyright holder is
|
| 423 |
+
reinstated permanently if the copyright holder notifies you of the
|
| 424 |
+
violation by some reasonable means, this is the first time you have
|
| 425 |
+
received notice of violation of this License (for any work) from that
|
| 426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
| 427 |
+
your receipt of the notice.
|
| 428 |
+
|
| 429 |
+
Termination of your rights under this section does not terminate the
|
| 430 |
+
licenses of parties who have received copies or rights from you under
|
| 431 |
+
this License. If your rights have been terminated and not permanently
|
| 432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
| 433 |
+
material under section 10.
|
| 434 |
+
|
| 435 |
+
9. Acceptance Not Required for Having Copies.
|
| 436 |
+
|
| 437 |
+
You are not required to accept this License in order to receive or
|
| 438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
| 439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
| 440 |
+
to receive a copy likewise does not require acceptance. However,
|
| 441 |
+
nothing other than this License grants you permission to propagate or
|
| 442 |
+
modify any covered work. These actions infringe copyright if you do
|
| 443 |
+
not accept this License. Therefore, by modifying or propagating a
|
| 444 |
+
covered work, you indicate your acceptance of this License to do so.
|
| 445 |
+
|
| 446 |
+
10. Automatic Licensing of Downstream Recipients.
|
| 447 |
+
|
| 448 |
+
Each time you convey a covered work, the recipient automatically
|
| 449 |
+
receives a license from the original licensors, to run, modify and
|
| 450 |
+
propagate that work, subject to this License. You are not responsible
|
| 451 |
+
for enforcing compliance by third parties with this License.
|
| 452 |
+
|
| 453 |
+
An "entity transaction" is a transaction transferring control of an
|
| 454 |
+
organization, or substantially all assets of one, or subdividing an
|
| 455 |
+
organization, or merging organizations. If propagation of a covered
|
| 456 |
+
work results from an entity transaction, each party to that
|
| 457 |
+
transaction who receives a copy of the work also receives whatever
|
| 458 |
+
licenses to the work the party's predecessor in interest had or could
|
| 459 |
+
give under the previous paragraph, plus a right to possession of the
|
| 460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
| 461 |
+
the predecessor has it or can get it with reasonable efforts.
|
| 462 |
+
|
| 463 |
+
You may not impose any further restrictions on the exercise of the
|
| 464 |
+
rights granted or affirmed under this License. For example, you may
|
| 465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
| 466 |
+
rights granted under this License, and you may not initiate litigation
|
| 467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
| 468 |
+
any patent claim is infringed by making, using, selling, offering for
|
| 469 |
+
sale, or importing the Program or any portion of it.
|
| 470 |
+
|
| 471 |
+
11. Patents.
|
| 472 |
+
|
| 473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
| 474 |
+
License of the Program or a work on which the Program is based. The
|
| 475 |
+
work thus licensed is called the contributor's "contributor version".
|
| 476 |
+
|
| 477 |
+
A contributor's "essential patent claims" are all patent claims
|
| 478 |
+
owned or controlled by the contributor, whether already acquired or
|
| 479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
| 480 |
+
by this License, of making, using, or selling its contributor version,
|
| 481 |
+
but do not include claims that would be infringed only as a
|
| 482 |
+
consequence of further modification of the contributor version. For
|
| 483 |
+
purposes of this definition, "control" includes the right to grant
|
| 484 |
+
patent sublicenses in a manner consistent with the requirements of
|
| 485 |
+
this License.
|
| 486 |
+
|
| 487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
| 488 |
+
patent license under the contributor's essential patent claims, to
|
| 489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
| 490 |
+
propagate the contents of its contributor version.
|
| 491 |
+
|
| 492 |
+
In the following three paragraphs, a "patent license" is any express
|
| 493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
| 494 |
+
(such as an express permission to practice a patent or covenant not to
|
| 495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
| 496 |
+
party means to make such an agreement or commitment not to enforce a
|
| 497 |
+
patent against the party.
|
| 498 |
+
|
| 499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
| 500 |
+
and the Corresponding Source of the work is not available for anyone
|
| 501 |
+
to copy, free of charge and under the terms of this License, through a
|
| 502 |
+
publicly available network server or other readily accessible means,
|
| 503 |
+
then you must either (1) cause the Corresponding Source to be so
|
| 504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
| 505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
| 506 |
+
consistent with the requirements of this License, to extend the patent
|
| 507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
| 508 |
+
actual knowledge that, but for the patent license, your conveying the
|
| 509 |
+
covered work in a country, or your recipient's use of the covered work
|
| 510 |
+
in a country, would infringe one or more identifiable patents in that
|
| 511 |
+
country that you have reason to believe are valid.
|
| 512 |
+
|
| 513 |
+
If, pursuant to or in connection with a single transaction or
|
| 514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
| 515 |
+
covered work, and grant a patent license to some of the parties
|
| 516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
| 517 |
+
or convey a specific copy of the covered work, then the patent license
|
| 518 |
+
you grant is automatically extended to all recipients of the covered
|
| 519 |
+
work and works based on it.
|
| 520 |
+
|
| 521 |
+
A patent license is "discriminatory" if it does not include within
|
| 522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
| 523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
| 524 |
+
specifically granted under this License. You may not convey a covered
|
| 525 |
+
work if you are a party to an arrangement with a third party that is
|
| 526 |
+
in the business of distributing software, under which you make payment
|
| 527 |
+
to the third party based on the extent of your activity of conveying
|
| 528 |
+
the work, and under which the third party grants, to any of the
|
| 529 |
+
parties who would receive the covered work from you, a discriminatory
|
| 530 |
+
patent license (a) in connection with copies of the covered work
|
| 531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
| 532 |
+
for and in connection with specific products or compilations that
|
| 533 |
+
contain the covered work, unless you entered into that arrangement,
|
| 534 |
+
or that patent license was granted, prior to 28 March 2007.
|
| 535 |
+
|
| 536 |
+
Nothing in this License shall be construed as excluding or limiting
|
| 537 |
+
any implied license or other defenses to infringement that may
|
| 538 |
+
otherwise be available to you under applicable patent law.
|
| 539 |
+
|
| 540 |
+
12. No Surrender of Others' Freedom.
|
| 541 |
+
|
| 542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
| 543 |
+
otherwise) that contradict the conditions of this License, they do not
|
| 544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
| 545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
| 546 |
+
License and any other pertinent obligations, then as a consequence you may
|
| 547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
| 548 |
+
to collect a royalty for further conveying from those to whom you convey
|
| 549 |
+
the Program, the only way you could satisfy both those terms and this
|
| 550 |
+
License would be to refrain entirely from conveying the Program.
|
| 551 |
+
|
| 552 |
+
13. Use with the GNU Affero General Public License.
|
| 553 |
+
|
| 554 |
+
Notwithstanding any other provision of this License, you have
|
| 555 |
+
permission to link or combine any covered work with a work licensed
|
| 556 |
+
under version 3 of the GNU Affero General Public License into a single
|
| 557 |
+
combined work, and to convey the resulting work. The terms of this
|
| 558 |
+
License will continue to apply to the part which is the covered work,
|
| 559 |
+
but the special requirements of the GNU Affero General Public License,
|
| 560 |
+
section 13, concerning interaction through a network will apply to the
|
| 561 |
+
combination as such.
|
| 562 |
+
|
| 563 |
+
14. Revised Versions of this License.
|
| 564 |
+
|
| 565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
| 566 |
+
the GNU General Public License from time to time. Such new versions will
|
| 567 |
+
be similar in spirit to the present version, but may differ in detail to
|
| 568 |
+
address new problems or concerns.
|
| 569 |
+
|
| 570 |
+
Each version is given a distinguishing version number. If the
|
| 571 |
+
Program specifies that a certain numbered version of the GNU General
|
| 572 |
+
Public License "or any later version" applies to it, you have the
|
| 573 |
+
option of following the terms and conditions either of that numbered
|
| 574 |
+
version or of any later version published by the Free Software
|
| 575 |
+
Foundation. If the Program does not specify a version number of the
|
| 576 |
+
GNU General Public License, you may choose any version ever published
|
| 577 |
+
by the Free Software Foundation.
|
| 578 |
+
|
| 579 |
+
If the Program specifies that a proxy can decide which future
|
| 580 |
+
versions of the GNU General Public License can be used, that proxy's
|
| 581 |
+
public statement of acceptance of a version permanently authorizes you
|
| 582 |
+
to choose that version for the Program.
|
| 583 |
+
|
| 584 |
+
Later license versions may give you additional or different
|
| 585 |
+
permissions. However, no additional obligations are imposed on any
|
| 586 |
+
author or copyright holder as a result of your choosing to follow a
|
| 587 |
+
later version.
|
| 588 |
+
|
| 589 |
+
15. Disclaimer of Warranty.
|
| 590 |
+
|
| 591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
| 592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
| 593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
| 594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
| 595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
| 596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
| 597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
| 598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
| 599 |
+
|
| 600 |
+
16. Limitation of Liability.
|
| 601 |
+
|
| 602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
| 603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
| 604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
| 605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
| 606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
| 607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
| 608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
| 609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
| 610 |
+
SUCH DAMAGES.
|
| 611 |
+
|
| 612 |
+
17. Interpretation of Sections 15 and 16.
|
| 613 |
+
|
| 614 |
+
If the disclaimer of warranty and limitation of liability provided
|
| 615 |
+
above cannot be given local legal effect according to their terms,
|
| 616 |
+
reviewing courts shall apply local law that most closely approximates
|
| 617 |
+
an absolute waiver of all civil liability in connection with the
|
| 618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
| 619 |
+
copy of the Program in return for a fee.
|
| 620 |
+
|
| 621 |
+
END OF TERMS AND CONDITIONS
|
| 622 |
+
|
| 623 |
+
How to Apply These Terms to Your New Programs
|
| 624 |
+
|
| 625 |
+
If you develop a new program, and you want it to be of the greatest
|
| 626 |
+
possible use to the public, the best way to achieve this is to make it
|
| 627 |
+
free software which everyone can redistribute and change under these terms.
|
| 628 |
+
|
| 629 |
+
To do so, attach the following notices to the program. It is safest
|
| 630 |
+
to attach them to the start of each source file to most effectively
|
| 631 |
+
state the exclusion of warranty; and each file should have at least
|
| 632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
| 633 |
+
|
| 634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
| 635 |
+
Copyright (C) <year> <name of author>
|
| 636 |
+
|
| 637 |
+
This program is free software: you can redistribute it and/or modify
|
| 638 |
+
it under the terms of the GNU General Public License as published by
|
| 639 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 640 |
+
(at your option) any later version.
|
| 641 |
+
|
| 642 |
+
This program is distributed in the hope that it will be useful,
|
| 643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 645 |
+
GNU General Public License for more details.
|
| 646 |
+
|
| 647 |
+
You should have received a copy of the GNU General Public License
|
| 648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 649 |
+
|
| 650 |
+
Also add information on how to contact you by electronic and paper mail.
|
| 651 |
+
|
| 652 |
+
If the program does terminal interaction, make it output a short
|
| 653 |
+
notice like this when it starts in an interactive mode:
|
| 654 |
+
|
| 655 |
+
<program> Copyright (C) <year> <name of author>
|
| 656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
| 657 |
+
This is free software, and you are welcome to redistribute it
|
| 658 |
+
under certain conditions; type `show c' for details.
|
| 659 |
+
|
| 660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
| 661 |
+
parts of the General Public License. Of course, your program's commands
|
| 662 |
+
might be different; for a GUI interface, you would use an "about box".
|
| 663 |
+
|
| 664 |
+
You should also get your employer (if you work as a programmer) or school,
|
| 665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
| 666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
| 667 |
+
<https://www.gnu.org/licenses/>.
|
| 668 |
+
|
| 669 |
+
The GNU General Public License does not permit incorporating your program
|
| 670 |
+
into proprietary programs. If your program is a subroutine library, you
|
| 671 |
+
may consider it more useful to permit linking proprietary applications with
|
| 672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
| 673 |
+
Public License instead of this License. But first, please read
|
| 674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
ComfyUI-Advanced-ControlNet/README.md
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ComfyUI-Advanced-ControlNet
|
| 2 |
+
Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks. The ControlNet nodes here fully support sliding context sampling, like the one used in the [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) nodes. Currently supports ControlNets, T2IAdapters, ControlLoRAs, ControlLLLite, SparseCtrls, SVD-ControlNets, and Reference.
|
| 3 |
+
|
| 4 |
+
Custom weights allow replication of the "My prompt is more important" feature of Auto1111's sd-webui ControlNet extension via Soft Weights, and the "ControlNet is more important" feature can be granularly controlled by changing the uncond_multiplier on the same Soft Weights.
|
| 5 |
+
|
| 6 |
+
ControlNet preprocessors are available through [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux) nodes.
|
| 7 |
+
|
| 8 |
+
## Features
|
| 9 |
+
- Timestep and latent strength scheduling
|
| 10 |
+
- Attention masks
|
| 11 |
+
- Replicate ***"My prompt is more important"*** feature from sd-webui-controlnet extension via ***Soft Weights***, and allow softness to be tweaked via ***base_multiplier***
|
| 12 |
+
- Replicate ***"ControlNet is more important"*** feature from sd-webui-controlnet extension via ***uncond_multiplier*** on ***Soft Weights***
|
| 13 |
+
- uncond_multiplier=0.0 gives identical results of auto1111's feature, but values between 0.0 and 1.0 can be used without issue to granularly control the setting.
|
| 14 |
+
- ControlNet, T2IAdapter, and ControlLoRA support for sliding context windows
|
| 15 |
+
- ControlLLLite support (requires model_optional to be passed into and out of Apply Advanced ControlNet node)
|
| 16 |
+
- SparseCtrl support
|
| 17 |
+
- SVD-ControlNet support
|
| 18 |
+
- Stable Video Diffusion ControlNets trained by **CiaraRowles**: [Depth](https://huggingface.co/CiaraRowles/temporal-controlnet-depth-svd-v1/tree/main/controlnet), [Lineart](https://huggingface.co/CiaraRowles/temporal-controlnet-lineart-svd-v1/tree/main/controlnet)
|
| 19 |
+
- Reference support
|
| 20 |
+
- Supports ```reference_attn```, ```reference_adain```, and ```refrence_adain+attn``` modes. ```style_fidelity``` and ```ref_weight``` are equivalent to style_fidelity and control_weight in Auto1111, respectively, and strength of the Apply ControlNet is the balance between ref-influenced result and no-ref result. There is also a Reference ControlNet (Finetune) node that allows adjust the style_fidelity, weight, and strength of attn and adain separately.
|
| 21 |
+
|
| 22 |
+
## Table of Contents:
|
| 23 |
+
- [Scheduling Explanation](#scheduling-explanation)
|
| 24 |
+
- [Nodes](#nodes)
|
| 25 |
+
- [Usage](#usage) (will fill this out soon)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Scheduling Explanation
|
| 29 |
+
|
| 30 |
+
The two core concepts for scheduling are ***Timestep Keyframes*** and ***Latent Keyframes***.
|
| 31 |
+
|
| 32 |
+
***Timestep Keyframes*** hold the values that guide the settings for a controlnet, and begin to take effect based on their start_percent, which corresponds to the percentage of the sampling process. They can contain masks for the strengths of each latent, control_net_weights, and latent_keyframes (specific strengths for each latent), all optional.
|
| 33 |
+
|
| 34 |
+
***Latent Keyframes*** determine the strength of the controlnet for specific latents - all they contain is the batch_index of the latent, and the strength the controlnet should apply for that latent. As a concept, latent keyframes achieve the same affect as a uniform mask with the chosen strength value.
|
| 35 |
+
|
| 36 |
+

|
| 37 |
+
|
| 38 |
+
# Nodes
|
| 39 |
+
|
| 40 |
+
The ControlNet nodes provided here are the ***Apply Advanced ControlNet*** and ***Load Advanced ControlNet Model*** (or diff) nodes. The vanilla ControlNet nodes are also compatible, and can be used almost interchangeably - the only difference is that **at least one of these nodes must be used** for Advanced versions of ControlNets to be used (important for sliding context sampling, like with AnimateDiff-Evolved).
|
| 41 |
+
|
| 42 |
+
Key:
|
| 43 |
+
- 🟩 - required inputs
|
| 44 |
+
- 🟨 - optional inputs
|
| 45 |
+
- 🟦 - start as widgets, can be converted to inputs
|
| 46 |
+
- 🟥 - optional input/output, but not recommended to use unless needed
|
| 47 |
+
- 🟪 - output
|
| 48 |
+
|
| 49 |
+
## Apply Advanced ControlNet
|
| 50 |
+

|
| 51 |
+
|
| 52 |
+
Same functionality as the vanilla Apply Advanced ControlNet (Advanced) node, except with Advanced ControlNet features added to it. Automatically converts any ControlNet from ControlNet loaders into Advanced versions.
|
| 53 |
+
|
| 54 |
+
### Inputs
|
| 55 |
+
- 🟩***positive***: conditioning (positive).
|
| 56 |
+
- 🟩***negative***: conditioning (negative).
|
| 57 |
+
- 🟩***control_net***: loaded controlnet; will be converted to Advanced version automatically by this node, if it's a supported type.
|
| 58 |
+
- 🟩***image***: images to guide controlnets - if the loaded controlnet requires it, they must preprocessed images. If one image provided, will be used for all latents. If more images provided, will use each image separately for each latent. If not enough images to meet latent count, will repeat the images from the beginning to match vanilla ControlNet functionality.
|
| 59 |
+
- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as image input, if you provide more than one mask, each can apply to a different latent.
|
| 60 |
+
- 🟨***timestep_kf***: timestep keyframes to guide controlnet effect throughout sampling steps.
|
| 61 |
+
- 🟨***latent_kf_override***: override for latent keyframes, useful if no other features from timestep keyframes is needed. *NOTE: this latent keyframe will be applied to ALL timesteps, regardless if there are other latent keyframes attached to connected timestep keyframes.*
|
| 62 |
+
- 🟨***weights_override***: override for weights, useful if no other features from timestep keyframes is needed. *NOTE: this weight will be applied to ALL timesteps, regardless if there are other weights attached to connected timestep keyframes.*
|
| 63 |
+
- 🟦***strength***: strength of controlnet; 1.0 is full strength, 0.0 is no effect at all.
|
| 64 |
+
- 🟦***start_percent***: sampling step percentage at which controlnet should start to be applied - no matter what start_percent is set on timestep keyframes, they won't take effect until this start_percent is reached.
|
| 65 |
+
- 🟦***stop_percent***: sampling step percentage at which controlnet should stop being applied - no matter what start_percent is set on timestep keyframes, they won't take effect once this end_percent is reached.
|
| 66 |
+
|
| 67 |
+
### Outputs
|
| 68 |
+
- 🟪***positive***: conditioning (positive) with applied controlnets
|
| 69 |
+
- 🟪***negative***: conditioning (negative) with applied controlnets
|
| 70 |
+
|
| 71 |
+
## Load Advanced ControlNet Model
|
| 72 |
+

|
| 73 |
+
|
| 74 |
+
Loads a ControlNet model and converts it into an Advanced version that supports all the features in this repo. When used with **Apply Advanced ControlNet** node, there is no reason to use the timestep_keyframe input on this node - use timestep_kf on the Apply node instead.
|
| 75 |
+
|
| 76 |
+
### Inputs
|
| 77 |
+
- 🟥***timestep_keyframe***: optional and likely unnecessary input to have ControlNet use selected timestep_keyframes - should not be used unless you need to. Useful if this node is not attached to **Apply Advanced ControlNet** node, but still want to use Timestep Keyframe, or to use TK_SHORTCUT outputs from ControlWeights in the same scenario. Will be overriden by the timestep_kf input on **Apply Advanced ControlNet** node, if one is provided there.
|
| 78 |
+
- 🟨***model***: model to plug into the diff version of the node. Some controlnets are designed for receive the model; if you don't know what this does, you probably don't want tot use the diff version of the node.
|
| 79 |
+
|
| 80 |
+
### Outputs
|
| 81 |
+
- 🟪***CONTROL_NET***: loaded Advanced ControlNet
|
| 82 |
+
|
| 83 |
+
## Timestep Keyframe
|
| 84 |
+

|
| 85 |
+
|
| 86 |
+
Scheduling node across timesteps (sampling steps) based on the set start_percent. Chaining Timestep Keyframes allows ControlNet scheduling across sampling steps (percentage-wise), through a timestep keyframe schedule.
|
| 87 |
+
|
| 88 |
+
### Inputs
|
| 89 |
+
- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
|
| 90 |
+
- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
|
| 91 |
+
- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
|
| 92 |
+
- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
|
| 93 |
+
- 🟦***start_percent***: sampling step percentage at which this Timestep Keyframe qualifies to be used. Acts as the 'key' for the Timestep Keyframe in the timestep keyframe schedule.
|
| 94 |
+
- 🟦***strength***: strength of the controlnet; multiplies the controlnet by this value, basically, applied alongside the strength on the Apply ControlNet node. If set to 0.0 will not have any effect during the duration of this Timestep Keyframe's effect, and will increase sampling speed by not doing any work.
|
| 95 |
+
- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
|
| 96 |
+
- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
|
| 97 |
+
- 🟦***guarantee_steps***: when 1 or greater, even if a Timestep Keyframe's start_percent ahead of this one in the schedule is closer to current sampling percentage, this Timestep Keyframe will still be used for the specified amount of steps before moving on to the next selected Timestep Keyframe in the following step. Whether the Timestep Keyframe is used or not, its inputs will still be accounted for inherit_missing purposes.
|
| 98 |
+
|
| 99 |
+
### Outputs
|
| 100 |
+
- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
|
| 101 |
+
|
| 102 |
+
## Timestep Keyframe Interpolation
|
| 103 |
+

|
| 104 |
+
|
| 105 |
+
Allows to create Timestep Keyframe with interpolated strength values in a given percent range. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
|
| 106 |
+
|
| 107 |
+
### Inputs
|
| 108 |
+
- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
|
| 109 |
+
- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
|
| 110 |
+
- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
|
| 111 |
+
- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
|
| 112 |
+
- 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
|
| 113 |
+
- 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
|
| 114 |
+
- 🟦***strength_start***: strength of the Timestep Keyframe at start of range.
|
| 115 |
+
- 🟦***strength_end***: strength of the Timestep Keyframe at end of range.
|
| 116 |
+
- 🟦***interpolation***: the method of interpolation.
|
| 117 |
+
- 🟦***intervals***: the amount of keyframes to generate in total - the first will have its start_percent equal to start_percent, the last will have its start_percent equal to end_percent.
|
| 118 |
+
- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
|
| 119 |
+
- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
|
| 120 |
+
- 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
|
| 121 |
+
|
| 122 |
+
### Outputs
|
| 123 |
+
- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
|
| 124 |
+
|
| 125 |
+
## Timestep Keyframe From List
|
| 126 |
+

|
| 127 |
+
|
| 128 |
+
Allows to create Timestep Keyframe via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
|
| 129 |
+
|
| 130 |
+
### Inputs
|
| 131 |
+
- 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
|
| 132 |
+
- 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
|
| 133 |
+
- 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
|
| 134 |
+
- 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
|
| 135 |
+
- 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Timestep Keyframe; first will be assigned to start_percent, last will be assigned to end_percent, and the rest spread linearly between.
|
| 136 |
+
- 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
|
| 137 |
+
- 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
|
| 138 |
+
- 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
|
| 139 |
+
- 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
|
| 140 |
+
- 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
|
| 141 |
+
|
| 142 |
+
### Outputs
|
| 143 |
+
- 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
|
| 144 |
+
|
| 145 |
+
## Latent Keyframe
|
| 146 |
+

|
| 147 |
+
|
| 148 |
+
A singular Latent Keyframe, selects the strength for a specific batch_index. If batch_index is not present during sampling, will simply have no effect. Can be chained with any other Latent Keyframe-type node to create a latent keyframe schedule.
|
| 149 |
+
|
| 150 |
+
### Inputs
|
| 151 |
+
- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If a Latent Keyframe contained in prev_latent_keyframes have the same batch_index as this Latent Keyframe, they will take priority over this node's value.*
|
| 152 |
+
- 🟦***batch_index***: index of latent in batch to apply controlnet strength to. Acts as the 'key' for the Latent Keyframe in the latent keyframe schedule.
|
| 153 |
+
- 🟦***strength***: strength of controlnet to apply to the corresponding latent.
|
| 154 |
+
|
| 155 |
+
### Outputs
|
| 156 |
+
- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
|
| 157 |
+
|
| 158 |
+
## Latent Keyframe Group
|
| 159 |
+

|
| 160 |
+
|
| 161 |
+
Allows to create Latent Keyframes via individual indeces or python-style ranges.
|
| 162 |
+
|
| 163 |
+
### Inputs
|
| 164 |
+
- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
|
| 165 |
+
- 🟨***latent_optional***: the latents expected to be passed in for sampling; only required if you wish to use negative indeces (will be automatically converted to real values).
|
| 166 |
+
- 🟦***index_strengths***: string list of indeces or python-style ranges of indeces to assign strengths to. If latent_optional is passed in, can contain negative indeces or ranges that contain negative numbers, python-style. The different indeces must be comma separated. Individual latents can be specified by ```batch_index=strength```, like ```0=0.9```. Ranges can be specified by ```start_index_inclusive:end_index_exclusive=strength```, like ```0:8=strength```. Negative indeces are possible when latents_optional has an input, with a string such as ```0,-4=0.25```.
|
| 167 |
+
- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
|
| 168 |
+
|
| 169 |
+
### Outputs
|
| 170 |
+
- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
|
| 171 |
+
|
| 172 |
+
## Latent Keyframe Interpolation
|
| 173 |
+

|
| 174 |
+
|
| 175 |
+
Allows to create Latent Keyframes with interpolated values in a range.
|
| 176 |
+
|
| 177 |
+
### Inputs
|
| 178 |
+
- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
|
| 179 |
+
- 🟦***batch_index_from***: starting batch_index of range, included.
|
| 180 |
+
- 🟦***batch_index_to***: end batch_index of range, excluded (python-style range).
|
| 181 |
+
- 🟦***strength_from***: starting strength of interpolation.
|
| 182 |
+
- 🟦***strength_to***: end strength of interpolation.
|
| 183 |
+
- 🟦***interpolation***: the method of interpolation.
|
| 184 |
+
- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
|
| 185 |
+
|
| 186 |
+
### Outputs
|
| 187 |
+
- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
|
| 188 |
+
|
| 189 |
+
## Latent Keyframe From List
|
| 190 |
+

|
| 191 |
+
|
| 192 |
+
Allows to create Latent Keyframes via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes.
|
| 193 |
+
|
| 194 |
+
### Inputs
|
| 195 |
+
- 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
|
| 196 |
+
- 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Latent Keyframe; the batch_index is the index of each float value in the list.
|
| 197 |
+
- 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
|
| 198 |
+
|
| 199 |
+
### Outputs
|
| 200 |
+
- 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
|
| 201 |
+
|
| 202 |
+
# There are more nodes to document and show usage - will add this soon! TODO
|
ComfyUI-Advanced-ControlNet/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .adv_control.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
|
| 2 |
+
|
| 3 |
+
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
|
ComfyUI-Advanced-ControlNet/adv_control/control.py
ADDED
|
@@ -0,0 +1,860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Union
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import comfy.utils
|
| 7 |
+
import comfy.model_management
|
| 8 |
+
import comfy.model_detection
|
| 9 |
+
import comfy.controlnet as comfy_cn
|
| 10 |
+
from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, broadcast_image_to
|
| 11 |
+
from comfy.model_patcher import ModelPatcher
|
| 12 |
+
|
| 13 |
+
from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper
|
| 14 |
+
from .control_lllite import LLLiteModule, LLLitePatch
|
| 15 |
+
from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers
|
| 16 |
+
from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, ControlWeightType, ControlWeights, WeightTypeException,
|
| 17 |
+
manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory)
|
| 18 |
+
from .logger import logger
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ControlNetAdvanced(ControlNet, AdvancedControlBase):
|
| 22 |
+
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
| 23 |
+
super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
| 24 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
|
| 25 |
+
|
| 26 |
+
def get_universal_weights(self) -> ControlWeights:
|
| 27 |
+
raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)]
|
| 28 |
+
return self.weights.copy_with_new_weights(raw_weights)
|
| 29 |
+
|
| 30 |
+
def get_control_advanced(self, x_noisy, t, cond, batched_number):
|
| 31 |
+
# perform special version of get_control that supports sliding context and masks
|
| 32 |
+
return self.sliding_get_control(x_noisy, t, cond, batched_number)
|
| 33 |
+
|
| 34 |
+
def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
|
| 35 |
+
control_prev = None
|
| 36 |
+
if self.previous_controlnet is not None:
|
| 37 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
| 38 |
+
|
| 39 |
+
if self.timestep_range is not None:
|
| 40 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
| 41 |
+
if control_prev is not None:
|
| 42 |
+
return control_prev
|
| 43 |
+
else:
|
| 44 |
+
return None
|
| 45 |
+
|
| 46 |
+
dtype = self.control_model.dtype
|
| 47 |
+
if self.manual_cast_dtype is not None:
|
| 48 |
+
dtype = self.manual_cast_dtype
|
| 49 |
+
|
| 50 |
+
output_dtype = x_noisy.dtype
|
| 51 |
+
# make cond_hint appropriate dimensions
|
| 52 |
+
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
|
| 53 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
| 54 |
+
if self.cond_hint is not None:
|
| 55 |
+
del self.cond_hint
|
| 56 |
+
self.cond_hint = None
|
| 57 |
+
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
|
| 58 |
+
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
|
| 59 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
| 60 |
+
else:
|
| 61 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
| 62 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
| 63 |
+
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
| 64 |
+
|
| 65 |
+
# prepare mask_cond_hint
|
| 66 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
|
| 67 |
+
|
| 68 |
+
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
| 69 |
+
# uses 'y' in new ComfyUI update
|
| 70 |
+
y = cond.get('y', None)
|
| 71 |
+
if y is None: # TODO: remove this in the future since no longer used by newest ComfyUI
|
| 72 |
+
y = cond.get('c_adm', None)
|
| 73 |
+
if y is not None:
|
| 74 |
+
y = y.to(dtype)
|
| 75 |
+
timestep = self.model_sampling_current.timestep(t)
|
| 76 |
+
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
| 77 |
+
|
| 78 |
+
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
| 79 |
+
return self.control_merge(None, control, control_prev, output_dtype)
|
| 80 |
+
|
| 81 |
+
def copy(self):
|
| 82 |
+
c = ControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
| 83 |
+
self.copy_to(c)
|
| 84 |
+
self.copy_to_advanced(c)
|
| 85 |
+
return c
|
| 86 |
+
|
| 87 |
+
@staticmethod
|
| 88 |
+
def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
|
| 89 |
+
return ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
|
| 90 |
+
global_average_pooling=v.global_average_pooling, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase):
|
| 94 |
+
def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroup, channels_in, compression_ratio=8, upscale_algorithm="nearest_exact", device=None):
|
| 95 |
+
super().__init__(t2i_model=t2i_model, channels_in=channels_in, compression_ratio=compression_ratio, upscale_algorithm=upscale_algorithm, device=device)
|
| 96 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.t2iadapter())
|
| 97 |
+
|
| 98 |
+
def control_merge_inject(self, control_input, control_output, control_prev, output_dtype):
|
| 99 |
+
# if has uncond multiplier, need to make sure control shapes are the same batch size as expected
|
| 100 |
+
if self.weights.has_uncond_multiplier:
|
| 101 |
+
if control_input is not None:
|
| 102 |
+
for i in range(len(control_input)):
|
| 103 |
+
x = control_input[i]
|
| 104 |
+
if x is not None:
|
| 105 |
+
if x.size(0) < self.batch_size:
|
| 106 |
+
control_input[i] = x.repeat(self.batched_number, 1, 1, 1)[:self.batch_size]
|
| 107 |
+
if control_output is not None:
|
| 108 |
+
for i in range(len(control_output)):
|
| 109 |
+
x = control_output[i]
|
| 110 |
+
if x is not None:
|
| 111 |
+
if x.size(0) < self.batch_size:
|
| 112 |
+
control_output[i] = x.repeat(self.batched_number, 1, 1, 1)[:self.batch_size]
|
| 113 |
+
return AdvancedControlBase.control_merge_inject(self, control_input, control_output, control_prev, output_dtype)
|
| 114 |
+
|
| 115 |
+
def get_universal_weights(self) -> ControlWeights:
|
| 116 |
+
raw_weights = [(self.weights.base_multiplier ** float(7 - i)) for i in range(8)]
|
| 117 |
+
raw_weights = [raw_weights[-8], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
|
| 118 |
+
raw_weights = get_properly_arranged_t2i_weights(raw_weights)
|
| 119 |
+
return self.weights.copy_with_new_weights(raw_weights)
|
| 120 |
+
|
| 121 |
+
def get_calc_pow(self, idx: int, layers: int) -> int:
|
| 122 |
+
# match how T2IAdapterAdvanced deals with universal weights
|
| 123 |
+
indeces = [7 - i for i in range(8)]
|
| 124 |
+
indeces = [indeces[-8], indeces[-3], indeces[-2], indeces[-1]]
|
| 125 |
+
indeces = get_properly_arranged_t2i_weights(indeces)
|
| 126 |
+
return indeces[idx]
|
| 127 |
+
|
| 128 |
+
def get_control_advanced(self, x_noisy, t, cond, batched_number):
|
| 129 |
+
try:
|
| 130 |
+
# if sub indexes present, replace original hint with subsection
|
| 131 |
+
if self.sub_idxs is not None:
|
| 132 |
+
# cond hints
|
| 133 |
+
full_cond_hint_original = self.cond_hint_original
|
| 134 |
+
del self.cond_hint
|
| 135 |
+
self.cond_hint = None
|
| 136 |
+
self.cond_hint_original = full_cond_hint_original[self.sub_idxs]
|
| 137 |
+
# mask hints
|
| 138 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
|
| 139 |
+
return super().get_control(x_noisy, t, cond, batched_number)
|
| 140 |
+
finally:
|
| 141 |
+
if self.sub_idxs is not None:
|
| 142 |
+
# replace original cond hint
|
| 143 |
+
self.cond_hint_original = full_cond_hint_original
|
| 144 |
+
del full_cond_hint_original
|
| 145 |
+
|
| 146 |
+
def copy(self):
|
| 147 |
+
c = T2IAdapterAdvanced(self.t2i_model, self.timestep_keyframes, self.channels_in, self.compression_ratio, self.upscale_algorithm)
|
| 148 |
+
self.copy_to(c)
|
| 149 |
+
self.copy_to_advanced(c)
|
| 150 |
+
return c
|
| 151 |
+
|
| 152 |
+
def cleanup(self):
|
| 153 |
+
super().cleanup()
|
| 154 |
+
self.cleanup_advanced()
|
| 155 |
+
|
| 156 |
+
@staticmethod
|
| 157 |
+
def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced':
|
| 158 |
+
return T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in,
|
| 159 |
+
compression_ratio=v.compression_ratio, upscale_algorithm=v.upscale_algorithm, device=v.device)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class ControlLoraAdvanced(ControlLora, AdvancedControlBase):
|
| 163 |
+
def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None):
|
| 164 |
+
super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling, device=device)
|
| 165 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllora())
|
| 166 |
+
# use some functions from ControlNetAdvanced
|
| 167 |
+
self.get_control_advanced = ControlNetAdvanced.get_control_advanced.__get__(self, type(self))
|
| 168 |
+
self.sliding_get_control = ControlNetAdvanced.sliding_get_control.__get__(self, type(self))
|
| 169 |
+
|
| 170 |
+
def get_universal_weights(self) -> ControlWeights:
|
| 171 |
+
raw_weights = [(self.weights.base_multiplier ** float(9 - i)) for i in range(10)]
|
| 172 |
+
return self.weights.copy_with_new_weights(raw_weights)
|
| 173 |
+
|
| 174 |
+
def copy(self):
|
| 175 |
+
c = ControlLoraAdvanced(self.control_weights, self.timestep_keyframes, global_average_pooling=self.global_average_pooling)
|
| 176 |
+
self.copy_to(c)
|
| 177 |
+
self.copy_to_advanced(c)
|
| 178 |
+
return c
|
| 179 |
+
|
| 180 |
+
def cleanup(self):
|
| 181 |
+
super().cleanup()
|
| 182 |
+
self.cleanup_advanced()
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced':
|
| 186 |
+
return ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
|
| 187 |
+
global_average_pooling=v.global_average_pooling, device=v.device)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
class SVDControlNetAdvanced(ControlNetAdvanced):
|
| 191 |
+
def __init__(self, control_model: SVDControlNet, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
| 192 |
+
super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
| 193 |
+
|
| 194 |
+
def set_cond_hint(self, *args, **kwargs):
|
| 195 |
+
to_return = super().set_cond_hint(*args, **kwargs)
|
| 196 |
+
# cond hint for SVD-ControlNet needs to be scaled between (-1, 1) instead of (0, 1)
|
| 197 |
+
self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
|
| 198 |
+
return to_return
|
| 199 |
+
|
| 200 |
+
def get_control_advanced(self, x_noisy, t, cond, batched_number):
|
| 201 |
+
control_prev = None
|
| 202 |
+
if self.previous_controlnet is not None:
|
| 203 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
| 204 |
+
|
| 205 |
+
if self.timestep_range is not None:
|
| 206 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
| 207 |
+
if control_prev is not None:
|
| 208 |
+
return control_prev
|
| 209 |
+
else:
|
| 210 |
+
return None
|
| 211 |
+
|
| 212 |
+
dtype = self.control_model.dtype
|
| 213 |
+
if self.manual_cast_dtype is not None:
|
| 214 |
+
dtype = self.manual_cast_dtype
|
| 215 |
+
|
| 216 |
+
output_dtype = x_noisy.dtype
|
| 217 |
+
# make cond_hint appropriate dimensions
|
| 218 |
+
# TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
|
| 219 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
| 220 |
+
if self.cond_hint is not None:
|
| 221 |
+
del self.cond_hint
|
| 222 |
+
self.cond_hint = None
|
| 223 |
+
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
|
| 224 |
+
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
|
| 225 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
| 226 |
+
else:
|
| 227 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
| 228 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
| 229 |
+
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
| 230 |
+
|
| 231 |
+
# prepare mask_cond_hint
|
| 232 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
|
| 233 |
+
|
| 234 |
+
context = cond.get('crossattn_controlnet', cond['c_crossattn'])
|
| 235 |
+
# uses 'y' in new ComfyUI update
|
| 236 |
+
y = cond.get('y', None)
|
| 237 |
+
if y is not None:
|
| 238 |
+
y = y.to(dtype)
|
| 239 |
+
timestep = self.model_sampling_current.timestep(t)
|
| 240 |
+
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
| 241 |
+
# concat c_concat if exists (should exist for SVD), doubling channels to 8
|
| 242 |
+
if cond.get('c_concat', None) is not None:
|
| 243 |
+
x_noisy = torch.cat([x_noisy] + [cond['c_concat']], dim=1)
|
| 244 |
+
|
| 245 |
+
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, cond=cond)
|
| 246 |
+
return self.control_merge(None, control, control_prev, output_dtype)
|
| 247 |
+
|
| 248 |
+
def copy(self):
|
| 249 |
+
c = SVDControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
| 250 |
+
self.copy_to(c)
|
| 251 |
+
self.copy_to_advanced(c)
|
| 252 |
+
return c
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
class SparseCtrlAdvanced(ControlNetAdvanced):
|
| 256 |
+
def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, sparse_settings: SparseSettings=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
|
| 257 |
+
super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
| 258 |
+
self.control_model_wrapped = SparseModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
|
| 259 |
+
self.add_compatible_weight(ControlWeightType.SPARSECTRL)
|
| 260 |
+
self.control_model: SparseControlNet = self.control_model # does nothing except help with IDE hints
|
| 261 |
+
self.sparse_settings = sparse_settings if sparse_settings is not None else SparseSettings.default()
|
| 262 |
+
self.latent_format = None
|
| 263 |
+
self.preprocessed = False
|
| 264 |
+
|
| 265 |
+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
|
| 266 |
+
# normal ControlNet stuff
|
| 267 |
+
control_prev = None
|
| 268 |
+
if self.previous_controlnet is not None:
|
| 269 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
| 270 |
+
|
| 271 |
+
if self.timestep_range is not None:
|
| 272 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
| 273 |
+
if control_prev is not None:
|
| 274 |
+
return control_prev
|
| 275 |
+
else:
|
| 276 |
+
return None
|
| 277 |
+
|
| 278 |
+
dtype = self.control_model.dtype
|
| 279 |
+
if self.manual_cast_dtype is not None:
|
| 280 |
+
dtype = self.manual_cast_dtype
|
| 281 |
+
output_dtype = x_noisy.dtype
|
| 282 |
+
# set actual input length on motion model
|
| 283 |
+
actual_length = x_noisy.size(0)//batched_number
|
| 284 |
+
full_length = actual_length if self.sub_idxs is None else self.full_latent_length
|
| 285 |
+
self.control_model.set_actual_length(actual_length=actual_length, full_length=full_length)
|
| 286 |
+
# prepare cond_hint, if needed
|
| 287 |
+
dim_mult = 1 if self.control_model.use_simplified_conditioning_embedding else 8
|
| 288 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2]*dim_mult != self.cond_hint.shape[2] or x_noisy.shape[3]*dim_mult != self.cond_hint.shape[3]:
|
| 289 |
+
# clear out cond_hint and conditioning_mask
|
| 290 |
+
if self.cond_hint is not None:
|
| 291 |
+
del self.cond_hint
|
| 292 |
+
self.cond_hint = None
|
| 293 |
+
# first, figure out which cond idxs are relevant, and where they fit in
|
| 294 |
+
cond_idxs = self.sparse_settings.sparse_method.get_indexes(hint_length=self.cond_hint_original.size(0), full_length=full_length)
|
| 295 |
+
|
| 296 |
+
range_idxs = list(range(full_length)) if self.sub_idxs is None else self.sub_idxs
|
| 297 |
+
hint_idxs = [] # idxs in cond_idxs
|
| 298 |
+
local_idxs = [] # idx to pun in final cond_hint
|
| 299 |
+
for i,cond_idx in enumerate(cond_idxs):
|
| 300 |
+
if cond_idx in range_idxs:
|
| 301 |
+
hint_idxs.append(i)
|
| 302 |
+
local_idxs.append(range_idxs.index(cond_idx))
|
| 303 |
+
# sub_cond_hint now contains the hints relevant to current x_noisy
|
| 304 |
+
sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(self.device)
|
| 305 |
+
|
| 306 |
+
# scale cond_hints to match noisy input
|
| 307 |
+
if self.control_model.use_simplified_conditioning_embedding:
|
| 308 |
+
# RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts
|
| 309 |
+
sub_cond_hint = self.latent_format.process_in(sub_cond_hint) # multiplies by model scale factor
|
| 310 |
+
sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3], x_noisy.shape[2], "nearest-exact", "center").to(dtype).to(self.device)
|
| 311 |
+
else:
|
| 312 |
+
# other SparseCtrl; inputs are typical images
|
| 313 |
+
sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
| 314 |
+
# prepare cond_hint (b, c, h ,w)
|
| 315 |
+
cond_shape = list(sub_cond_hint.shape)
|
| 316 |
+
cond_shape[0] = len(range_idxs)
|
| 317 |
+
self.cond_hint = torch.zeros(cond_shape).to(dtype).to(self.device)
|
| 318 |
+
self.cond_hint[local_idxs] = sub_cond_hint[:]
|
| 319 |
+
# prepare cond_mask (b, 1, h, w)
|
| 320 |
+
cond_shape[1] = 1
|
| 321 |
+
cond_mask = torch.zeros(cond_shape).to(dtype).to(self.device)
|
| 322 |
+
cond_mask[local_idxs] = 1.0
|
| 323 |
+
# combine cond_hint and cond_mask into (b, c+1, h, w)
|
| 324 |
+
if not self.sparse_settings.merged:
|
| 325 |
+
self.cond_hint = torch.cat([self.cond_hint, cond_mask], dim=1)
|
| 326 |
+
del sub_cond_hint
|
| 327 |
+
del cond_mask
|
| 328 |
+
# make cond_hint match x_noisy batch
|
| 329 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
| 330 |
+
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
| 331 |
+
|
| 332 |
+
# prepare mask_cond_hint
|
| 333 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
|
| 334 |
+
|
| 335 |
+
context = cond['c_crossattn']
|
| 336 |
+
y = cond.get('y', None)
|
| 337 |
+
if y is not None:
|
| 338 |
+
y = y.to(dtype)
|
| 339 |
+
timestep = self.model_sampling_current.timestep(t)
|
| 340 |
+
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
| 341 |
+
|
| 342 |
+
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
| 343 |
+
return self.control_merge(None, control, control_prev, output_dtype)
|
| 344 |
+
|
| 345 |
+
def pre_run_advanced(self, model, percent_to_timestep_function):
|
| 346 |
+
super().pre_run_advanced(model, percent_to_timestep_function)
|
| 347 |
+
if type(self.cond_hint_original) == PreprocSparseRGBWrapper:
|
| 348 |
+
if not self.control_model.use_simplified_conditioning_embedding:
|
| 349 |
+
raise ValueError("Any model besides RGB SparseCtrl should NOT have its images go through the RGB SparseCtrl preprocessor.")
|
| 350 |
+
self.cond_hint_original = self.cond_hint_original.condhint
|
| 351 |
+
self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint
|
| 352 |
+
if self.control_model.motion_wrapper is not None:
|
| 353 |
+
self.control_model.motion_wrapper.reset()
|
| 354 |
+
self.control_model.motion_wrapper.set_strength(self.sparse_settings.motion_strength)
|
| 355 |
+
self.control_model.motion_wrapper.set_scale_multiplier(self.sparse_settings.motion_scale)
|
| 356 |
+
|
| 357 |
+
def cleanup_advanced(self):
|
| 358 |
+
super().cleanup_advanced()
|
| 359 |
+
if self.latent_format is not None:
|
| 360 |
+
del self.latent_format
|
| 361 |
+
self.latent_format = None
|
| 362 |
+
|
| 363 |
+
def copy(self):
|
| 364 |
+
c = SparseCtrlAdvanced(self.control_model, self.timestep_keyframes, self.sparse_settings, self.global_average_pooling, self.device, self.load_device, self.manual_cast_dtype)
|
| 365 |
+
self.copy_to(c)
|
| 366 |
+
self.copy_to_advanced(c)
|
| 367 |
+
return c
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase):
|
| 371 |
+
# This ControlNet is more of an attention patch than a traditional controlnet
|
| 372 |
+
def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device=None):
|
| 373 |
+
super().__init__(device)
|
| 374 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), require_model=True)
|
| 375 |
+
self.patch_attn1 = patch_attn1.set_control(self)
|
| 376 |
+
self.patch_attn2 = patch_attn2.set_control(self)
|
| 377 |
+
self.latent_dims_div2 = None
|
| 378 |
+
self.latent_dims_div4 = None
|
| 379 |
+
|
| 380 |
+
def patch_model(self, model: ModelPatcher):
|
| 381 |
+
model.set_model_attn1_patch(self.patch_attn1)
|
| 382 |
+
model.set_model_attn2_patch(self.patch_attn2)
|
| 383 |
+
|
| 384 |
+
def set_cond_hint(self, *args, **kwargs):
|
| 385 |
+
to_return = super().set_cond_hint(*args, **kwargs)
|
| 386 |
+
# cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
|
| 387 |
+
self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
|
| 388 |
+
return to_return
|
| 389 |
+
|
| 390 |
+
def pre_run_advanced(self, *args, **kwargs):
|
| 391 |
+
AdvancedControlBase.pre_run_advanced(self, *args, **kwargs)
|
| 392 |
+
#logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
|
| 393 |
+
self.patch_attn1.set_control(self)
|
| 394 |
+
self.patch_attn2.set_control(self)
|
| 395 |
+
#logger.warn(f"in pre_run_advanced: {id(self)}")
|
| 396 |
+
|
| 397 |
+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
|
| 398 |
+
# normal ControlNet stuff
|
| 399 |
+
control_prev = None
|
| 400 |
+
if self.previous_controlnet is not None:
|
| 401 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
| 402 |
+
|
| 403 |
+
if self.timestep_range is not None:
|
| 404 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
| 405 |
+
return control_prev
|
| 406 |
+
|
| 407 |
+
dtype = x_noisy.dtype
|
| 408 |
+
# prepare cond_hint
|
| 409 |
+
if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
| 410 |
+
if self.cond_hint is not None:
|
| 411 |
+
del self.cond_hint
|
| 412 |
+
self.cond_hint = None
|
| 413 |
+
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
|
| 414 |
+
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
|
| 415 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
| 416 |
+
else:
|
| 417 |
+
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
| 418 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
| 419 |
+
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
| 420 |
+
# some special logic here compared to other controlnets:
|
| 421 |
+
# * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
|
| 422 |
+
# * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
|
| 423 |
+
divisible_by_2_h = x_noisy.shape[2]%2==0
|
| 424 |
+
divisible_by_2_w = x_noisy.shape[3]%2==0
|
| 425 |
+
if not (divisible_by_2_h and divisible_by_2_w):
|
| 426 |
+
#logger.warn(f"{x_noisy.shape} not divisible by 2!")
|
| 427 |
+
new_h = (x_noisy.shape[2]//2)*2
|
| 428 |
+
new_w = (x_noisy.shape[3]//2)*2
|
| 429 |
+
if not divisible_by_2_h:
|
| 430 |
+
new_h += 2
|
| 431 |
+
if not divisible_by_2_w:
|
| 432 |
+
new_w += 2
|
| 433 |
+
self.latent_dims_div2 = (new_h, new_w)
|
| 434 |
+
divisible_by_4_h = x_noisy.shape[2]%4==0
|
| 435 |
+
divisible_by_4_w = x_noisy.shape[3]%4==0
|
| 436 |
+
if not (divisible_by_4_h and divisible_by_4_w):
|
| 437 |
+
#logger.warn(f"{x_noisy.shape} not divisible by 4!")
|
| 438 |
+
new_h = (x_noisy.shape[2]//4)*4
|
| 439 |
+
new_w = (x_noisy.shape[3]//4)*4
|
| 440 |
+
if not divisible_by_4_h:
|
| 441 |
+
new_h += 4
|
| 442 |
+
if not divisible_by_4_w:
|
| 443 |
+
new_w += 4
|
| 444 |
+
self.latent_dims_div4 = (new_h, new_w)
|
| 445 |
+
# prepare mask
|
| 446 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
|
| 447 |
+
# done preparing; model patches will take care of everything now.
|
| 448 |
+
# return normal controlnet stuff
|
| 449 |
+
return control_prev
|
| 450 |
+
|
| 451 |
+
def cleanup_advanced(self):
|
| 452 |
+
super().cleanup_advanced()
|
| 453 |
+
self.patch_attn1.cleanup()
|
| 454 |
+
self.patch_attn2.cleanup()
|
| 455 |
+
self.latent_dims_div2 = None
|
| 456 |
+
self.latent_dims_div4 = None
|
| 457 |
+
|
| 458 |
+
def copy(self):
|
| 459 |
+
c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes)
|
| 460 |
+
self.copy_to(c)
|
| 461 |
+
self.copy_to_advanced(c)
|
| 462 |
+
return c
|
| 463 |
+
|
| 464 |
+
# deepcopy needs to properly keep track of objects to work between model.clone calls!
|
| 465 |
+
# def __deepcopy__(self, *args, **kwargs):
|
| 466 |
+
# self.cleanup_advanced()
|
| 467 |
+
# return self
|
| 468 |
+
|
| 469 |
+
# def get_models(self):
|
| 470 |
+
# # get_models is called once at the start of every KSampler run - use to reset already_patched status
|
| 471 |
+
# out = super().get_models()
|
| 472 |
+
# logger.error(f"in get_models! {id(self)}")
|
| 473 |
+
# return out
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
|
| 477 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
| 478 |
+
control = None
|
| 479 |
+
# check if a non-vanilla ControlNet
|
| 480 |
+
controlnet_type = ControlWeightType.DEFAULT
|
| 481 |
+
has_controlnet_key = False
|
| 482 |
+
has_motion_modules_key = False
|
| 483 |
+
has_temporal_res_block_key = False
|
| 484 |
+
for key in controlnet_data:
|
| 485 |
+
# LLLite check
|
| 486 |
+
if "lllite" in key:
|
| 487 |
+
controlnet_type = ControlWeightType.CONTROLLLLITE
|
| 488 |
+
break
|
| 489 |
+
# SparseCtrl check
|
| 490 |
+
elif "motion_modules" in key:
|
| 491 |
+
has_motion_modules_key = True
|
| 492 |
+
elif "controlnet" in key:
|
| 493 |
+
has_controlnet_key = True
|
| 494 |
+
# SVD-ControlNet check
|
| 495 |
+
elif "temporal_res_block" in key:
|
| 496 |
+
has_temporal_res_block_key = True
|
| 497 |
+
if has_controlnet_key and has_motion_modules_key:
|
| 498 |
+
controlnet_type = ControlWeightType.SPARSECTRL
|
| 499 |
+
elif has_controlnet_key and has_temporal_res_block_key:
|
| 500 |
+
controlnet_type = ControlWeightType.SVD_CONTROLNET
|
| 501 |
+
|
| 502 |
+
if controlnet_type != ControlWeightType.DEFAULT:
|
| 503 |
+
if controlnet_type == ControlWeightType.CONTROLLLLITE:
|
| 504 |
+
control = load_controllllite(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe)
|
| 505 |
+
elif controlnet_type == ControlWeightType.SPARSECTRL:
|
| 506 |
+
control = load_sparsectrl(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe, model=model)
|
| 507 |
+
elif controlnet_type == ControlWeightType.SVD_CONTROLNET:
|
| 508 |
+
control = load_svdcontrolnet(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe)
|
| 509 |
+
#raise Exception(f"SVD-ControlNet is not supported yet!")
|
| 510 |
+
#control = comfy_cn.load_controlnet(ckpt_path, model=model)
|
| 511 |
+
# otherwise, load vanilla ControlNet
|
| 512 |
+
else:
|
| 513 |
+
try:
|
| 514 |
+
# hacky way of getting load_torch_file in load_controlnet to use already-present controlnet_data and not redo loading
|
| 515 |
+
orig_load_torch_file = comfy.utils.load_torch_file
|
| 516 |
+
comfy.utils.load_torch_file = load_torch_file_with_dict_factory(controlnet_data, orig_load_torch_file)
|
| 517 |
+
control = comfy_cn.load_controlnet(ckpt_path, model=model)
|
| 518 |
+
finally:
|
| 519 |
+
comfy.utils.load_torch_file = orig_load_torch_file
|
| 520 |
+
return convert_to_advanced(control, timestep_keyframe=timestep_keyframe)
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None):
|
| 524 |
+
# if already advanced, leave it be
|
| 525 |
+
if is_advanced_controlnet(control):
|
| 526 |
+
return control
|
| 527 |
+
# if exactly ControlNet returned, transform it into ControlNetAdvanced
|
| 528 |
+
if type(control) == ControlNet:
|
| 529 |
+
return ControlNetAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
|
| 530 |
+
# if exactly ControlLora returned, transform it into ControlLoraAdvanced
|
| 531 |
+
elif type(control) == ControlLora:
|
| 532 |
+
return ControlLoraAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
|
| 533 |
+
# if T2IAdapter returned, transform it into T2IAdapterAdvanced
|
| 534 |
+
elif isinstance(control, T2IAdapter):
|
| 535 |
+
return T2IAdapterAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
|
| 536 |
+
# otherwise, leave it be - might be something I am not supporting yet
|
| 537 |
+
return control
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def is_advanced_controlnet(input_object):
|
| 541 |
+
return hasattr(input_object, "sub_idxs")
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, sparse_settings=SparseSettings.default(), model=None) -> SparseCtrlAdvanced:
|
| 545 |
+
if controlnet_data is None:
|
| 546 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
| 547 |
+
# first, separate out motion part from normal controlnet part and attempt to load that portion
|
| 548 |
+
motion_data = {}
|
| 549 |
+
for key in list(controlnet_data.keys()):
|
| 550 |
+
if "temporal" in key:
|
| 551 |
+
motion_data[key] = controlnet_data.pop(key)
|
| 552 |
+
if len(motion_data) == 0:
|
| 553 |
+
raise ValueError(f"No motion-related keys in '{ckpt_path}'; not a valid SparseCtrl model!")
|
| 554 |
+
motion_wrapper: SparseCtrlMotionWrapper = SparseCtrlMotionWrapper(motion_data).to(comfy.model_management.unet_dtype())
|
| 555 |
+
missing, unexpected = motion_wrapper.load_state_dict(motion_data)
|
| 556 |
+
if len(missing) > 0 or len(unexpected) > 0:
|
| 557 |
+
logger.info(f"SparseCtrlMotionWrapper: {missing}, {unexpected}")
|
| 558 |
+
|
| 559 |
+
# now, load as if it was a normal controlnet - mostly copied from comfy load_controlnet function
|
| 560 |
+
controlnet_config = None
|
| 561 |
+
is_diffusers = False
|
| 562 |
+
use_simplified_conditioning_embedding = False
|
| 563 |
+
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data:
|
| 564 |
+
is_diffusers = True
|
| 565 |
+
if "controlnet_cond_embedding.weight" in controlnet_data:
|
| 566 |
+
is_diffusers = True
|
| 567 |
+
use_simplified_conditioning_embedding = True
|
| 568 |
+
if is_diffusers: #diffusers format
|
| 569 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
| 570 |
+
controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
| 571 |
+
diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
|
| 572 |
+
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
| 573 |
+
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
| 574 |
+
|
| 575 |
+
count = 0
|
| 576 |
+
loop = True
|
| 577 |
+
while loop:
|
| 578 |
+
suffix = [".weight", ".bias"]
|
| 579 |
+
for s in suffix:
|
| 580 |
+
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
| 581 |
+
k_out = "zero_convs.{}.0{}".format(count, s)
|
| 582 |
+
if k_in not in controlnet_data:
|
| 583 |
+
loop = False
|
| 584 |
+
break
|
| 585 |
+
diffusers_keys[k_in] = k_out
|
| 586 |
+
count += 1
|
| 587 |
+
# normal conditioning embedding
|
| 588 |
+
if not use_simplified_conditioning_embedding:
|
| 589 |
+
count = 0
|
| 590 |
+
loop = True
|
| 591 |
+
while loop:
|
| 592 |
+
suffix = [".weight", ".bias"]
|
| 593 |
+
for s in suffix:
|
| 594 |
+
if count == 0:
|
| 595 |
+
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
| 596 |
+
else:
|
| 597 |
+
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
| 598 |
+
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
| 599 |
+
if k_in not in controlnet_data:
|
| 600 |
+
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
| 601 |
+
loop = False
|
| 602 |
+
diffusers_keys[k_in] = k_out
|
| 603 |
+
count += 1
|
| 604 |
+
# simplified conditioning embedding
|
| 605 |
+
else:
|
| 606 |
+
count = 0
|
| 607 |
+
suffix = [".weight", ".bias"]
|
| 608 |
+
for s in suffix:
|
| 609 |
+
k_in = "controlnet_cond_embedding{}".format(s)
|
| 610 |
+
k_out = "input_hint_block.{}{}".format(count, s)
|
| 611 |
+
diffusers_keys[k_in] = k_out
|
| 612 |
+
|
| 613 |
+
new_sd = {}
|
| 614 |
+
for k in diffusers_keys:
|
| 615 |
+
if k in controlnet_data:
|
| 616 |
+
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
| 617 |
+
|
| 618 |
+
leftover_keys = controlnet_data.keys()
|
| 619 |
+
if len(leftover_keys) > 0:
|
| 620 |
+
logger.info("leftover keys:", leftover_keys)
|
| 621 |
+
controlnet_data = new_sd
|
| 622 |
+
|
| 623 |
+
pth_key = 'control_model.zero_convs.0.0.weight'
|
| 624 |
+
pth = False
|
| 625 |
+
key = 'zero_convs.0.0.weight'
|
| 626 |
+
if pth_key in controlnet_data:
|
| 627 |
+
pth = True
|
| 628 |
+
key = pth_key
|
| 629 |
+
prefix = "control_model."
|
| 630 |
+
elif key in controlnet_data:
|
| 631 |
+
prefix = ""
|
| 632 |
+
else:
|
| 633 |
+
raise ValueError("The provided model is not a valid SparseCtrl model! [ErrorCode: HORSERADISH]")
|
| 634 |
+
|
| 635 |
+
if controlnet_config is None:
|
| 636 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
| 637 |
+
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
| 638 |
+
load_device = comfy.model_management.get_torch_device()
|
| 639 |
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
| 640 |
+
if manual_cast_dtype is not None:
|
| 641 |
+
controlnet_config["operations"] = manual_cast_clean_groupnorm
|
| 642 |
+
else:
|
| 643 |
+
controlnet_config["operations"] = disable_weight_init_clean_groupnorm
|
| 644 |
+
controlnet_config.pop("out_channels")
|
| 645 |
+
# get proper hint channels
|
| 646 |
+
if use_simplified_conditioning_embedding:
|
| 647 |
+
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
| 648 |
+
controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding
|
| 649 |
+
else:
|
| 650 |
+
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
| 651 |
+
controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding
|
| 652 |
+
control_model = SparseControlNet(**controlnet_config)
|
| 653 |
+
|
| 654 |
+
if pth:
|
| 655 |
+
if 'difference' in controlnet_data:
|
| 656 |
+
if model is not None:
|
| 657 |
+
comfy.model_management.load_models_gpu([model])
|
| 658 |
+
model_sd = model.model_state_dict()
|
| 659 |
+
for x in controlnet_data:
|
| 660 |
+
c_m = "control_model."
|
| 661 |
+
if x.startswith(c_m):
|
| 662 |
+
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
| 663 |
+
if sd_key in model_sd:
|
| 664 |
+
cd = controlnet_data[x]
|
| 665 |
+
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
| 666 |
+
else:
|
| 667 |
+
logger.warning("WARNING: Loaded a diff SparseCtrl without a model. It will very likely not work.")
|
| 668 |
+
|
| 669 |
+
class WeightsLoader(torch.nn.Module):
|
| 670 |
+
pass
|
| 671 |
+
w = WeightsLoader()
|
| 672 |
+
w.control_model = control_model
|
| 673 |
+
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
| 674 |
+
else:
|
| 675 |
+
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
| 676 |
+
if len(missing) > 0 or len(unexpected) > 0:
|
| 677 |
+
logger.info(f"SparseCtrl ControlNet: {missing}, {unexpected}")
|
| 678 |
+
|
| 679 |
+
global_average_pooling = False
|
| 680 |
+
filename = os.path.splitext(ckpt_path)[0]
|
| 681 |
+
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
| 682 |
+
global_average_pooling = True
|
| 683 |
+
|
| 684 |
+
# both motion portion and controlnet portions are loaded; bring them together if using motion model
|
| 685 |
+
if sparse_settings.use_motion:
|
| 686 |
+
motion_wrapper.inject(control_model)
|
| 687 |
+
|
| 688 |
+
control = SparseCtrlAdvanced(control_model, timestep_keyframes=timestep_keyframe, sparse_settings=sparse_settings, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
| 689 |
+
return control
|
| 690 |
+
|
| 691 |
+
|
| 692 |
+
def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):
|
| 693 |
+
if controlnet_data is None:
|
| 694 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
| 695 |
+
# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
|
| 696 |
+
# first, split weights for each module
|
| 697 |
+
module_weights = {}
|
| 698 |
+
for key, value in controlnet_data.items():
|
| 699 |
+
fragments = key.split(".")
|
| 700 |
+
module_name = fragments[0]
|
| 701 |
+
weight_name = ".".join(fragments[1:])
|
| 702 |
+
|
| 703 |
+
if module_name not in module_weights:
|
| 704 |
+
module_weights[module_name] = {}
|
| 705 |
+
module_weights[module_name][weight_name] = value
|
| 706 |
+
|
| 707 |
+
# next, load each module
|
| 708 |
+
modules = {}
|
| 709 |
+
for module_name, weights in module_weights.items():
|
| 710 |
+
# kohya planned to do something about how these should be chosen, so I'm not touching this
|
| 711 |
+
# since I am not familiar with the logic for this
|
| 712 |
+
if "conditioning1.4.weight" in weights:
|
| 713 |
+
depth = 3
|
| 714 |
+
elif weights["conditioning1.2.weight"].shape[-1] == 4:
|
| 715 |
+
depth = 2
|
| 716 |
+
else:
|
| 717 |
+
depth = 1
|
| 718 |
+
|
| 719 |
+
module = LLLiteModule(
|
| 720 |
+
name=module_name,
|
| 721 |
+
is_conv2d=weights["down.0.weight"].ndim == 4,
|
| 722 |
+
in_dim=weights["down.0.weight"].shape[1],
|
| 723 |
+
depth=depth,
|
| 724 |
+
cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
|
| 725 |
+
mlp_dim=weights["down.0.weight"].shape[0],
|
| 726 |
+
)
|
| 727 |
+
# load weights into module
|
| 728 |
+
module.load_state_dict(weights)
|
| 729 |
+
modules[module_name] = module
|
| 730 |
+
if len(modules) == 1:
|
| 731 |
+
module.is_first = True
|
| 732 |
+
|
| 733 |
+
#logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")
|
| 734 |
+
|
| 735 |
+
patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1)
|
| 736 |
+
patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2)
|
| 737 |
+
control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe)
|
| 738 |
+
return control
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def load_svdcontrolnet(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
|
| 742 |
+
if controlnet_data is None:
|
| 743 |
+
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
|
| 744 |
+
|
| 745 |
+
controlnet_config = None
|
| 746 |
+
if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
|
| 747 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
| 748 |
+
controlnet_config = svd_unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
|
| 749 |
+
diffusers_keys = svd_unet_to_diffusers(controlnet_config)
|
| 750 |
+
diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
|
| 751 |
+
diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
|
| 752 |
+
|
| 753 |
+
count = 0
|
| 754 |
+
loop = True
|
| 755 |
+
while loop:
|
| 756 |
+
suffix = [".weight", ".bias"]
|
| 757 |
+
for s in suffix:
|
| 758 |
+
k_in = "controlnet_down_blocks.{}{}".format(count, s)
|
| 759 |
+
k_out = "zero_convs.{}.0{}".format(count, s)
|
| 760 |
+
if k_in not in controlnet_data:
|
| 761 |
+
loop = False
|
| 762 |
+
break
|
| 763 |
+
diffusers_keys[k_in] = k_out
|
| 764 |
+
count += 1
|
| 765 |
+
|
| 766 |
+
count = 0
|
| 767 |
+
loop = True
|
| 768 |
+
while loop:
|
| 769 |
+
suffix = [".weight", ".bias"]
|
| 770 |
+
for s in suffix:
|
| 771 |
+
if count == 0:
|
| 772 |
+
k_in = "controlnet_cond_embedding.conv_in{}".format(s)
|
| 773 |
+
else:
|
| 774 |
+
k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
|
| 775 |
+
k_out = "input_hint_block.{}{}".format(count * 2, s)
|
| 776 |
+
if k_in not in controlnet_data:
|
| 777 |
+
k_in = "controlnet_cond_embedding.conv_out{}".format(s)
|
| 778 |
+
loop = False
|
| 779 |
+
diffusers_keys[k_in] = k_out
|
| 780 |
+
count += 1
|
| 781 |
+
|
| 782 |
+
new_sd = {}
|
| 783 |
+
for k in diffusers_keys:
|
| 784 |
+
if k in controlnet_data:
|
| 785 |
+
new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
|
| 786 |
+
|
| 787 |
+
leftover_keys = controlnet_data.keys()
|
| 788 |
+
if len(leftover_keys) > 0:
|
| 789 |
+
spatial_leftover_keys = []
|
| 790 |
+
temporal_leftover_keys = []
|
| 791 |
+
other_leftover_keys = []
|
| 792 |
+
for key in leftover_keys:
|
| 793 |
+
if "spatial" in key:
|
| 794 |
+
spatial_leftover_keys.append(key)
|
| 795 |
+
elif "temporal" in key:
|
| 796 |
+
temporal_leftover_keys.append(key)
|
| 797 |
+
else:
|
| 798 |
+
other_leftover_keys.append(key)
|
| 799 |
+
logger.warn(f"spatial_leftover_keys ({len(spatial_leftover_keys)}): {spatial_leftover_keys}")
|
| 800 |
+
logger.warn(f"temporal_leftover_keys ({len(temporal_leftover_keys)}): {temporal_leftover_keys}")
|
| 801 |
+
logger.warn(f"other_leftover_keys ({len(other_leftover_keys)}): {other_leftover_keys}")
|
| 802 |
+
#print("leftover keys:", leftover_keys)
|
| 803 |
+
controlnet_data = new_sd
|
| 804 |
+
|
| 805 |
+
pth_key = 'control_model.zero_convs.0.0.weight'
|
| 806 |
+
pth = False
|
| 807 |
+
key = 'zero_convs.0.0.weight'
|
| 808 |
+
if pth_key in controlnet_data:
|
| 809 |
+
pth = True
|
| 810 |
+
key = pth_key
|
| 811 |
+
prefix = "control_model."
|
| 812 |
+
elif key in controlnet_data:
|
| 813 |
+
prefix = ""
|
| 814 |
+
else:
|
| 815 |
+
raise ValueError("The provided model is not a valid SVD-ControlNet model! [ErrorCode: MUSTARD]")
|
| 816 |
+
|
| 817 |
+
if controlnet_config is None:
|
| 818 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
| 819 |
+
controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
|
| 820 |
+
load_device = comfy.model_management.get_torch_device()
|
| 821 |
+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
| 822 |
+
if manual_cast_dtype is not None:
|
| 823 |
+
controlnet_config["operations"] = comfy.ops.manual_cast
|
| 824 |
+
controlnet_config.pop("out_channels")
|
| 825 |
+
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
| 826 |
+
control_model = SVDControlNet(**controlnet_config)
|
| 827 |
+
|
| 828 |
+
if pth:
|
| 829 |
+
if 'difference' in controlnet_data:
|
| 830 |
+
if model is not None:
|
| 831 |
+
comfy.model_management.load_models_gpu([model])
|
| 832 |
+
model_sd = model.model_state_dict()
|
| 833 |
+
for x in controlnet_data:
|
| 834 |
+
c_m = "control_model."
|
| 835 |
+
if x.startswith(c_m):
|
| 836 |
+
sd_key = "diffusion_model.{}".format(x[len(c_m):])
|
| 837 |
+
if sd_key in model_sd:
|
| 838 |
+
cd = controlnet_data[x]
|
| 839 |
+
cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
|
| 840 |
+
else:
|
| 841 |
+
print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
|
| 842 |
+
|
| 843 |
+
class WeightsLoader(torch.nn.Module):
|
| 844 |
+
pass
|
| 845 |
+
w = WeightsLoader()
|
| 846 |
+
w.control_model = control_model
|
| 847 |
+
missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
|
| 848 |
+
else:
|
| 849 |
+
missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
|
| 850 |
+
if len(missing) > 0 or len(unexpected) > 0:
|
| 851 |
+
logger.info(f"SVD-ControlNet: {missing}, {unexpected}")
|
| 852 |
+
|
| 853 |
+
global_average_pooling = False
|
| 854 |
+
filename = os.path.splitext(ckpt_path)[0]
|
| 855 |
+
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
|
| 856 |
+
global_average_pooling = True
|
| 857 |
+
|
| 858 |
+
control = SVDControlNetAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
|
| 859 |
+
return control
|
| 860 |
+
|
ComfyUI-Advanced-ControlNet/adv_control/control_lllite.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
|
| 2 |
+
# basically, all the LLLite core code is from there, which I then combined with
|
| 3 |
+
# Advanced-ControlNet features and QoL
|
| 4 |
+
import math
|
| 5 |
+
from typing import Union
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import torch
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
import comfy.utils
|
| 11 |
+
from comfy.controlnet import ControlBase
|
| 12 |
+
|
| 13 |
+
from .logger import logger
|
| 14 |
+
from .utils import AdvancedControlBase, deepcopy_with_sharing, prepare_mask_batch
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def extra_options_to_module_prefix(extra_options):
|
| 18 |
+
# extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
|
| 19 |
+
|
| 20 |
+
# block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
|
| 21 |
+
# ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
|
| 22 |
+
# transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
|
| 23 |
+
# block_index is: 0-1 or 0-9, depends on the block
|
| 24 |
+
# input 7 and 8, middle has 10 blocks
|
| 25 |
+
|
| 26 |
+
# make module name from extra_options
|
| 27 |
+
block = extra_options["block"]
|
| 28 |
+
block_index = extra_options["block_index"]
|
| 29 |
+
if block[0] == "input":
|
| 30 |
+
module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
|
| 31 |
+
elif block[0] == "middle":
|
| 32 |
+
module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
|
| 33 |
+
elif block[0] == "output":
|
| 34 |
+
module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
|
| 35 |
+
else:
|
| 36 |
+
raise Exception(f"ControlLLLite: invalid block name '{block[0]}'. Expected 'input', 'middle', or 'output'.")
|
| 37 |
+
return module_pfx
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class LLLitePatch:
|
| 41 |
+
ATTN1 = "attn1"
|
| 42 |
+
ATTN2 = "attn2"
|
| 43 |
+
def __init__(self, modules: dict[str, 'LLLiteModule'], patch_type: str, control: Union[AdvancedControlBase, ControlBase]=None):
|
| 44 |
+
self.modules = modules
|
| 45 |
+
self.control = control
|
| 46 |
+
self.patch_type = patch_type
|
| 47 |
+
#logger.error(f"create LLLitePatch: {id(self)},{control}")
|
| 48 |
+
|
| 49 |
+
def __call__(self, q, k, v, extra_options):
|
| 50 |
+
#logger.error(f"in __call__: {id(self)}")
|
| 51 |
+
# determine if have anything to run
|
| 52 |
+
if self.control.timestep_range is not None:
|
| 53 |
+
# it turns out comparing single-value tensors to floats is extremely slow
|
| 54 |
+
# a: Tensor = extra_options["sigmas"][0]
|
| 55 |
+
if self.control.t > self.control.timestep_range[0] or self.control.t < self.control.timestep_range[1]:
|
| 56 |
+
return q, k, v
|
| 57 |
+
|
| 58 |
+
module_pfx = extra_options_to_module_prefix(extra_options)
|
| 59 |
+
|
| 60 |
+
is_attn1 = q.shape[-1] == k.shape[-1] # self attention
|
| 61 |
+
if is_attn1:
|
| 62 |
+
module_pfx = module_pfx + "_attn1"
|
| 63 |
+
else:
|
| 64 |
+
module_pfx = module_pfx + "_attn2"
|
| 65 |
+
|
| 66 |
+
module_pfx_to_q = module_pfx + "_to_q"
|
| 67 |
+
module_pfx_to_k = module_pfx + "_to_k"
|
| 68 |
+
module_pfx_to_v = module_pfx + "_to_v"
|
| 69 |
+
|
| 70 |
+
if module_pfx_to_q in self.modules:
|
| 71 |
+
q = q + self.modules[module_pfx_to_q](q, self.control)
|
| 72 |
+
if module_pfx_to_k in self.modules:
|
| 73 |
+
k = k + self.modules[module_pfx_to_k](k, self.control)
|
| 74 |
+
if module_pfx_to_v in self.modules:
|
| 75 |
+
v = v + self.modules[module_pfx_to_v](v, self.control)
|
| 76 |
+
|
| 77 |
+
return q, k, v
|
| 78 |
+
|
| 79 |
+
def to(self, device):
|
| 80 |
+
#logger.info(f"to... has control? {self.control}")
|
| 81 |
+
for d in self.modules.keys():
|
| 82 |
+
self.modules[d] = self.modules[d].to(device)
|
| 83 |
+
return self
|
| 84 |
+
|
| 85 |
+
def set_control(self, control: Union[AdvancedControlBase, ControlBase]) -> 'LLLitePatch':
|
| 86 |
+
self.control = control
|
| 87 |
+
return self
|
| 88 |
+
#logger.error(f"set control for LLLitePatch: {id(self)}, cn: {id(control)}")
|
| 89 |
+
|
| 90 |
+
def clone_with_control(self, control: AdvancedControlBase):
|
| 91 |
+
#logger.error(f"clone-set control for LLLitePatch: {id(self)},{id(control)}")
|
| 92 |
+
return LLLitePatch(self.modules, self.patch_type, control)
|
| 93 |
+
|
| 94 |
+
def cleanup(self):
|
| 95 |
+
#total_cleaned = 0
|
| 96 |
+
for module in self.modules.values():
|
| 97 |
+
module.cleanup()
|
| 98 |
+
# total_cleaned += 1
|
| 99 |
+
#logger.info(f"cleaned modules: {total_cleaned}, {id(self)}")
|
| 100 |
+
#logger.error(f"cleanup LLLitePatch: {id(self)}")
|
| 101 |
+
|
| 102 |
+
# make sure deepcopy does not copy control, and deepcopied LLLitePatch should be assigned to control
|
| 103 |
+
def __deepcopy__(self, memo):
|
| 104 |
+
self.cleanup()
|
| 105 |
+
to_return: LLLitePatch = deepcopy_with_sharing(self, shared_attribute_names = ['control'], memo=memo)
|
| 106 |
+
#logger.warn(f"patch {id(self)} turned into {id(to_return)}")
|
| 107 |
+
try:
|
| 108 |
+
if self.patch_type == self.ATTN1:
|
| 109 |
+
to_return.control.patch_attn1 = to_return
|
| 110 |
+
elif self.patch_type == self.ATTN2:
|
| 111 |
+
to_return.control.patch_attn2 = to_return
|
| 112 |
+
except Exception:
|
| 113 |
+
pass
|
| 114 |
+
return to_return
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# TODO: use comfy.ops to support fp8 properly
|
| 118 |
+
class LLLiteModule(torch.nn.Module):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
name: str,
|
| 122 |
+
is_conv2d: bool,
|
| 123 |
+
in_dim: int,
|
| 124 |
+
depth: int,
|
| 125 |
+
cond_emb_dim: int,
|
| 126 |
+
mlp_dim: int,
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.name = name
|
| 130 |
+
self.is_conv2d = is_conv2d
|
| 131 |
+
self.is_first = False
|
| 132 |
+
|
| 133 |
+
modules = []
|
| 134 |
+
modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
|
| 135 |
+
if depth == 1:
|
| 136 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 137 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
| 138 |
+
elif depth == 2:
|
| 139 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 140 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
|
| 141 |
+
elif depth == 3:
|
| 142 |
+
# kernel size 8 is too large, so set it to 4
|
| 143 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 144 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
|
| 145 |
+
modules.append(torch.nn.ReLU(inplace=True))
|
| 146 |
+
modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
|
| 147 |
+
|
| 148 |
+
self.conditioning1 = torch.nn.Sequential(*modules)
|
| 149 |
+
|
| 150 |
+
if self.is_conv2d:
|
| 151 |
+
self.down = torch.nn.Sequential(
|
| 152 |
+
torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
| 153 |
+
torch.nn.ReLU(inplace=True),
|
| 154 |
+
)
|
| 155 |
+
self.mid = torch.nn.Sequential(
|
| 156 |
+
torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
|
| 157 |
+
torch.nn.ReLU(inplace=True),
|
| 158 |
+
)
|
| 159 |
+
self.up = torch.nn.Sequential(
|
| 160 |
+
torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
self.down = torch.nn.Sequential(
|
| 164 |
+
torch.nn.Linear(in_dim, mlp_dim),
|
| 165 |
+
torch.nn.ReLU(inplace=True),
|
| 166 |
+
)
|
| 167 |
+
self.mid = torch.nn.Sequential(
|
| 168 |
+
torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
|
| 169 |
+
torch.nn.ReLU(inplace=True),
|
| 170 |
+
)
|
| 171 |
+
self.up = torch.nn.Sequential(
|
| 172 |
+
torch.nn.Linear(mlp_dim, in_dim),
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
self.depth = depth
|
| 176 |
+
self.cond_emb = None
|
| 177 |
+
self.cx_shape = None
|
| 178 |
+
self.prev_batch = 0
|
| 179 |
+
self.prev_sub_idxs = None
|
| 180 |
+
|
| 181 |
+
def cleanup(self):
|
| 182 |
+
del self.cond_emb
|
| 183 |
+
self.cond_emb = None
|
| 184 |
+
self.cx_shape = None
|
| 185 |
+
self.prev_batch = 0
|
| 186 |
+
self.prev_sub_idxs = None
|
| 187 |
+
|
| 188 |
+
def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]):
|
| 189 |
+
mask = None
|
| 190 |
+
mask_tk = None
|
| 191 |
+
#logger.info(x.shape)
|
| 192 |
+
if self.cond_emb is None or control.sub_idxs != self.prev_sub_idxs or x.shape[0] != self.prev_batch:
|
| 193 |
+
# print(f"cond_emb is None, {self.name}")
|
| 194 |
+
cond_hint = control.cond_hint.to(x.device, dtype=x.dtype)
|
| 195 |
+
if control.latent_dims_div2 is not None and x.shape[-1] != 1280:
|
| 196 |
+
cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div2[0] * 8, control.latent_dims_div2[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
|
| 197 |
+
elif control.latent_dims_div4 is not None and x.shape[-1] == 1280:
|
| 198 |
+
cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div4[0] * 8, control.latent_dims_div4[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
|
| 199 |
+
cx = self.conditioning1(cond_hint)
|
| 200 |
+
self.cx_shape = cx.shape
|
| 201 |
+
if not self.is_conv2d:
|
| 202 |
+
# reshape / b,c,h,w -> b,h*w,c
|
| 203 |
+
n, c, h, w = cx.shape
|
| 204 |
+
cx = cx.view(n, c, h * w).permute(0, 2, 1)
|
| 205 |
+
self.cond_emb = cx
|
| 206 |
+
# save prev values
|
| 207 |
+
self.prev_batch = x.shape[0]
|
| 208 |
+
self.prev_sub_idxs = control.sub_idxs
|
| 209 |
+
|
| 210 |
+
cx: torch.Tensor = self.cond_emb
|
| 211 |
+
# print(f"forward {self.name}, {cx.shape}, {x.shape}")
|
| 212 |
+
|
| 213 |
+
# TODO: make masks work for conv2d (could not find any ControlLLLites at this time that use them)
|
| 214 |
+
# create masks
|
| 215 |
+
if not self.is_conv2d:
|
| 216 |
+
n, c, h, w = self.cx_shape
|
| 217 |
+
if control.mask_cond_hint is not None:
|
| 218 |
+
mask = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
|
| 219 |
+
mask = mask.view(mask.shape[0], 1, h * w).permute(0, 2, 1)
|
| 220 |
+
if control.tk_mask_cond_hint is not None:
|
| 221 |
+
mask_tk = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
|
| 222 |
+
mask_tk = mask_tk.view(mask_tk.shape[0], 1, h * w).permute(0, 2, 1)
|
| 223 |
+
|
| 224 |
+
# x in uncond/cond doubles batch size
|
| 225 |
+
if x.shape[0] != cx.shape[0]:
|
| 226 |
+
if self.is_conv2d:
|
| 227 |
+
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
|
| 228 |
+
else:
|
| 229 |
+
# print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
|
| 230 |
+
cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
|
| 231 |
+
if mask is not None:
|
| 232 |
+
mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1)
|
| 233 |
+
if mask_tk is not None:
|
| 234 |
+
mask_tk = mask_tk.repeat(x.shape[0] // mask_tk.shape[0], 1, 1)
|
| 235 |
+
|
| 236 |
+
if mask is None:
|
| 237 |
+
mask = 1.0
|
| 238 |
+
elif mask_tk is not None:
|
| 239 |
+
mask = mask * mask_tk
|
| 240 |
+
|
| 241 |
+
#logger.info(f"cs: {cx.shape}, x: {x.shape}, is_conv2d: {self.is_conv2d}")
|
| 242 |
+
cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
|
| 243 |
+
cx = self.mid(cx)
|
| 244 |
+
cx = self.up(cx)
|
| 245 |
+
if control.latent_keyframes is not None:
|
| 246 |
+
cx = cx * control.calc_latent_keyframe_mults(x=cx, batched_number=control.batched_number)
|
| 247 |
+
if control.weights is not None and control.weights.has_uncond_multiplier:
|
| 248 |
+
cond_or_uncond = control.batched_number.cond_or_uncond
|
| 249 |
+
actual_length = cx.size(0) // control.batched_number
|
| 250 |
+
for idx, cond_type in enumerate(cond_or_uncond):
|
| 251 |
+
# if uncond, set to weight's uncond_multiplier
|
| 252 |
+
if cond_type == 1:
|
| 253 |
+
cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier
|
| 254 |
+
return cx * mask * control.strength * control._current_timestep_keyframe.strength
|
ComfyUI-Advanced-ControlNet/adv_control/control_reference.py
ADDED
|
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Union
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
import comfy.sample
|
| 8 |
+
import comfy.model_patcher
|
| 9 |
+
import comfy.utils
|
| 10 |
+
from comfy.controlnet import ControlBase
|
| 11 |
+
from comfy.model_patcher import ModelPatcher
|
| 12 |
+
from comfy.ldm.modules.attention import BasicTransformerBlock
|
| 13 |
+
from comfy.ldm.modules.diffusionmodules import openaimodel
|
| 14 |
+
|
| 15 |
+
from .logger import logger
|
| 16 |
+
from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, AbstractPreprocWrapper,
|
| 17 |
+
deepcopy_with_sharing, prepare_mask_batch, broadcast_image_to_full)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def refcn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
|
| 21 |
+
def get_refcn(control: ControlBase, order: int=-1):
|
| 22 |
+
ref_set: set[ReferenceAdvanced] = set()
|
| 23 |
+
if control is None:
|
| 24 |
+
return ref_set
|
| 25 |
+
if type(control) == ReferenceAdvanced:
|
| 26 |
+
control.order = order
|
| 27 |
+
order -= 1
|
| 28 |
+
ref_set.add(control)
|
| 29 |
+
ref_set.update(get_refcn(control.previous_controlnet, order=order))
|
| 30 |
+
return ref_set
|
| 31 |
+
|
| 32 |
+
def refcn_sample(model: ModelPatcher, *args, **kwargs):
|
| 33 |
+
# check if positive or negative conds contain ref cn
|
| 34 |
+
positive = args[-3]
|
| 35 |
+
negative = args[-2]
|
| 36 |
+
ref_set = set()
|
| 37 |
+
if positive is not None:
|
| 38 |
+
for cond in positive:
|
| 39 |
+
if "control" in cond[1]:
|
| 40 |
+
ref_set.update(get_refcn(cond[1]["control"]))
|
| 41 |
+
if negative is not None:
|
| 42 |
+
for cond in negative:
|
| 43 |
+
if "control" in cond[1]:
|
| 44 |
+
ref_set.update(get_refcn(cond[1]["control"]))
|
| 45 |
+
# if no ref cn found, do original function immediately
|
| 46 |
+
if len(ref_set) == 0:
|
| 47 |
+
return orig_comfy_sample(model, *args, **kwargs)
|
| 48 |
+
# otherwise, injection time
|
| 49 |
+
try:
|
| 50 |
+
# inject
|
| 51 |
+
# storage for all Reference-related injections
|
| 52 |
+
reference_injections = ReferenceInjections()
|
| 53 |
+
|
| 54 |
+
# first, handle attn module injection
|
| 55 |
+
all_modules = torch_dfs(model.model)
|
| 56 |
+
attn_modules: list[RefBasicTransformerBlock] = []
|
| 57 |
+
for module in all_modules:
|
| 58 |
+
if isinstance(module, BasicTransformerBlock):
|
| 59 |
+
attn_modules.append(module)
|
| 60 |
+
attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
|
| 61 |
+
attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
|
| 62 |
+
for i, module in enumerate(attn_modules):
|
| 63 |
+
injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i)
|
| 64 |
+
injection_holder.attn_weight = float(i) / float(len(attn_modules))
|
| 65 |
+
if hasattr(module, "_forward"): # backward compatibility
|
| 66 |
+
module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
|
| 67 |
+
else:
|
| 68 |
+
module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
|
| 69 |
+
module.injection_holder = injection_holder
|
| 70 |
+
reference_injections.attn_modules.append(module)
|
| 71 |
+
# figure out which module is middle block
|
| 72 |
+
if hasattr(model.model.diffusion_model, "middle_block"):
|
| 73 |
+
mid_modules = torch_dfs(model.model.diffusion_model.middle_block)
|
| 74 |
+
mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)]
|
| 75 |
+
for module in mid_attn_modules:
|
| 76 |
+
module.injection_holder.is_middle = True
|
| 77 |
+
|
| 78 |
+
# next, handle gn module injection (TimestepEmbedSequential)
|
| 79 |
+
# TODO: figure out the logic behind these hardcoded indexes
|
| 80 |
+
if type(model.model).__name__ == "SDXL":
|
| 81 |
+
input_block_indices = [4, 5, 7, 8]
|
| 82 |
+
output_block_indices = [0, 1, 2, 3, 4, 5]
|
| 83 |
+
else:
|
| 84 |
+
input_block_indices = [4, 5, 7, 8, 10, 11]
|
| 85 |
+
output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
|
| 86 |
+
if hasattr(model.model.diffusion_model, "middle_block"):
|
| 87 |
+
module = model.model.diffusion_model.middle_block
|
| 88 |
+
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True)
|
| 89 |
+
injection_holder.gn_weight = 0.0
|
| 90 |
+
module.injection_holder = injection_holder
|
| 91 |
+
reference_injections.gn_modules.append(module)
|
| 92 |
+
for w, i in enumerate(input_block_indices):
|
| 93 |
+
module = model.model.diffusion_model.input_blocks[i]
|
| 94 |
+
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True)
|
| 95 |
+
injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
|
| 96 |
+
module.injection_holder = injection_holder
|
| 97 |
+
reference_injections.gn_modules.append(module)
|
| 98 |
+
for w, i in enumerate(output_block_indices):
|
| 99 |
+
module = model.model.diffusion_model.output_blocks[i]
|
| 100 |
+
injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True)
|
| 101 |
+
injection_holder.gn_weight = float(w) / float(len(output_block_indices))
|
| 102 |
+
module.injection_holder = injection_holder
|
| 103 |
+
reference_injections.gn_modules.append(module)
|
| 104 |
+
# hack gn_module forwards and update weights
|
| 105 |
+
for i, module in enumerate(reference_injections.gn_modules):
|
| 106 |
+
module.injection_holder.gn_weight *= 2
|
| 107 |
+
|
| 108 |
+
# handle diffusion_model forward injection
|
| 109 |
+
reference_injections.diffusion_model_orig_forward = model.model.diffusion_model.forward
|
| 110 |
+
model.model.diffusion_model.forward = factory_forward_inject_UNetModel(reference_injections).__get__(model.model.diffusion_model, type(model.model.diffusion_model))
|
| 111 |
+
# store ordered ref cns in model's transformer options
|
| 112 |
+
orig_model_options = model.model_options
|
| 113 |
+
new_model_options = model.model_options.copy()
|
| 114 |
+
new_model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
| 115 |
+
ref_list: list[ReferenceAdvanced] = list(ref_set)
|
| 116 |
+
new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order)
|
| 117 |
+
model.model_options = new_model_options
|
| 118 |
+
# continue with original function
|
| 119 |
+
return orig_comfy_sample(model, *args, **kwargs)
|
| 120 |
+
finally:
|
| 121 |
+
# cleanup injections
|
| 122 |
+
# restore attn modules
|
| 123 |
+
attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules
|
| 124 |
+
for module in attn_modules:
|
| 125 |
+
module.injection_holder.restore(module)
|
| 126 |
+
module.injection_holder.clean()
|
| 127 |
+
del module.injection_holder
|
| 128 |
+
del attn_modules
|
| 129 |
+
# restore gn modules
|
| 130 |
+
gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules
|
| 131 |
+
for module in gn_modules:
|
| 132 |
+
module.injection_holder.restore(module)
|
| 133 |
+
module.injection_holder.clean()
|
| 134 |
+
del module.injection_holder
|
| 135 |
+
del gn_modules
|
| 136 |
+
# restore diffusion_model forward function
|
| 137 |
+
model.model.diffusion_model.forward = reference_injections.diffusion_model_orig_forward.__get__(model.model.diffusion_model, type(model.model.diffusion_model))
|
| 138 |
+
# restore model_options
|
| 139 |
+
model.model_options = orig_model_options
|
| 140 |
+
# cleanup
|
| 141 |
+
reference_injections.cleanup()
|
| 142 |
+
return refcn_sample
|
| 143 |
+
# inject sample functions
|
| 144 |
+
comfy.sample.sample = refcn_sample_factory(comfy.sample.sample)
|
| 145 |
+
comfy.sample.sample_custom = refcn_sample_factory(comfy.sample.sample_custom, is_custom=True)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
REF_ATTN_CONTROL_LIST = "ref_attn_control_list"
|
| 149 |
+
REF_ADAIN_CONTROL_LIST = "ref_adain_control_list"
|
| 150 |
+
REF_CONTROL_LIST_ALL = "ref_control_list_all"
|
| 151 |
+
REF_CONTROL_INFO = "ref_control_info"
|
| 152 |
+
REF_ATTN_MACHINE_STATE = "ref_attn_machine_state"
|
| 153 |
+
REF_ADAIN_MACHINE_STATE = "ref_adain_machine_state"
|
| 154 |
+
REF_COND_IDXS = "ref_cond_idxs"
|
| 155 |
+
REF_UNCOND_IDXS = "ref_uncond_idxs"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class MachineState:
|
| 159 |
+
WRITE = "write"
|
| 160 |
+
READ = "read"
|
| 161 |
+
STYLEALIGN = "stylealign"
|
| 162 |
+
OFF = "off"
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class ReferenceType:
|
| 166 |
+
ATTN = "reference_attn"
|
| 167 |
+
ADAIN = "reference_adain"
|
| 168 |
+
ATTN_ADAIN = "reference_attn+adain"
|
| 169 |
+
STYLE_ALIGN = "StyleAlign"
|
| 170 |
+
|
| 171 |
+
_LIST = [ATTN, ADAIN, ATTN_ADAIN]
|
| 172 |
+
_LIST_ATTN = [ATTN, ATTN_ADAIN]
|
| 173 |
+
_LIST_ADAIN = [ADAIN, ATTN_ADAIN]
|
| 174 |
+
|
| 175 |
+
@classmethod
|
| 176 |
+
def is_attn(cls, ref_type: str):
|
| 177 |
+
return ref_type in cls._LIST_ATTN
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def is_adain(cls, ref_type: str):
|
| 181 |
+
return ref_type in cls._LIST_ADAIN
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ReferenceOptions:
|
| 185 |
+
def __init__(self, reference_type: str,
|
| 186 |
+
attn_style_fidelity: float, adain_style_fidelity: float,
|
| 187 |
+
attn_ref_weight: float, adain_ref_weight: float,
|
| 188 |
+
attn_strength: float=1.0, adain_strength: float=1.0,
|
| 189 |
+
ref_with_other_cns: bool=False):
|
| 190 |
+
self.reference_type = reference_type
|
| 191 |
+
# attn
|
| 192 |
+
self.original_attn_style_fidelity = attn_style_fidelity
|
| 193 |
+
self.attn_style_fidelity = attn_style_fidelity
|
| 194 |
+
self.attn_ref_weight = attn_ref_weight
|
| 195 |
+
self.attn_strength = attn_strength
|
| 196 |
+
# adain
|
| 197 |
+
self.original_adain_style_fidelity = adain_style_fidelity
|
| 198 |
+
self.adain_style_fidelity = adain_style_fidelity
|
| 199 |
+
self.adain_ref_weight = adain_ref_weight
|
| 200 |
+
self.adain_strength = adain_strength
|
| 201 |
+
# other
|
| 202 |
+
self.ref_with_other_cns = ref_with_other_cns
|
| 203 |
+
|
| 204 |
+
def clone(self):
|
| 205 |
+
return ReferenceOptions(reference_type=self.reference_type,
|
| 206 |
+
attn_style_fidelity=self.original_attn_style_fidelity, adain_style_fidelity=self.original_adain_style_fidelity,
|
| 207 |
+
attn_ref_weight=self.attn_ref_weight, adain_ref_weight=self.adain_ref_weight,
|
| 208 |
+
attn_strength=self.attn_strength, adain_strength=self.adain_strength,
|
| 209 |
+
ref_with_other_cns=self.ref_with_other_cns)
|
| 210 |
+
|
| 211 |
+
@staticmethod
|
| 212 |
+
def create_combo(reference_type: str, style_fidelity: float, ref_weight: float, ref_with_other_cns: bool=False):
|
| 213 |
+
return ReferenceOptions(reference_type=reference_type,
|
| 214 |
+
attn_style_fidelity=style_fidelity, adain_style_fidelity=style_fidelity,
|
| 215 |
+
attn_ref_weight=ref_weight, adain_ref_weight=ref_weight,
|
| 216 |
+
ref_with_other_cns=ref_with_other_cns)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class ReferencePreprocWrapper(AbstractPreprocWrapper):
|
| 221 |
+
error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
|
| 222 |
+
def __init__(self, condhint: Tensor):
|
| 223 |
+
super().__init__(condhint)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class ReferenceAdvanced(ControlBase, AdvancedControlBase):
|
| 227 |
+
CHANNEL_TO_MULT = {320: 1, 640: 2, 1280: 4}
|
| 228 |
+
|
| 229 |
+
def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup, device=None):
|
| 230 |
+
super().__init__(device)
|
| 231 |
+
AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
|
| 232 |
+
self.ref_opts = ref_opts
|
| 233 |
+
self.order = 0
|
| 234 |
+
self.latent_format = None
|
| 235 |
+
self.model_sampling_current = None
|
| 236 |
+
self.should_apply_attn_effective_strength = False
|
| 237 |
+
self.should_apply_adain_effective_strength = False
|
| 238 |
+
self.should_apply_effective_masks = False
|
| 239 |
+
self.latent_shape = None
|
| 240 |
+
|
| 241 |
+
def any_attn_strength_to_apply(self):
|
| 242 |
+
return self.should_apply_attn_effective_strength or self.should_apply_effective_masks
|
| 243 |
+
|
| 244 |
+
def any_adain_strength_to_apply(self):
|
| 245 |
+
return self.should_apply_adain_effective_strength or self.should_apply_effective_masks
|
| 246 |
+
|
| 247 |
+
def get_effective_strength(self):
|
| 248 |
+
effective_strength = self.strength
|
| 249 |
+
if self._current_timestep_keyframe is not None:
|
| 250 |
+
effective_strength = effective_strength * self._current_timestep_keyframe.strength
|
| 251 |
+
return effective_strength
|
| 252 |
+
|
| 253 |
+
def get_effective_attn_mask_or_float(self, x: Tensor, channels: int, is_mid: bool):
|
| 254 |
+
if not self.should_apply_effective_masks:
|
| 255 |
+
return self.get_effective_strength() * self.ref_opts.attn_strength
|
| 256 |
+
if is_mid:
|
| 257 |
+
div = 8
|
| 258 |
+
else:
|
| 259 |
+
div = self.CHANNEL_TO_MULT[channels]
|
| 260 |
+
real_mask = torch.ones([self.latent_shape[0], 1, self.latent_shape[2]//div, self.latent_shape[3]//div]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.attn_strength
|
| 261 |
+
self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number)
|
| 262 |
+
# mask is now shape [b, 1, h ,w]; need to turn into [b, h*w, 1]
|
| 263 |
+
b, c, h, w = real_mask.shape
|
| 264 |
+
real_mask = real_mask.permute(0, 2, 3, 1).reshape(b, h*w, c)
|
| 265 |
+
return real_mask
|
| 266 |
+
|
| 267 |
+
def get_effective_adain_mask_or_float(self, x: Tensor):
|
| 268 |
+
if not self.should_apply_effective_masks:
|
| 269 |
+
return self.get_effective_strength() * self.ref_opts.adain_strength
|
| 270 |
+
b, c, h, w = x.shape
|
| 271 |
+
real_mask = torch.ones([b, 1, h, w]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.adain_strength
|
| 272 |
+
self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number)
|
| 273 |
+
return real_mask
|
| 274 |
+
|
| 275 |
+
def should_run(self):
|
| 276 |
+
running = super().should_run()
|
| 277 |
+
if not running:
|
| 278 |
+
return running
|
| 279 |
+
attn_run = False
|
| 280 |
+
adain_run = False
|
| 281 |
+
if ReferenceType.is_attn(self.ref_opts.reference_type):
|
| 282 |
+
# attn will run as long as neither weight or strength is zero
|
| 283 |
+
attn_run = not (math.isclose(self.ref_opts.attn_ref_weight, 0.0) or math.isclose(self.ref_opts.attn_strength, 0.0))
|
| 284 |
+
if ReferenceType.is_adain(self.ref_opts.reference_type):
|
| 285 |
+
# adain will run as long as neither weight or strength is zero
|
| 286 |
+
adain_run = not (math.isclose(self.ref_opts.adain_ref_weight, 0.0) or math.isclose(self.ref_opts.adain_strength, 0.0))
|
| 287 |
+
return attn_run or adain_run
|
| 288 |
+
|
| 289 |
+
def pre_run_advanced(self, model, percent_to_timestep_function):
|
| 290 |
+
AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function)
|
| 291 |
+
if type(self.cond_hint_original) == ReferencePreprocWrapper:
|
| 292 |
+
self.cond_hint_original = self.cond_hint_original.condhint
|
| 293 |
+
self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
|
| 294 |
+
self.model_sampling_current = model.model_sampling
|
| 295 |
+
# SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments
|
| 296 |
+
if type(model).__name__ == "SDXL":
|
| 297 |
+
self.ref_opts.attn_style_fidelity = self.ref_opts.original_attn_style_fidelity ** 3.0
|
| 298 |
+
self.ref_opts.adain_style_fidelity = self.ref_opts.original_adain_style_fidelity ** 3.0
|
| 299 |
+
else:
|
| 300 |
+
self.ref_opts.attn_style_fidelity = self.ref_opts.original_attn_style_fidelity
|
| 301 |
+
self.ref_opts.adain_style_fidelity = self.ref_opts.original_adain_style_fidelity
|
| 302 |
+
|
| 303 |
+
def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
|
| 304 |
+
# normal ControlNet stuff
|
| 305 |
+
control_prev = None
|
| 306 |
+
if self.previous_controlnet is not None:
|
| 307 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
| 308 |
+
|
| 309 |
+
if self.timestep_range is not None:
|
| 310 |
+
if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
|
| 311 |
+
return control_prev
|
| 312 |
+
|
| 313 |
+
dtype = x_noisy.dtype
|
| 314 |
+
# prepare cond_hint - it is a latent, NOT an image
|
| 315 |
+
#if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] != self.cond_hint.shape[2] or x_noisy.shape[3] != self.cond_hint.shape[3]:
|
| 316 |
+
if self.cond_hint is not None:
|
| 317 |
+
del self.cond_hint
|
| 318 |
+
self.cond_hint = None
|
| 319 |
+
# if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
|
| 320 |
+
if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
|
| 321 |
+
self.cond_hint = comfy.utils.common_upscale(
|
| 322 |
+
self.cond_hint_original[self.sub_idxs],
|
| 323 |
+
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(self.device)
|
| 324 |
+
else:
|
| 325 |
+
self.cond_hint = comfy.utils.common_upscale(
|
| 326 |
+
self.cond_hint_original,
|
| 327 |
+
x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(self.device)
|
| 328 |
+
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
| 329 |
+
self.cond_hint = broadcast_image_to_full(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
|
| 330 |
+
# noise cond_hint based on sigma (current step)
|
| 331 |
+
self.cond_hint = self.latent_format.process_in(self.cond_hint)
|
| 332 |
+
self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None)
|
| 333 |
+
timestep = self.model_sampling_current.timestep(t)
|
| 334 |
+
self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0))
|
| 335 |
+
self.should_apply_adain_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.adain_strength, 1.0))
|
| 336 |
+
# prepare mask - use direct_attn, so the mask dims will match source latents (and be smaller)
|
| 337 |
+
self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, direct_attn=True)
|
| 338 |
+
self.should_apply_effective_masks = self.latent_keyframes is not None or self.mask_cond_hint is not None or self.tk_mask_cond_hint is not None
|
| 339 |
+
self.latent_shape = list(x_noisy.shape)
|
| 340 |
+
# done preparing; model patches will take care of everything now.
|
| 341 |
+
# return normal controlnet stuff
|
| 342 |
+
return control_prev
|
| 343 |
+
|
| 344 |
+
def cleanup_advanced(self):
|
| 345 |
+
super().cleanup_advanced()
|
| 346 |
+
del self.latent_format
|
| 347 |
+
self.latent_format = None
|
| 348 |
+
del self.model_sampling_current
|
| 349 |
+
self.model_sampling_current = None
|
| 350 |
+
self.should_apply_attn_effective_strength = False
|
| 351 |
+
self.should_apply_adain_effective_strength = False
|
| 352 |
+
self.should_apply_effective_masks = False
|
| 353 |
+
|
| 354 |
+
def copy(self):
|
| 355 |
+
c = ReferenceAdvanced(self.ref_opts, self.timestep_keyframes)
|
| 356 |
+
c.order = self.order
|
| 357 |
+
self.copy_to(c)
|
| 358 |
+
self.copy_to_advanced(c)
|
| 359 |
+
return c
|
| 360 |
+
|
| 361 |
+
# avoid deepcopy shenanigans by making deepcopy not do anything to the reference
|
| 362 |
+
# TODO: do the bookkeeping to do this in a proper way for all Adv-ControlNets
|
| 363 |
+
def __deepcopy__(self, memo):
|
| 364 |
+
return self
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def ref_noise_latents(latents: Tensor, sigma: Tensor, noise: Tensor=None):
|
| 368 |
+
sigma = sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 369 |
+
alpha_cumprod = 1 / ((sigma * sigma) + 1)
|
| 370 |
+
sqrt_alpha_prod = alpha_cumprod ** 0.5
|
| 371 |
+
sqrt_one_minus_alpha_prod = (1. - alpha_cumprod) ** 0.5
|
| 372 |
+
if noise is None:
|
| 373 |
+
# generator = torch.Generator(device="cuda")
|
| 374 |
+
# generator.manual_seed(0)
|
| 375 |
+
# noise = torch.empty_like(latents).normal_(generator=generator)
|
| 376 |
+
# generator = torch.Generator()
|
| 377 |
+
# generator.manual_seed(0)
|
| 378 |
+
# noise = torch.randn(latents.size(), generator=generator).to(latents.device)
|
| 379 |
+
noise = torch.randn_like(latents).to(latents.device)
|
| 380 |
+
return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def simple_noise_latents(latents: Tensor, sigma: float, noise: Tensor=None):
|
| 384 |
+
if noise is None:
|
| 385 |
+
noise = torch.rand_like(latents)
|
| 386 |
+
return latents + noise * sigma
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
class BankStylesBasicTransformerBlock:
|
| 390 |
+
def __init__(self):
|
| 391 |
+
self.bank = []
|
| 392 |
+
self.style_cfgs = []
|
| 393 |
+
self.cn_idx: list[int] = []
|
| 394 |
+
|
| 395 |
+
def get_avg_style_fidelity(self):
|
| 396 |
+
return sum(self.style_cfgs) / float(len(self.style_cfgs))
|
| 397 |
+
|
| 398 |
+
def clean(self):
|
| 399 |
+
del self.bank
|
| 400 |
+
self.bank = []
|
| 401 |
+
del self.style_cfgs
|
| 402 |
+
self.style_cfgs = []
|
| 403 |
+
del self.cn_idx
|
| 404 |
+
self.cn_idx = []
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class BankStylesTimestepEmbedSequential:
|
| 408 |
+
def __init__(self):
|
| 409 |
+
self.var_bank = []
|
| 410 |
+
self.mean_bank = []
|
| 411 |
+
self.style_cfgs = []
|
| 412 |
+
self.cn_idx: list[int] = []
|
| 413 |
+
|
| 414 |
+
def get_avg_var_bank(self):
|
| 415 |
+
return sum(self.var_bank) / float(len(self.var_bank))
|
| 416 |
+
|
| 417 |
+
def get_avg_mean_bank(self):
|
| 418 |
+
return sum(self.mean_bank) / float(len(self.mean_bank))
|
| 419 |
+
|
| 420 |
+
def get_avg_style_fidelity(self):
|
| 421 |
+
return sum(self.style_cfgs) / float(len(self.style_cfgs))
|
| 422 |
+
|
| 423 |
+
def clean(self):
|
| 424 |
+
del self.mean_bank
|
| 425 |
+
self.mean_bank = []
|
| 426 |
+
del self.var_bank
|
| 427 |
+
self.var_bank = []
|
| 428 |
+
del self.style_cfgs
|
| 429 |
+
self.style_cfgs = []
|
| 430 |
+
del self.cn_idx
|
| 431 |
+
self.cn_idx = []
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class InjectionBasicTransformerBlockHolder:
|
| 435 |
+
def __init__(self, block: BasicTransformerBlock, idx=None):
|
| 436 |
+
if hasattr(block, "_forward"): # backward compatibility
|
| 437 |
+
self.original_forward = block._forward
|
| 438 |
+
else:
|
| 439 |
+
self.original_forward = block.forward
|
| 440 |
+
self.idx = idx
|
| 441 |
+
self.attn_weight = 1.0
|
| 442 |
+
self.is_middle = False
|
| 443 |
+
self.bank_styles = BankStylesBasicTransformerBlock()
|
| 444 |
+
|
| 445 |
+
def restore(self, block: BasicTransformerBlock):
|
| 446 |
+
if hasattr(block, "_forward"): # backward compatibility
|
| 447 |
+
block._forward = self.original_forward
|
| 448 |
+
else:
|
| 449 |
+
block.forward = self.original_forward
|
| 450 |
+
|
| 451 |
+
def clean(self):
|
| 452 |
+
self.bank_styles.clean()
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class InjectionTimestepEmbedSequentialHolder:
|
| 456 |
+
def __init__(self, block: openaimodel.TimestepEmbedSequential, idx=None, is_middle=False, is_input=False, is_output=False):
|
| 457 |
+
self.original_forward = block.forward
|
| 458 |
+
self.idx = idx
|
| 459 |
+
self.gn_weight = 1.0
|
| 460 |
+
self.is_middle = is_middle
|
| 461 |
+
self.is_input = is_input
|
| 462 |
+
self.is_output = is_output
|
| 463 |
+
self.bank_styles = BankStylesTimestepEmbedSequential()
|
| 464 |
+
|
| 465 |
+
def restore(self, block: openaimodel.TimestepEmbedSequential):
|
| 466 |
+
block.forward = self.original_forward
|
| 467 |
+
|
| 468 |
+
def clean(self):
|
| 469 |
+
self.bank_styles.clean()
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
class ReferenceInjections:
|
| 473 |
+
def __init__(self, attn_modules: list['RefBasicTransformerBlock']=None, gn_modules: list['RefTimestepEmbedSequential']=None):
|
| 474 |
+
self.attn_modules = attn_modules if attn_modules else []
|
| 475 |
+
self.gn_modules = gn_modules if gn_modules else []
|
| 476 |
+
self.diffusion_model_orig_forward: Callable = None
|
| 477 |
+
|
| 478 |
+
def clean_module_mem(self):
|
| 479 |
+
for attn_module in self.attn_modules:
|
| 480 |
+
try:
|
| 481 |
+
attn_module.injection_holder.clean()
|
| 482 |
+
except Exception:
|
| 483 |
+
pass
|
| 484 |
+
for gn_module in self.gn_modules:
|
| 485 |
+
try:
|
| 486 |
+
gn_module.injection_holder.clean()
|
| 487 |
+
except Exception:
|
| 488 |
+
pass
|
| 489 |
+
|
| 490 |
+
def cleanup(self):
|
| 491 |
+
self.clean_module_mem()
|
| 492 |
+
del self.attn_modules
|
| 493 |
+
self.attn_modules = []
|
| 494 |
+
del self.gn_modules
|
| 495 |
+
self.gn_modules = []
|
| 496 |
+
self.diffusion_model_orig_forward = None
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def factory_forward_inject_UNetModel(reference_injections: ReferenceInjections):
|
| 500 |
+
def forward_inject_UNetModel(self, x: Tensor, *args, **kwargs):
|
| 501 |
+
# get control and transformer_options from kwargs
|
| 502 |
+
real_args = list(args)
|
| 503 |
+
real_kwargs = list(kwargs.keys())
|
| 504 |
+
control = kwargs.get("control", None)
|
| 505 |
+
transformer_options = kwargs.get("transformer_options", None)
|
| 506 |
+
# look for ReferenceAttnPatch objects to get ReferenceAdvanced objects
|
| 507 |
+
ref_controlnets: list[ReferenceAdvanced] = transformer_options[REF_CONTROL_LIST_ALL]
|
| 508 |
+
# discard any controlnets that should not run
|
| 509 |
+
ref_controlnets = [x for x in ref_controlnets if x.should_run()]
|
| 510 |
+
# if nothing related to reference controlnets, do nothing special
|
| 511 |
+
if len(ref_controlnets) == 0:
|
| 512 |
+
return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs)
|
| 513 |
+
try:
|
| 514 |
+
# assign cond and uncond idxs
|
| 515 |
+
batched_number = len(transformer_options["cond_or_uncond"])
|
| 516 |
+
per_batch = x.shape[0] // batched_number
|
| 517 |
+
indiv_conds = []
|
| 518 |
+
for cond_type in transformer_options["cond_or_uncond"]:
|
| 519 |
+
indiv_conds.extend([cond_type] * per_batch)
|
| 520 |
+
transformer_options[REF_UNCOND_IDXS] = [i for i, x in enumerate(indiv_conds) if x == 1]
|
| 521 |
+
transformer_options[REF_COND_IDXS] = [i for i, x in enumerate(indiv_conds) if x == 0]
|
| 522 |
+
# check which controlnets do which thing
|
| 523 |
+
attn_controlnets = []
|
| 524 |
+
adain_controlnets = []
|
| 525 |
+
for control in ref_controlnets:
|
| 526 |
+
if ReferenceType.is_attn(control.ref_opts.reference_type):
|
| 527 |
+
attn_controlnets.append(control)
|
| 528 |
+
if ReferenceType.is_adain(control.ref_opts.reference_type):
|
| 529 |
+
adain_controlnets.append(control)
|
| 530 |
+
if len(adain_controlnets) > 0:
|
| 531 |
+
# ComfyUI uses forward_timestep_embed with the TimestepEmbedSequential passed into it
|
| 532 |
+
orig_forward_timestep_embed = openaimodel.forward_timestep_embed
|
| 533 |
+
openaimodel.forward_timestep_embed = forward_timestep_embed_ref_inject_factory(orig_forward_timestep_embed)
|
| 534 |
+
# handle running diffusion with ref cond hints
|
| 535 |
+
for control in ref_controlnets:
|
| 536 |
+
if ReferenceType.is_attn(control.ref_opts.reference_type):
|
| 537 |
+
transformer_options[REF_ATTN_MACHINE_STATE] = MachineState.WRITE
|
| 538 |
+
else:
|
| 539 |
+
transformer_options[REF_ATTN_MACHINE_STATE] = MachineState.OFF
|
| 540 |
+
if ReferenceType.is_adain(control.ref_opts.reference_type):
|
| 541 |
+
transformer_options[REF_ADAIN_MACHINE_STATE] = MachineState.WRITE
|
| 542 |
+
else:
|
| 543 |
+
transformer_options[REF_ADAIN_MACHINE_STATE] = MachineState.OFF
|
| 544 |
+
transformer_options[REF_ATTN_CONTROL_LIST] = [control]
|
| 545 |
+
transformer_options[REF_ADAIN_CONTROL_LIST] = [control]
|
| 546 |
+
|
| 547 |
+
orig_kwargs = kwargs
|
| 548 |
+
if not control.ref_opts.ref_with_other_cns:
|
| 549 |
+
kwargs = kwargs.copy()
|
| 550 |
+
kwargs["control"] = None
|
| 551 |
+
reference_injections.diffusion_model_orig_forward(control.cond_hint.to(dtype=x.dtype).to(device=x.device), *args, **kwargs)
|
| 552 |
+
kwargs = orig_kwargs
|
| 553 |
+
# run diffusion for real now
|
| 554 |
+
transformer_options[REF_ATTN_MACHINE_STATE] = MachineState.READ
|
| 555 |
+
transformer_options[REF_ADAIN_MACHINE_STATE] = MachineState.READ
|
| 556 |
+
transformer_options[REF_ATTN_CONTROL_LIST] = attn_controlnets
|
| 557 |
+
transformer_options[REF_ADAIN_CONTROL_LIST] = adain_controlnets
|
| 558 |
+
return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs)
|
| 559 |
+
finally:
|
| 560 |
+
# make sure banks are cleared no matter what happens - otherwise, RIP VRAM
|
| 561 |
+
reference_injections.clean_module_mem()
|
| 562 |
+
if len(adain_controlnets) > 0:
|
| 563 |
+
openaimodel.forward_timestep_embed = orig_forward_timestep_embed
|
| 564 |
+
|
| 565 |
+
return forward_inject_UNetModel
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
# dummy class just to help IDE keep track of injected variables
|
| 569 |
+
class RefBasicTransformerBlock(BasicTransformerBlock):
|
| 570 |
+
injection_holder: InjectionBasicTransformerBlockHolder = None
|
| 571 |
+
|
| 572 |
+
def _forward_inject_BasicTransformerBlock(self: RefBasicTransformerBlock, x: Tensor, context: Tensor=None, transformer_options: dict[str]={}):
|
| 573 |
+
extra_options = {}
|
| 574 |
+
block = transformer_options.get("block", None)
|
| 575 |
+
block_index = transformer_options.get("block_index", 0)
|
| 576 |
+
transformer_patches = {}
|
| 577 |
+
transformer_patches_replace = {}
|
| 578 |
+
|
| 579 |
+
for k in transformer_options:
|
| 580 |
+
if k == "patches":
|
| 581 |
+
transformer_patches = transformer_options[k]
|
| 582 |
+
elif k == "patches_replace":
|
| 583 |
+
transformer_patches_replace = transformer_options[k]
|
| 584 |
+
else:
|
| 585 |
+
extra_options[k] = transformer_options[k]
|
| 586 |
+
|
| 587 |
+
extra_options["n_heads"] = self.n_heads
|
| 588 |
+
extra_options["dim_head"] = self.d_head
|
| 589 |
+
|
| 590 |
+
if self.ff_in:
|
| 591 |
+
x_skip = x
|
| 592 |
+
x = self.ff_in(self.norm_in(x))
|
| 593 |
+
if self.is_res:
|
| 594 |
+
x += x_skip
|
| 595 |
+
|
| 596 |
+
n: Tensor = self.norm1(x)
|
| 597 |
+
if self.disable_self_attn:
|
| 598 |
+
context_attn1 = context
|
| 599 |
+
else:
|
| 600 |
+
context_attn1 = None
|
| 601 |
+
value_attn1 = None
|
| 602 |
+
|
| 603 |
+
# Reference CN stuff
|
| 604 |
+
uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, [])
|
| 605 |
+
c_idx_mask = transformer_options.get(REF_COND_IDXS, [])
|
| 606 |
+
# WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced
|
| 607 |
+
ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_ATTN_CONTROL_LIST, None)
|
| 608 |
+
ref_machine_state: str = transformer_options.get(REF_ATTN_MACHINE_STATE, None)
|
| 609 |
+
# if in WRITE mode, save n and style_fidelity
|
| 610 |
+
if ref_controlnets and ref_machine_state == MachineState.WRITE:
|
| 611 |
+
if ref_controlnets[0].ref_opts.attn_ref_weight > self.injection_holder.attn_weight:
|
| 612 |
+
self.injection_holder.bank_styles.bank.append(n.detach().clone())
|
| 613 |
+
self.injection_holder.bank_styles.style_cfgs.append(ref_controlnets[0].ref_opts.attn_style_fidelity)
|
| 614 |
+
self.injection_holder.bank_styles.cn_idx.append(ref_controlnets[0].order)
|
| 615 |
+
|
| 616 |
+
if "attn1_patch" in transformer_patches:
|
| 617 |
+
patch = transformer_patches["attn1_patch"]
|
| 618 |
+
if context_attn1 is None:
|
| 619 |
+
context_attn1 = n
|
| 620 |
+
value_attn1 = context_attn1
|
| 621 |
+
for p in patch:
|
| 622 |
+
n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
|
| 623 |
+
|
| 624 |
+
if block is not None:
|
| 625 |
+
transformer_block = (block[0], block[1], block_index)
|
| 626 |
+
else:
|
| 627 |
+
transformer_block = None
|
| 628 |
+
attn1_replace_patch = transformer_patches_replace.get("attn1", {})
|
| 629 |
+
block_attn1 = transformer_block
|
| 630 |
+
if block_attn1 not in attn1_replace_patch:
|
| 631 |
+
block_attn1 = block
|
| 632 |
+
|
| 633 |
+
if block_attn1 in attn1_replace_patch:
|
| 634 |
+
if context_attn1 is None:
|
| 635 |
+
context_attn1 = n
|
| 636 |
+
value_attn1 = n
|
| 637 |
+
n = self.attn1.to_q(n)
|
| 638 |
+
# Reference CN READ - use attn1_replace_patch appropriately
|
| 639 |
+
if ref_machine_state == MachineState.READ and len(self.injection_holder.bank_styles.bank) > 0:
|
| 640 |
+
bank_styles = self.injection_holder.bank_styles
|
| 641 |
+
style_fidelity = bank_styles.get_avg_style_fidelity()
|
| 642 |
+
real_bank = bank_styles.bank.copy()
|
| 643 |
+
cn_idx = 0
|
| 644 |
+
for idx, order in enumerate(bank_styles.cn_idx):
|
| 645 |
+
# make sure matching ref cn is selected
|
| 646 |
+
for i in range(cn_idx, len(ref_controlnets)):
|
| 647 |
+
if ref_controlnets[i].order == order:
|
| 648 |
+
cn_idx = i
|
| 649 |
+
break
|
| 650 |
+
assert order == ref_controlnets[cn_idx].order
|
| 651 |
+
if ref_controlnets[cn_idx].any_attn_strength_to_apply():
|
| 652 |
+
effective_strength = ref_controlnets[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle)
|
| 653 |
+
real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
|
| 654 |
+
n_uc = self.attn1.to_out(attn1_replace_patch[block_attn1](
|
| 655 |
+
n,
|
| 656 |
+
self.attn1.to_k(torch.cat([context_attn1] + real_bank, dim=1)),
|
| 657 |
+
self.attn1.to_v(torch.cat([value_attn1] + real_bank, dim=1)),
|
| 658 |
+
extra_options))
|
| 659 |
+
n_c = n_uc.clone()
|
| 660 |
+
if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
|
| 661 |
+
n_c[uc_idx_mask] = self.attn1.to_out(attn1_replace_patch[block_attn1](
|
| 662 |
+
n[uc_idx_mask],
|
| 663 |
+
self.attn1.to_k(context_attn1[uc_idx_mask]),
|
| 664 |
+
self.attn1.to_v(value_attn1[uc_idx_mask]),
|
| 665 |
+
extra_options))
|
| 666 |
+
n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc
|
| 667 |
+
bank_styles.clean()
|
| 668 |
+
else:
|
| 669 |
+
context_attn1 = self.attn1.to_k(context_attn1)
|
| 670 |
+
value_attn1 = self.attn1.to_v(value_attn1)
|
| 671 |
+
n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
|
| 672 |
+
n = self.attn1.to_out(n)
|
| 673 |
+
else:
|
| 674 |
+
# Reference CN READ - no attn1_replace_patch
|
| 675 |
+
if ref_machine_state == MachineState.READ and len(self.injection_holder.bank_styles.bank) > 0:
|
| 676 |
+
if context_attn1 is None:
|
| 677 |
+
context_attn1 = n
|
| 678 |
+
bank_styles = self.injection_holder.bank_styles
|
| 679 |
+
style_fidelity = bank_styles.get_avg_style_fidelity()
|
| 680 |
+
real_bank = bank_styles.bank.copy()
|
| 681 |
+
cn_idx = 0
|
| 682 |
+
for idx, order in enumerate(bank_styles.cn_idx):
|
| 683 |
+
# make sure matching ref cn is selected
|
| 684 |
+
for i in range(cn_idx, len(ref_controlnets)):
|
| 685 |
+
if ref_controlnets[i].order == order:
|
| 686 |
+
cn_idx = i
|
| 687 |
+
break
|
| 688 |
+
assert order == ref_controlnets[cn_idx].order
|
| 689 |
+
if ref_controlnets[cn_idx].any_attn_strength_to_apply():
|
| 690 |
+
effective_strength = ref_controlnets[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle)
|
| 691 |
+
real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
|
| 692 |
+
n_uc: Tensor = self.attn1(
|
| 693 |
+
n,
|
| 694 |
+
context=torch.cat([context_attn1] + real_bank, dim=1),
|
| 695 |
+
value=torch.cat([value_attn1] + real_bank, dim=1) if value_attn1 is not None else value_attn1)
|
| 696 |
+
n_c = n_uc.clone()
|
| 697 |
+
if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
|
| 698 |
+
n_c[uc_idx_mask] = self.attn1(
|
| 699 |
+
n[uc_idx_mask],
|
| 700 |
+
context=context_attn1[uc_idx_mask],
|
| 701 |
+
value=value_attn1[uc_idx_mask] if value_attn1 is not None else value_attn1)
|
| 702 |
+
n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc
|
| 703 |
+
bank_styles.clean()
|
| 704 |
+
else:
|
| 705 |
+
n = self.attn1(n, context=context_attn1, value=value_attn1)
|
| 706 |
+
|
| 707 |
+
if "attn1_output_patch" in transformer_patches:
|
| 708 |
+
patch = transformer_patches["attn1_output_patch"]
|
| 709 |
+
for p in patch:
|
| 710 |
+
n = p(n, extra_options)
|
| 711 |
+
|
| 712 |
+
x += n
|
| 713 |
+
if "middle_patch" in transformer_patches:
|
| 714 |
+
patch = transformer_patches["middle_patch"]
|
| 715 |
+
for p in patch:
|
| 716 |
+
x = p(x, extra_options)
|
| 717 |
+
|
| 718 |
+
if self.attn2 is not None:
|
| 719 |
+
n = self.norm2(x)
|
| 720 |
+
if self.switch_temporal_ca_to_sa:
|
| 721 |
+
context_attn2 = n
|
| 722 |
+
else:
|
| 723 |
+
context_attn2 = context
|
| 724 |
+
value_attn2 = None
|
| 725 |
+
if "attn2_patch" in transformer_patches:
|
| 726 |
+
patch = transformer_patches["attn2_patch"]
|
| 727 |
+
value_attn2 = context_attn2
|
| 728 |
+
for p in patch:
|
| 729 |
+
n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
|
| 730 |
+
|
| 731 |
+
attn2_replace_patch = transformer_patches_replace.get("attn2", {})
|
| 732 |
+
block_attn2 = transformer_block
|
| 733 |
+
if block_attn2 not in attn2_replace_patch:
|
| 734 |
+
block_attn2 = block
|
| 735 |
+
|
| 736 |
+
if block_attn2 in attn2_replace_patch:
|
| 737 |
+
if value_attn2 is None:
|
| 738 |
+
value_attn2 = context_attn2
|
| 739 |
+
n = self.attn2.to_q(n)
|
| 740 |
+
context_attn2 = self.attn2.to_k(context_attn2)
|
| 741 |
+
value_attn2 = self.attn2.to_v(value_attn2)
|
| 742 |
+
n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
|
| 743 |
+
n = self.attn2.to_out(n)
|
| 744 |
+
else:
|
| 745 |
+
n = self.attn2(n, context=context_attn2, value=value_attn2)
|
| 746 |
+
|
| 747 |
+
if "attn2_output_patch" in transformer_patches:
|
| 748 |
+
patch = transformer_patches["attn2_output_patch"]
|
| 749 |
+
for p in patch:
|
| 750 |
+
n = p(n, extra_options)
|
| 751 |
+
|
| 752 |
+
x += n
|
| 753 |
+
if self.is_res:
|
| 754 |
+
x_skip = x
|
| 755 |
+
x = self.ff(self.norm3(x))
|
| 756 |
+
if self.is_res:
|
| 757 |
+
x += x_skip
|
| 758 |
+
|
| 759 |
+
return x
|
| 760 |
+
|
| 761 |
+
|
| 762 |
+
class RefTimestepEmbedSequential(openaimodel.TimestepEmbedSequential):
|
| 763 |
+
injection_holder: InjectionTimestepEmbedSequentialHolder = None
|
| 764 |
+
|
| 765 |
+
def forward_timestep_embed_ref_inject_factory(orig_timestep_embed_inject_factory: Callable):
|
| 766 |
+
def forward_timestep_embed_ref_inject(*args, **kwargs):
|
| 767 |
+
ts: RefTimestepEmbedSequential = args[0]
|
| 768 |
+
if not hasattr(ts, "injection_holder"):
|
| 769 |
+
return orig_timestep_embed_inject_factory(*args, **kwargs)
|
| 770 |
+
eps = 1e-6
|
| 771 |
+
x: Tensor = orig_timestep_embed_inject_factory(*args, **kwargs)
|
| 772 |
+
y: Tensor = None
|
| 773 |
+
transformer_options: dict[str] = args[4]
|
| 774 |
+
# Reference CN stuff
|
| 775 |
+
uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, [])
|
| 776 |
+
c_idx_mask = transformer_options.get(REF_COND_IDXS, [])
|
| 777 |
+
# WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced
|
| 778 |
+
ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_ADAIN_CONTROL_LIST, None)
|
| 779 |
+
ref_machine_state: str = transformer_options.get(REF_ADAIN_MACHINE_STATE, None)
|
| 780 |
+
|
| 781 |
+
# if in WRITE mode, save var, mean, and style_cfg
|
| 782 |
+
if ref_machine_state == MachineState.WRITE:
|
| 783 |
+
if ref_controlnets[0].ref_opts.adain_ref_weight > ts.injection_holder.gn_weight:
|
| 784 |
+
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
| 785 |
+
ts.injection_holder.bank_styles.var_bank.append(var)
|
| 786 |
+
ts.injection_holder.bank_styles.mean_bank.append(mean)
|
| 787 |
+
ts.injection_holder.bank_styles.style_cfgs.append(ref_controlnets[0].ref_opts.adain_style_fidelity)
|
| 788 |
+
ts.injection_holder.bank_styles.cn_idx.append(ref_controlnets[0].order)
|
| 789 |
+
# if in READ mode, do math with saved var, mean, and style_cfg
|
| 790 |
+
if ref_machine_state == MachineState.READ:
|
| 791 |
+
if len(ts.injection_holder.bank_styles.var_bank) > 0:
|
| 792 |
+
bank_styles = ts.injection_holder.bank_styles
|
| 793 |
+
var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
|
| 794 |
+
std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
|
| 795 |
+
y_uc = torch.zeros_like(x)
|
| 796 |
+
cn_idx = 0
|
| 797 |
+
for idx, order in enumerate(bank_styles.cn_idx):
|
| 798 |
+
# make sure matching ref cn is selected
|
| 799 |
+
for i in range(cn_idx, len(ref_controlnets)):
|
| 800 |
+
if ref_controlnets[i].order == order:
|
| 801 |
+
cn_idx = i
|
| 802 |
+
break
|
| 803 |
+
assert order == ref_controlnets[cn_idx].order
|
| 804 |
+
style_fidelity = bank_styles.style_cfgs[idx]
|
| 805 |
+
var_acc = bank_styles.var_bank[idx]
|
| 806 |
+
mean_acc = bank_styles.mean_bank[idx]
|
| 807 |
+
std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
|
| 808 |
+
sub_y_uc = (((x - mean) / std) * std_acc) + mean_acc
|
| 809 |
+
if ref_controlnets[cn_idx].any_adain_strength_to_apply():
|
| 810 |
+
effective_strength = ref_controlnets[cn_idx].get_effective_adain_mask_or_float(x=x)
|
| 811 |
+
sub_y_uc = sub_y_uc * effective_strength + x * (1-effective_strength)
|
| 812 |
+
y_uc += sub_y_uc
|
| 813 |
+
# get average, if more than one
|
| 814 |
+
if len(bank_styles.cn_idx) > 1:
|
| 815 |
+
y_uc /= len(bank_styles.cn_idx)
|
| 816 |
+
y_c = y_uc.clone()
|
| 817 |
+
if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
|
| 818 |
+
y_c[uc_idx_mask] = x.to(y_c.dtype)[uc_idx_mask]
|
| 819 |
+
y = style_fidelity * y_c + (1.0 - style_fidelity) * y_uc
|
| 820 |
+
ts.injection_holder.bank_styles.clean()
|
| 821 |
+
|
| 822 |
+
if y is None:
|
| 823 |
+
y = x
|
| 824 |
+
return y.to(x.dtype)
|
| 825 |
+
|
| 826 |
+
return forward_timestep_embed_ref_inject
|
| 827 |
+
|
| 828 |
+
# DFS Search for Torch.nn.Module, Written by Lvmin
|
| 829 |
+
def torch_dfs(model: torch.nn.Module):
|
| 830 |
+
result = [model]
|
| 831 |
+
for child in model.children():
|
| 832 |
+
result += torch_dfs(child)
|
| 833 |
+
return result
|
ComfyUI-Advanced-ControlNet/adv_control/control_sparsectrl.py
ADDED
|
@@ -0,0 +1,949 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#taken from: https://github.com/lllyasviel/ControlNet
|
| 2 |
+
#and modified
|
| 3 |
+
#and then taken from comfy/cldm/cldm.py and modified again
|
| 4 |
+
|
| 5 |
+
from abc import ABC, abstractmethod
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Iterable, Union
|
| 9 |
+
import torch
|
| 10 |
+
import torch as th
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
from einops import rearrange, repeat
|
| 14 |
+
|
| 15 |
+
from comfy.ldm.modules.diffusionmodules.util import (
|
| 16 |
+
zero_module,
|
| 17 |
+
timestep_embedding,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from comfy.cli_args import args
|
| 21 |
+
from comfy.cldm.cldm import ControlNet as ControlNetCLDM
|
| 22 |
+
from comfy.ldm.modules.attention import SpatialTransformer
|
| 23 |
+
from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default
|
| 24 |
+
from comfy.ldm.modules.attention import FeedForward, SpatialTransformer
|
| 25 |
+
from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample
|
| 26 |
+
from comfy.model_patcher import ModelPatcher
|
| 27 |
+
from comfy.controlnet import broadcast_image_to
|
| 28 |
+
from comfy.utils import repeat_to_batch_size
|
| 29 |
+
import comfy.ops
|
| 30 |
+
import comfy.model_management
|
| 31 |
+
|
| 32 |
+
from .utils import TimestepKeyframeGroup, disable_weight_init_clean_groupnorm, prepare_mask_batch
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# until xformers bug is fixed, do not use xformers for VersatileAttention! TODO: change this when fix is out
|
| 36 |
+
# logic for choosing optimized_attention method taken from comfy/ldm/modules/attention.py
|
| 37 |
+
optimized_attention_mm = attention_basic
|
| 38 |
+
if comfy.model_management.xformers_enabled():
|
| 39 |
+
pass
|
| 40 |
+
#optimized_attention_mm = attention_xformers
|
| 41 |
+
if comfy.model_management.pytorch_attention_enabled():
|
| 42 |
+
optimized_attention_mm = attention_pytorch
|
| 43 |
+
else:
|
| 44 |
+
if args.use_split_cross_attention:
|
| 45 |
+
optimized_attention_mm = attention_split
|
| 46 |
+
else:
|
| 47 |
+
optimized_attention_mm = attention_sub_quad
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class SparseControlNet(ControlNetCLDM):
|
| 51 |
+
def __init__(self, *args,**kwargs):
|
| 52 |
+
super().__init__(*args, **kwargs)
|
| 53 |
+
hint_channels = kwargs.get("hint_channels")
|
| 54 |
+
operations: disable_weight_init_clean_groupnorm = kwargs.get("operations", disable_weight_init_clean_groupnorm)
|
| 55 |
+
device = kwargs.get("device", None)
|
| 56 |
+
self.use_simplified_conditioning_embedding = kwargs.get("use_simplified_conditioning_embedding", False)
|
| 57 |
+
if self.use_simplified_conditioning_embedding:
|
| 58 |
+
self.input_hint_block = TimestepEmbedSequential(
|
| 59 |
+
zero_module(operations.conv_nd(self.dims, hint_channels, self.model_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
| 60 |
+
)
|
| 61 |
+
self.motion_wrapper: SparseCtrlMotionWrapper = None
|
| 62 |
+
|
| 63 |
+
def set_actual_length(self, actual_length: int, full_length: int):
|
| 64 |
+
if self.motion_wrapper is not None:
|
| 65 |
+
self.motion_wrapper.set_video_length(video_length=actual_length, full_length=full_length)
|
| 66 |
+
|
| 67 |
+
def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
|
| 68 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
| 69 |
+
emb = self.time_embed(t_emb)
|
| 70 |
+
|
| 71 |
+
# SparseCtrl sets noisy input to zeros
|
| 72 |
+
x = torch.zeros_like(x)
|
| 73 |
+
guided_hint = self.input_hint_block(hint, emb, context)
|
| 74 |
+
|
| 75 |
+
outs = []
|
| 76 |
+
|
| 77 |
+
hs = []
|
| 78 |
+
if self.num_classes is not None:
|
| 79 |
+
assert y.shape[0] == x.shape[0]
|
| 80 |
+
emb = emb + self.label_emb(y)
|
| 81 |
+
|
| 82 |
+
h = x
|
| 83 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
| 84 |
+
if guided_hint is not None:
|
| 85 |
+
h = module(h, emb, context)
|
| 86 |
+
h += guided_hint
|
| 87 |
+
guided_hint = None
|
| 88 |
+
else:
|
| 89 |
+
h = module(h, emb, context)
|
| 90 |
+
outs.append(zero_conv(h, emb, context))
|
| 91 |
+
|
| 92 |
+
h = self.middle_block(h, emb, context)
|
| 93 |
+
outs.append(self.middle_block_out(h, emb, context))
|
| 94 |
+
|
| 95 |
+
return outs
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SparseModelPatcher(ModelPatcher):
|
| 99 |
+
def __init__(self, *args, **kwargs):
|
| 100 |
+
self.model: SparseControlNet
|
| 101 |
+
super().__init__(*args, **kwargs)
|
| 102 |
+
|
| 103 |
+
def patch_model(self, device_to=None, patch_weights=True):
|
| 104 |
+
if patch_weights:
|
| 105 |
+
patched_model = super().patch_model(device_to)
|
| 106 |
+
else:
|
| 107 |
+
patched_model = super().patch_model(device_to, patch_weights)
|
| 108 |
+
try:
|
| 109 |
+
if self.model.motion_wrapper is not None:
|
| 110 |
+
self.model.motion_wrapper.to(device=device_to)
|
| 111 |
+
except Exception:
|
| 112 |
+
pass
|
| 113 |
+
return patched_model
|
| 114 |
+
|
| 115 |
+
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
| 116 |
+
try:
|
| 117 |
+
if self.model.motion_wrapper is not None:
|
| 118 |
+
self.model.motion_wrapper.to(device=device_to)
|
| 119 |
+
except Exception:
|
| 120 |
+
pass
|
| 121 |
+
if unpatch_weights:
|
| 122 |
+
return super().unpatch_model(device_to)
|
| 123 |
+
else:
|
| 124 |
+
return super().unpatch_model(device_to, unpatch_weights)
|
| 125 |
+
|
| 126 |
+
def clone(self):
|
| 127 |
+
# normal ModelPatcher clone actions
|
| 128 |
+
n = SparseModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
|
| 129 |
+
n.patches = {}
|
| 130 |
+
for k in self.patches:
|
| 131 |
+
n.patches[k] = self.patches[k][:]
|
| 132 |
+
if hasattr(n, "patches_uuid"):
|
| 133 |
+
self.patches_uuid = n.patches_uuid
|
| 134 |
+
|
| 135 |
+
n.object_patches = self.object_patches.copy()
|
| 136 |
+
n.model_options = copy.deepcopy(self.model_options)
|
| 137 |
+
n.model_keys = self.model_keys
|
| 138 |
+
if hasattr(n, "backup"):
|
| 139 |
+
self.backup = n.backup
|
| 140 |
+
if hasattr(n, "object_patches_backup"):
|
| 141 |
+
self.object_patches_backup = n.object_patches_backup
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class PreprocSparseRGBWrapper:
|
| 145 |
+
error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
|
| 146 |
+
def __init__(self, condhint: Tensor):
|
| 147 |
+
self.condhint = condhint
|
| 148 |
+
|
| 149 |
+
def movedim(self, *args, **kwargs):
|
| 150 |
+
return self
|
| 151 |
+
|
| 152 |
+
def __getattr__(self, *args, **kwargs):
|
| 153 |
+
raise AttributeError(self.error_msg)
|
| 154 |
+
|
| 155 |
+
def __setattr__(self, name, value):
|
| 156 |
+
if name != "condhint":
|
| 157 |
+
raise AttributeError(self.error_msg)
|
| 158 |
+
super().__setattr__(name, value)
|
| 159 |
+
|
| 160 |
+
def __iter__(self, *args, **kwargs):
|
| 161 |
+
raise AttributeError(self.error_msg)
|
| 162 |
+
|
| 163 |
+
def __next__(self, *args, **kwargs):
|
| 164 |
+
raise AttributeError(self.error_msg)
|
| 165 |
+
|
| 166 |
+
def __len__(self, *args, **kwargs):
|
| 167 |
+
raise AttributeError(self.error_msg)
|
| 168 |
+
|
| 169 |
+
def __getitem__(self, *args, **kwargs):
|
| 170 |
+
raise AttributeError(self.error_msg)
|
| 171 |
+
|
| 172 |
+
def __setitem__(self, *args, **kwargs):
|
| 173 |
+
raise AttributeError(self.error_msg)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class SparseSettings:
|
| 177 |
+
def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False):
|
| 178 |
+
self.sparse_method = sparse_method
|
| 179 |
+
self.use_motion = use_motion
|
| 180 |
+
self.motion_strength = motion_strength
|
| 181 |
+
self.motion_scale = motion_scale
|
| 182 |
+
self.merged = merged
|
| 183 |
+
|
| 184 |
+
@classmethod
|
| 185 |
+
def default(cls):
|
| 186 |
+
return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class SparseMethod(ABC):
|
| 190 |
+
SPREAD = "spread"
|
| 191 |
+
INDEX = "index"
|
| 192 |
+
def __init__(self, method: str):
|
| 193 |
+
self.method = method
|
| 194 |
+
|
| 195 |
+
@abstractmethod
|
| 196 |
+
def get_indexes(self, hint_length: int, full_length: int) -> list[int]:
|
| 197 |
+
pass
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class SparseSpreadMethod(SparseMethod):
|
| 201 |
+
UNIFORM = "uniform"
|
| 202 |
+
STARTING = "starting"
|
| 203 |
+
ENDING = "ending"
|
| 204 |
+
CENTER = "center"
|
| 205 |
+
|
| 206 |
+
LIST = [UNIFORM, STARTING, ENDING, CENTER]
|
| 207 |
+
|
| 208 |
+
def __init__(self, spread=UNIFORM):
|
| 209 |
+
super().__init__(self.SPREAD)
|
| 210 |
+
self.spread = spread
|
| 211 |
+
|
| 212 |
+
def get_indexes(self, hint_length: int, full_length: int) -> list[int]:
|
| 213 |
+
# if hint_length >= full_length, limit hints to full_length
|
| 214 |
+
if hint_length >= full_length:
|
| 215 |
+
return list(range(full_length))
|
| 216 |
+
# handle special case of 1 hint image
|
| 217 |
+
if hint_length == 1:
|
| 218 |
+
if self.spread in [self.UNIFORM, self.STARTING]:
|
| 219 |
+
return [0]
|
| 220 |
+
elif self.spread == self.ENDING:
|
| 221 |
+
return [full_length-1]
|
| 222 |
+
elif self.spread == self.CENTER:
|
| 223 |
+
# return second (of three) values as the center
|
| 224 |
+
return [np.linspace(0, full_length-1, 3, endpoint=True, dtype=int)[1]]
|
| 225 |
+
else:
|
| 226 |
+
raise ValueError(f"Unrecognized spread: {self.spread}")
|
| 227 |
+
# otherwise, handle other cases
|
| 228 |
+
if self.spread == self.UNIFORM:
|
| 229 |
+
return list(np.linspace(0, full_length-1, hint_length, endpoint=True, dtype=int))
|
| 230 |
+
elif self.spread == self.STARTING:
|
| 231 |
+
# make split 1 larger, remove last element
|
| 232 |
+
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
|
| 233 |
+
elif self.spread == self.ENDING:
|
| 234 |
+
# make split 1 larger, remove first element
|
| 235 |
+
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[1:]
|
| 236 |
+
elif self.spread == self.CENTER:
|
| 237 |
+
# if hint length is not 3 greater than full length, do STARTING behavior
|
| 238 |
+
if full_length-hint_length < 3:
|
| 239 |
+
return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
|
| 240 |
+
# otherwise, get linspace of 2 greater than needed, then cut off first and last
|
| 241 |
+
return list(np.linspace(0, full_length-1, hint_length+2, endpoint=True, dtype=int))[1:-1]
|
| 242 |
+
return ValueError(f"Unrecognized spread: {self.spread}")
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
class SparseIndexMethod(SparseMethod):
|
| 246 |
+
def __init__(self, idxs: list[int]):
|
| 247 |
+
super().__init__(self.INDEX)
|
| 248 |
+
self.idxs = idxs
|
| 249 |
+
|
| 250 |
+
def get_indexes(self, hint_length: int, full_length: int) -> list[int]:
|
| 251 |
+
orig_hint_length = hint_length
|
| 252 |
+
if hint_length > full_length:
|
| 253 |
+
hint_length = full_length
|
| 254 |
+
# if idxs is less than hint_length, throw error
|
| 255 |
+
if len(self.idxs) < hint_length:
|
| 256 |
+
err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images."
|
| 257 |
+
if orig_hint_length != hint_length:
|
| 258 |
+
err_msg = f"{err_msg} (original input images: {orig_hint_length})"
|
| 259 |
+
raise ValueError(err_msg)
|
| 260 |
+
# cap idxs to hint_length
|
| 261 |
+
idxs = self.idxs[:hint_length]
|
| 262 |
+
new_idxs = []
|
| 263 |
+
real_idxs = set()
|
| 264 |
+
for idx in idxs:
|
| 265 |
+
if idx < 0:
|
| 266 |
+
real_idx = full_length+idx
|
| 267 |
+
if real_idx in real_idxs:
|
| 268 |
+
raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.")
|
| 269 |
+
else:
|
| 270 |
+
real_idx = idx
|
| 271 |
+
if real_idx in real_idxs:
|
| 272 |
+
raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.")
|
| 273 |
+
real_idxs.add(real_idx)
|
| 274 |
+
new_idxs.append(real_idx)
|
| 275 |
+
return new_idxs
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
#########################################
|
| 279 |
+
# motion-related portion of controlnet
|
| 280 |
+
class BlockType:
|
| 281 |
+
UP = "up"
|
| 282 |
+
DOWN = "down"
|
| 283 |
+
MID = "mid"
|
| 284 |
+
|
| 285 |
+
def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int:
|
| 286 |
+
return get_block_max(mm_state_dict, "down_blocks")
|
| 287 |
+
|
| 288 |
+
def get_up_block_max(mm_state_dict: dict[str, Tensor]) -> int:
|
| 289 |
+
return get_block_max(mm_state_dict, "up_blocks")
|
| 290 |
+
|
| 291 |
+
def get_block_max(mm_state_dict: dict[str, Tensor], block_name: str) -> int:
|
| 292 |
+
# keep track of biggest down_block count in module
|
| 293 |
+
biggest_block = -1
|
| 294 |
+
for key in mm_state_dict.keys():
|
| 295 |
+
if block_name in key:
|
| 296 |
+
try:
|
| 297 |
+
block_int = key.split(".")[1]
|
| 298 |
+
block_num = int(block_int)
|
| 299 |
+
if block_num > biggest_block:
|
| 300 |
+
biggest_block = block_num
|
| 301 |
+
except ValueError:
|
| 302 |
+
pass
|
| 303 |
+
return biggest_block
|
| 304 |
+
|
| 305 |
+
def has_mid_block(mm_state_dict: dict[str, Tensor]):
|
| 306 |
+
# check if keys contain mid_block
|
| 307 |
+
for key in mm_state_dict.keys():
|
| 308 |
+
if key.startswith("mid_block."):
|
| 309 |
+
return True
|
| 310 |
+
return False
|
| 311 |
+
|
| 312 |
+
def get_position_encoding_max_len(mm_state_dict: dict[str, Tensor], mm_name: str=None) -> int:
|
| 313 |
+
# use pos_encoder.pe entries to determine max length - [1, {max_length}, {320|640|1280}]
|
| 314 |
+
for key in mm_state_dict.keys():
|
| 315 |
+
if key.endswith("pos_encoder.pe"):
|
| 316 |
+
return mm_state_dict[key].size(1) # get middle dim
|
| 317 |
+
raise ValueError(f"No pos_encoder.pe found in SparseCtrl state_dict - {mm_name} is not a valid SparseCtrl model!")
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
class SparseCtrlMotionWrapper(nn.Module):
|
| 321 |
+
def __init__(self, mm_state_dict: dict[str, Tensor]):
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.down_blocks: Iterable[MotionModule] = None
|
| 324 |
+
self.up_blocks: Iterable[MotionModule] = None
|
| 325 |
+
self.mid_block: MotionModule = None
|
| 326 |
+
self.encoding_max_len = get_position_encoding_max_len(mm_state_dict, "")
|
| 327 |
+
layer_channels = (320, 640, 1280, 1280)
|
| 328 |
+
if get_down_block_max(mm_state_dict) > -1:
|
| 329 |
+
self.down_blocks = nn.ModuleList([])
|
| 330 |
+
for c in layer_channels:
|
| 331 |
+
self.down_blocks.append(MotionModule(c, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.DOWN))
|
| 332 |
+
if get_up_block_max(mm_state_dict) > -1:
|
| 333 |
+
self.up_blocks = nn.ModuleList([])
|
| 334 |
+
for c in reversed(layer_channels):
|
| 335 |
+
self.up_blocks.append(MotionModule(c, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.UP))
|
| 336 |
+
if has_mid_block(mm_state_dict):
|
| 337 |
+
self.mid_block = MotionModule(1280, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.MID)
|
| 338 |
+
|
| 339 |
+
def inject(self, unet: SparseControlNet):
|
| 340 |
+
# inject input (down) blocks
|
| 341 |
+
self._inject(unet.input_blocks, self.down_blocks)
|
| 342 |
+
# inject mid block, if present
|
| 343 |
+
if self.mid_block is not None:
|
| 344 |
+
self._inject([unet.middle_block], [self.mid_block])
|
| 345 |
+
unet.motion_wrapper = self
|
| 346 |
+
|
| 347 |
+
def _inject(self, unet_blocks: nn.ModuleList, mm_blocks: nn.ModuleList):
|
| 348 |
+
# Rules for injection:
|
| 349 |
+
# For each component list in a unet block:
|
| 350 |
+
# if SpatialTransformer exists in list, place next block after last occurrence
|
| 351 |
+
# elif ResBlock exists in list, place next block after first occurrence
|
| 352 |
+
# else don't place block
|
| 353 |
+
injection_count = 0
|
| 354 |
+
unet_idx = 0
|
| 355 |
+
# details about blocks passed in
|
| 356 |
+
per_block = len(mm_blocks[0].motion_modules)
|
| 357 |
+
injection_goal = len(mm_blocks) * per_block
|
| 358 |
+
# only stop injecting when modules exhausted
|
| 359 |
+
while injection_count < injection_goal:
|
| 360 |
+
# figure out which VanillaTemporalModule from mm to inject
|
| 361 |
+
mm_blk_idx, mm_vtm_idx = injection_count // per_block, injection_count % per_block
|
| 362 |
+
# figure out layout of unet block components
|
| 363 |
+
st_idx = -1 # SpatialTransformer index
|
| 364 |
+
res_idx = -1 # first ResBlock index
|
| 365 |
+
# first, figure out indeces of relevant blocks
|
| 366 |
+
for idx, component in enumerate(unet_blocks[unet_idx]):
|
| 367 |
+
if type(component) == SpatialTransformer:
|
| 368 |
+
st_idx = idx
|
| 369 |
+
elif type(component).__name__ == "ResBlock" and res_idx < 0:
|
| 370 |
+
res_idx = idx
|
| 371 |
+
# if SpatialTransformer exists, inject right after
|
| 372 |
+
if st_idx >= 0:
|
| 373 |
+
unet_blocks[unet_idx].insert(st_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx])
|
| 374 |
+
injection_count += 1
|
| 375 |
+
# otherwise, if only ResBlock exists, inject right after
|
| 376 |
+
elif res_idx >= 0:
|
| 377 |
+
unet_blocks[unet_idx].insert(res_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx])
|
| 378 |
+
injection_count += 1
|
| 379 |
+
# increment unet_idx
|
| 380 |
+
unet_idx += 1
|
| 381 |
+
|
| 382 |
+
def eject(self, unet: SparseControlNet):
|
| 383 |
+
# remove from input blocks (downblocks)
|
| 384 |
+
self._eject(unet.input_blocks)
|
| 385 |
+
# remove from middle block (encapsulate in list to make compatible)
|
| 386 |
+
self._eject([unet.middle_block])
|
| 387 |
+
del unet.motion_wrapper
|
| 388 |
+
unet.motion_wrapper = None
|
| 389 |
+
|
| 390 |
+
def _eject(self, unet_blocks: nn.ModuleList):
|
| 391 |
+
# eject all VanillaTemporalModule objects from all blocks
|
| 392 |
+
for block in unet_blocks:
|
| 393 |
+
idx_to_pop = []
|
| 394 |
+
for idx, component in enumerate(block):
|
| 395 |
+
if type(component) == VanillaTemporalModule:
|
| 396 |
+
idx_to_pop.append(idx)
|
| 397 |
+
# pop in backwards order, as to not disturb what the indeces refer to
|
| 398 |
+
for idx in sorted(idx_to_pop, reverse=True):
|
| 399 |
+
block.pop(idx)
|
| 400 |
+
|
| 401 |
+
def set_video_length(self, video_length: int, full_length: int):
|
| 402 |
+
self.AD_video_length = video_length
|
| 403 |
+
if self.down_blocks is not None:
|
| 404 |
+
for block in self.down_blocks:
|
| 405 |
+
block.set_video_length(video_length, full_length)
|
| 406 |
+
if self.up_blocks is not None:
|
| 407 |
+
for block in self.up_blocks:
|
| 408 |
+
block.set_video_length(video_length, full_length)
|
| 409 |
+
if self.mid_block is not None:
|
| 410 |
+
self.mid_block.set_video_length(video_length, full_length)
|
| 411 |
+
|
| 412 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
| 413 |
+
if self.down_blocks is not None:
|
| 414 |
+
for block in self.down_blocks:
|
| 415 |
+
block.set_scale_multiplier(multiplier)
|
| 416 |
+
if self.up_blocks is not None:
|
| 417 |
+
for block in self.up_blocks:
|
| 418 |
+
block.set_scale_multiplier(multiplier)
|
| 419 |
+
if self.mid_block is not None:
|
| 420 |
+
self.mid_block.set_scale_multiplier(multiplier)
|
| 421 |
+
|
| 422 |
+
def set_strength(self, strength: float):
|
| 423 |
+
if self.down_blocks is not None:
|
| 424 |
+
for block in self.down_blocks:
|
| 425 |
+
block.set_strength(strength)
|
| 426 |
+
if self.up_blocks is not None:
|
| 427 |
+
for block in self.up_blocks:
|
| 428 |
+
block.set_strength(strength)
|
| 429 |
+
if self.mid_block is not None:
|
| 430 |
+
self.mid_block.set_strength(strength)
|
| 431 |
+
|
| 432 |
+
def reset_temp_vars(self):
|
| 433 |
+
if self.down_blocks is not None:
|
| 434 |
+
for block in self.down_blocks:
|
| 435 |
+
block.reset_temp_vars()
|
| 436 |
+
if self.up_blocks is not None:
|
| 437 |
+
for block in self.up_blocks:
|
| 438 |
+
block.reset_temp_vars()
|
| 439 |
+
if self.mid_block is not None:
|
| 440 |
+
self.mid_block.reset_temp_vars()
|
| 441 |
+
|
| 442 |
+
def reset_scale_multiplier(self):
|
| 443 |
+
self.set_scale_multiplier(None)
|
| 444 |
+
|
| 445 |
+
def reset(self):
|
| 446 |
+
self.reset_scale_multiplier()
|
| 447 |
+
self.reset_temp_vars()
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class MotionModule(nn.Module):
|
| 451 |
+
def __init__(self, in_channels, temporal_position_encoding_max_len=24, block_type: str=BlockType.DOWN):
|
| 452 |
+
super().__init__()
|
| 453 |
+
if block_type == BlockType.MID:
|
| 454 |
+
# mid blocks contain only a single VanillaTemporalModule
|
| 455 |
+
self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList([get_motion_module(in_channels, temporal_position_encoding_max_len)])
|
| 456 |
+
else:
|
| 457 |
+
# down blocks contain two VanillaTemporalModules
|
| 458 |
+
self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList(
|
| 459 |
+
[
|
| 460 |
+
get_motion_module(in_channels, temporal_position_encoding_max_len),
|
| 461 |
+
get_motion_module(in_channels, temporal_position_encoding_max_len)
|
| 462 |
+
]
|
| 463 |
+
)
|
| 464 |
+
# up blocks contain one additional VanillaTemporalModule
|
| 465 |
+
if block_type == BlockType.UP:
|
| 466 |
+
self.motion_modules.append(get_motion_module(in_channels, temporal_position_encoding_max_len))
|
| 467 |
+
|
| 468 |
+
def set_video_length(self, video_length: int, full_length: int):
|
| 469 |
+
for motion_module in self.motion_modules:
|
| 470 |
+
motion_module.set_video_length(video_length, full_length)
|
| 471 |
+
|
| 472 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
| 473 |
+
for motion_module in self.motion_modules:
|
| 474 |
+
motion_module.set_scale_multiplier(multiplier)
|
| 475 |
+
|
| 476 |
+
def set_masks(self, masks: Tensor, min_val: float, max_val: float):
|
| 477 |
+
for motion_module in self.motion_modules:
|
| 478 |
+
motion_module.set_masks(masks, min_val, max_val)
|
| 479 |
+
|
| 480 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
| 481 |
+
for motion_module in self.motion_modules:
|
| 482 |
+
motion_module.set_sub_idxs(sub_idxs)
|
| 483 |
+
|
| 484 |
+
def set_strength(self, strength: float):
|
| 485 |
+
for motion_module in self.motion_modules:
|
| 486 |
+
motion_module.set_strength(strength)
|
| 487 |
+
|
| 488 |
+
def reset_temp_vars(self):
|
| 489 |
+
for motion_module in self.motion_modules:
|
| 490 |
+
motion_module.reset_temp_vars()
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def get_motion_module(in_channels, temporal_position_encoding_max_len):
|
| 494 |
+
# unlike normal AD, there is only one attention block expected in SparseCtrl models
|
| 495 |
+
return VanillaTemporalModule(in_channels=in_channels, attention_block_types=("Temporal_Self",), temporal_position_encoding_max_len=temporal_position_encoding_max_len)
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
class VanillaTemporalModule(nn.Module):
|
| 499 |
+
def __init__(
|
| 500 |
+
self,
|
| 501 |
+
in_channels,
|
| 502 |
+
num_attention_heads=8,
|
| 503 |
+
num_transformer_block=1,
|
| 504 |
+
attention_block_types=("Temporal_Self", "Temporal_Self"),
|
| 505 |
+
cross_frame_attention_mode=None,
|
| 506 |
+
temporal_position_encoding=True,
|
| 507 |
+
temporal_position_encoding_max_len=24,
|
| 508 |
+
temporal_attention_dim_div=1,
|
| 509 |
+
zero_initialize=True,
|
| 510 |
+
):
|
| 511 |
+
super().__init__()
|
| 512 |
+
self.strength = 1.0
|
| 513 |
+
self.temporal_transformer = TemporalTransformer3DModel(
|
| 514 |
+
in_channels=in_channels,
|
| 515 |
+
num_attention_heads=num_attention_heads,
|
| 516 |
+
attention_head_dim=in_channels
|
| 517 |
+
// num_attention_heads
|
| 518 |
+
// temporal_attention_dim_div,
|
| 519 |
+
num_layers=num_transformer_block,
|
| 520 |
+
attention_block_types=attention_block_types,
|
| 521 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 522 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 523 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
if zero_initialize:
|
| 527 |
+
self.temporal_transformer.proj_out = zero_module(
|
| 528 |
+
self.temporal_transformer.proj_out
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
def set_video_length(self, video_length: int, full_length: int):
|
| 532 |
+
self.temporal_transformer.set_video_length(video_length, full_length)
|
| 533 |
+
|
| 534 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
| 535 |
+
self.temporal_transformer.set_scale_multiplier(multiplier)
|
| 536 |
+
|
| 537 |
+
def set_masks(self, masks: Tensor, min_val: float, max_val: float):
|
| 538 |
+
self.temporal_transformer.set_masks(masks, min_val, max_val)
|
| 539 |
+
|
| 540 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
| 541 |
+
self.temporal_transformer.set_sub_idxs(sub_idxs)
|
| 542 |
+
|
| 543 |
+
def set_strength(self, strength: float):
|
| 544 |
+
self.strength = strength
|
| 545 |
+
|
| 546 |
+
def reset_temp_vars(self):
|
| 547 |
+
self.set_strength(1.0)
|
| 548 |
+
self.temporal_transformer.reset_temp_vars()
|
| 549 |
+
|
| 550 |
+
def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None):
|
| 551 |
+
if math.isclose(self.strength, 1.0):
|
| 552 |
+
return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)
|
| 553 |
+
elif math.isclose(self.strength, 0.0):
|
| 554 |
+
return input_tensor
|
| 555 |
+
elif self.strength > 1.0:
|
| 556 |
+
return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)*self.strength
|
| 557 |
+
else:
|
| 558 |
+
return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)*self.strength + input_tensor*(1.0-self.strength)
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class TemporalTransformer3DModel(nn.Module):
|
| 562 |
+
def __init__(
|
| 563 |
+
self,
|
| 564 |
+
in_channels,
|
| 565 |
+
num_attention_heads,
|
| 566 |
+
attention_head_dim,
|
| 567 |
+
num_layers,
|
| 568 |
+
attention_block_types=(
|
| 569 |
+
"Temporal_Self",
|
| 570 |
+
"Temporal_Self",
|
| 571 |
+
),
|
| 572 |
+
dropout=0.0,
|
| 573 |
+
norm_num_groups=32,
|
| 574 |
+
cross_attention_dim=768,
|
| 575 |
+
activation_fn="geglu",
|
| 576 |
+
attention_bias=False,
|
| 577 |
+
upcast_attention=False,
|
| 578 |
+
cross_frame_attention_mode=None,
|
| 579 |
+
temporal_position_encoding=False,
|
| 580 |
+
temporal_position_encoding_max_len=24,
|
| 581 |
+
):
|
| 582 |
+
super().__init__()
|
| 583 |
+
self.video_length = 16
|
| 584 |
+
self.full_length = 16
|
| 585 |
+
self.scale_min = 1.0
|
| 586 |
+
self.scale_max = 1.0
|
| 587 |
+
self.raw_scale_mask: Union[Tensor, None] = None
|
| 588 |
+
self.temp_scale_mask: Union[Tensor, None] = None
|
| 589 |
+
self.sub_idxs: Union[list[int], None] = None
|
| 590 |
+
self.prev_hidden_states_batch = 0
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 594 |
+
|
| 595 |
+
self.norm = disable_weight_init_clean_groupnorm.GroupNorm(
|
| 596 |
+
num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
|
| 597 |
+
)
|
| 598 |
+
self.proj_in = nn.Linear(in_channels, inner_dim)
|
| 599 |
+
|
| 600 |
+
self.transformer_blocks: Iterable[TemporalTransformerBlock] = nn.ModuleList(
|
| 601 |
+
[
|
| 602 |
+
TemporalTransformerBlock(
|
| 603 |
+
dim=inner_dim,
|
| 604 |
+
num_attention_heads=num_attention_heads,
|
| 605 |
+
attention_head_dim=attention_head_dim,
|
| 606 |
+
attention_block_types=attention_block_types,
|
| 607 |
+
dropout=dropout,
|
| 608 |
+
norm_num_groups=norm_num_groups,
|
| 609 |
+
cross_attention_dim=cross_attention_dim,
|
| 610 |
+
activation_fn=activation_fn,
|
| 611 |
+
attention_bias=attention_bias,
|
| 612 |
+
upcast_attention=upcast_attention,
|
| 613 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 614 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 615 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 616 |
+
)
|
| 617 |
+
for d in range(num_layers)
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
self.proj_out = nn.Linear(inner_dim, in_channels)
|
| 621 |
+
|
| 622 |
+
def set_video_length(self, video_length: int, full_length: int):
|
| 623 |
+
self.video_length = video_length
|
| 624 |
+
self.full_length = full_length
|
| 625 |
+
|
| 626 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
| 627 |
+
for block in self.transformer_blocks:
|
| 628 |
+
block.set_scale_multiplier(multiplier)
|
| 629 |
+
|
| 630 |
+
def set_masks(self, masks: Tensor, min_val: float, max_val: float):
|
| 631 |
+
self.scale_min = min_val
|
| 632 |
+
self.scale_max = max_val
|
| 633 |
+
self.raw_scale_mask = masks
|
| 634 |
+
|
| 635 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
| 636 |
+
self.sub_idxs = sub_idxs
|
| 637 |
+
for block in self.transformer_blocks:
|
| 638 |
+
block.set_sub_idxs(sub_idxs)
|
| 639 |
+
|
| 640 |
+
def reset_temp_vars(self):
|
| 641 |
+
del self.temp_scale_mask
|
| 642 |
+
self.temp_scale_mask = None
|
| 643 |
+
self.prev_hidden_states_batch = 0
|
| 644 |
+
|
| 645 |
+
def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]:
|
| 646 |
+
# if no raw mask, return None
|
| 647 |
+
if self.raw_scale_mask is None:
|
| 648 |
+
return None
|
| 649 |
+
shape = hidden_states.shape
|
| 650 |
+
batch, channel, height, width = shape
|
| 651 |
+
# if temp mask already calculated, return it
|
| 652 |
+
if self.temp_scale_mask != None:
|
| 653 |
+
# check if hidden_states batch matches
|
| 654 |
+
if batch == self.prev_hidden_states_batch:
|
| 655 |
+
if self.sub_idxs is not None:
|
| 656 |
+
return self.temp_scale_mask[:, self.sub_idxs, :]
|
| 657 |
+
return self.temp_scale_mask
|
| 658 |
+
# if does not match, reset cached temp_scale_mask and recalculate it
|
| 659 |
+
del self.temp_scale_mask
|
| 660 |
+
self.temp_scale_mask = None
|
| 661 |
+
# otherwise, calculate temp mask
|
| 662 |
+
self.prev_hidden_states_batch = batch
|
| 663 |
+
mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width))
|
| 664 |
+
mask = repeat_to_batch_size(mask, self.full_length)
|
| 665 |
+
# if mask not the same amount length as full length, make it match
|
| 666 |
+
if self.full_length != mask.shape[0]:
|
| 667 |
+
mask = broadcast_image_to(mask, self.full_length, 1)
|
| 668 |
+
# reshape mask to attention K shape (h*w, latent_count, 1)
|
| 669 |
+
batch, channel, height, width = mask.shape
|
| 670 |
+
# first, perform same operations as on hidden_states,
|
| 671 |
+
# turning (b, c, h, w) -> (b, h*w, c)
|
| 672 |
+
mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel)
|
| 673 |
+
# then, make it the same shape as attention's k, (h*w, b, c)
|
| 674 |
+
mask = mask.permute(1, 0, 2)
|
| 675 |
+
# make masks match the expected length of h*w
|
| 676 |
+
batched_number = shape[0] // self.video_length
|
| 677 |
+
if batched_number > 1:
|
| 678 |
+
mask = torch.cat([mask] * batched_number, dim=0)
|
| 679 |
+
# cache mask and set to proper device
|
| 680 |
+
self.temp_scale_mask = mask
|
| 681 |
+
# move temp_scale_mask to proper dtype + device
|
| 682 |
+
self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device)
|
| 683 |
+
# return subset of masks, if needed
|
| 684 |
+
if self.sub_idxs is not None:
|
| 685 |
+
return self.temp_scale_mask[:, self.sub_idxs, :]
|
| 686 |
+
return self.temp_scale_mask
|
| 687 |
+
|
| 688 |
+
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 689 |
+
batch, channel, height, width = hidden_states.shape
|
| 690 |
+
residual = hidden_states
|
| 691 |
+
scale_mask = self.get_scale_mask(hidden_states)
|
| 692 |
+
# add some casts for fp8 purposes - does not affect speed otherwise
|
| 693 |
+
hidden_states = self.norm(hidden_states).to(hidden_states.dtype)
|
| 694 |
+
inner_dim = hidden_states.shape[1]
|
| 695 |
+
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
| 696 |
+
batch, height * width, inner_dim
|
| 697 |
+
)
|
| 698 |
+
hidden_states = self.proj_in(hidden_states).to(hidden_states.dtype)
|
| 699 |
+
|
| 700 |
+
# Transformer Blocks
|
| 701 |
+
for block in self.transformer_blocks:
|
| 702 |
+
hidden_states = block(
|
| 703 |
+
hidden_states,
|
| 704 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 705 |
+
attention_mask=attention_mask,
|
| 706 |
+
video_length=self.video_length,
|
| 707 |
+
scale_mask=scale_mask
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# output
|
| 711 |
+
hidden_states = self.proj_out(hidden_states)
|
| 712 |
+
hidden_states = (
|
| 713 |
+
hidden_states.reshape(batch, height, width, inner_dim)
|
| 714 |
+
.permute(0, 3, 1, 2)
|
| 715 |
+
.contiguous()
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
output = hidden_states + residual
|
| 719 |
+
|
| 720 |
+
return output
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
class TemporalTransformerBlock(nn.Module):
|
| 724 |
+
def __init__(
|
| 725 |
+
self,
|
| 726 |
+
dim,
|
| 727 |
+
num_attention_heads,
|
| 728 |
+
attention_head_dim,
|
| 729 |
+
attention_block_types=(
|
| 730 |
+
"Temporal_Self",
|
| 731 |
+
"Temporal_Self",
|
| 732 |
+
),
|
| 733 |
+
dropout=0.0,
|
| 734 |
+
norm_num_groups=32,
|
| 735 |
+
cross_attention_dim=768,
|
| 736 |
+
activation_fn="geglu",
|
| 737 |
+
attention_bias=False,
|
| 738 |
+
upcast_attention=False,
|
| 739 |
+
cross_frame_attention_mode=None,
|
| 740 |
+
temporal_position_encoding=False,
|
| 741 |
+
temporal_position_encoding_max_len=24,
|
| 742 |
+
):
|
| 743 |
+
super().__init__()
|
| 744 |
+
|
| 745 |
+
attention_blocks = []
|
| 746 |
+
norms = []
|
| 747 |
+
|
| 748 |
+
for block_name in attention_block_types:
|
| 749 |
+
attention_blocks.append(
|
| 750 |
+
VersatileAttention(
|
| 751 |
+
attention_mode=block_name.split("_")[0],
|
| 752 |
+
context_dim=cross_attention_dim # called context_dim for ComfyUI impl
|
| 753 |
+
if block_name.endswith("_Cross")
|
| 754 |
+
else None,
|
| 755 |
+
query_dim=dim,
|
| 756 |
+
heads=num_attention_heads,
|
| 757 |
+
dim_head=attention_head_dim,
|
| 758 |
+
dropout=dropout,
|
| 759 |
+
#bias=attention_bias, # remove for Comfy CrossAttention
|
| 760 |
+
#upcast_attention=upcast_attention, # remove for Comfy CrossAttention
|
| 761 |
+
cross_frame_attention_mode=cross_frame_attention_mode,
|
| 762 |
+
temporal_position_encoding=temporal_position_encoding,
|
| 763 |
+
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
|
| 764 |
+
)
|
| 765 |
+
)
|
| 766 |
+
norms.append(nn.LayerNorm(dim))
|
| 767 |
+
|
| 768 |
+
self.attention_blocks: Iterable[VersatileAttention] = nn.ModuleList(attention_blocks)
|
| 769 |
+
self.norms = nn.ModuleList(norms)
|
| 770 |
+
|
| 771 |
+
self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"))
|
| 772 |
+
self.ff_norm = nn.LayerNorm(dim)
|
| 773 |
+
|
| 774 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
| 775 |
+
for block in self.attention_blocks:
|
| 776 |
+
block.set_scale_multiplier(multiplier)
|
| 777 |
+
|
| 778 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
| 779 |
+
for block in self.attention_blocks:
|
| 780 |
+
block.set_sub_idxs(sub_idxs)
|
| 781 |
+
|
| 782 |
+
def forward(
|
| 783 |
+
self,
|
| 784 |
+
hidden_states,
|
| 785 |
+
encoder_hidden_states=None,
|
| 786 |
+
attention_mask=None,
|
| 787 |
+
video_length=None,
|
| 788 |
+
scale_mask=None
|
| 789 |
+
):
|
| 790 |
+
for attention_block, norm in zip(self.attention_blocks, self.norms):
|
| 791 |
+
norm_hidden_states = norm(hidden_states).to(hidden_states.dtype)
|
| 792 |
+
hidden_states = (
|
| 793 |
+
attention_block(
|
| 794 |
+
norm_hidden_states,
|
| 795 |
+
encoder_hidden_states=encoder_hidden_states
|
| 796 |
+
if attention_block.is_cross_attention
|
| 797 |
+
else None,
|
| 798 |
+
attention_mask=attention_mask,
|
| 799 |
+
video_length=video_length,
|
| 800 |
+
scale_mask=scale_mask
|
| 801 |
+
)
|
| 802 |
+
+ hidden_states
|
| 803 |
+
)
|
| 804 |
+
|
| 805 |
+
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
|
| 806 |
+
|
| 807 |
+
output = hidden_states
|
| 808 |
+
return output
|
| 809 |
+
|
| 810 |
+
|
| 811 |
+
class PositionalEncoding(nn.Module):
|
| 812 |
+
def __init__(self, d_model, dropout=0.0, max_len=24):
|
| 813 |
+
super().__init__()
|
| 814 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 815 |
+
position = torch.arange(max_len).unsqueeze(1)
|
| 816 |
+
div_term = torch.exp(
|
| 817 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
| 818 |
+
)
|
| 819 |
+
pe = torch.zeros(1, max_len, d_model)
|
| 820 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 821 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
| 822 |
+
self.register_buffer("pe", pe)
|
| 823 |
+
self.sub_idxs = None
|
| 824 |
+
|
| 825 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
| 826 |
+
self.sub_idxs = sub_idxs
|
| 827 |
+
|
| 828 |
+
def forward(self, x):
|
| 829 |
+
#if self.sub_idxs is not None:
|
| 830 |
+
# x = x + self.pe[:, self.sub_idxs]
|
| 831 |
+
#else:
|
| 832 |
+
x = x + self.pe[:, : x.size(1)]
|
| 833 |
+
return self.dropout(x)
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
class CrossAttentionMM(nn.Module):
|
| 837 |
+
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None,
|
| 838 |
+
operations=comfy.ops.disable_weight_init):
|
| 839 |
+
super().__init__()
|
| 840 |
+
inner_dim = dim_head * heads
|
| 841 |
+
context_dim = default(context_dim, query_dim)
|
| 842 |
+
|
| 843 |
+
self.heads = heads
|
| 844 |
+
self.dim_head = dim_head
|
| 845 |
+
self.scale = None
|
| 846 |
+
|
| 847 |
+
self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 848 |
+
self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 849 |
+
self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
|
| 850 |
+
|
| 851 |
+
self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
|
| 852 |
+
|
| 853 |
+
def forward(self, x, context=None, value=None, mask=None, scale_mask=None):
|
| 854 |
+
q = self.to_q(x)
|
| 855 |
+
context = default(context, x)
|
| 856 |
+
k: Tensor = self.to_k(context)
|
| 857 |
+
if value is not None:
|
| 858 |
+
v = self.to_v(value)
|
| 859 |
+
del value
|
| 860 |
+
else:
|
| 861 |
+
v = self.to_v(context)
|
| 862 |
+
|
| 863 |
+
# apply custom scale by multiplying k by scale factor
|
| 864 |
+
if self.scale is not None:
|
| 865 |
+
k *= self.scale
|
| 866 |
+
|
| 867 |
+
# apply scale mask, if present
|
| 868 |
+
if scale_mask is not None:
|
| 869 |
+
k *= scale_mask
|
| 870 |
+
|
| 871 |
+
out = optimized_attention_mm(q, k, v, self.heads, mask)
|
| 872 |
+
return self.to_out(out)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
class VersatileAttention(CrossAttentionMM):
|
| 876 |
+
def __init__(
|
| 877 |
+
self,
|
| 878 |
+
attention_mode=None,
|
| 879 |
+
cross_frame_attention_mode=None,
|
| 880 |
+
temporal_position_encoding=False,
|
| 881 |
+
temporal_position_encoding_max_len=24,
|
| 882 |
+
*args,
|
| 883 |
+
**kwargs,
|
| 884 |
+
):
|
| 885 |
+
super().__init__(*args, **kwargs)
|
| 886 |
+
assert attention_mode == "Temporal"
|
| 887 |
+
|
| 888 |
+
self.attention_mode = attention_mode
|
| 889 |
+
self.is_cross_attention = kwargs["context_dim"] is not None
|
| 890 |
+
|
| 891 |
+
self.pos_encoder = (
|
| 892 |
+
PositionalEncoding(
|
| 893 |
+
kwargs["query_dim"],
|
| 894 |
+
dropout=0.0,
|
| 895 |
+
max_len=temporal_position_encoding_max_len,
|
| 896 |
+
)
|
| 897 |
+
if (temporal_position_encoding and attention_mode == "Temporal")
|
| 898 |
+
else None
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
def extra_repr(self):
|
| 902 |
+
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
|
| 903 |
+
|
| 904 |
+
def set_scale_multiplier(self, multiplier: Union[float, None]):
|
| 905 |
+
if multiplier is None or math.isclose(multiplier, 1.0):
|
| 906 |
+
self.scale = None
|
| 907 |
+
else:
|
| 908 |
+
self.scale = multiplier
|
| 909 |
+
|
| 910 |
+
def set_sub_idxs(self, sub_idxs: list[int]):
|
| 911 |
+
if self.pos_encoder != None:
|
| 912 |
+
self.pos_encoder.set_sub_idxs(sub_idxs)
|
| 913 |
+
|
| 914 |
+
def forward(
|
| 915 |
+
self,
|
| 916 |
+
hidden_states: Tensor,
|
| 917 |
+
encoder_hidden_states=None,
|
| 918 |
+
attention_mask=None,
|
| 919 |
+
video_length=None,
|
| 920 |
+
scale_mask=None,
|
| 921 |
+
):
|
| 922 |
+
if self.attention_mode != "Temporal":
|
| 923 |
+
raise NotImplementedError
|
| 924 |
+
|
| 925 |
+
d = hidden_states.shape[1]
|
| 926 |
+
hidden_states = rearrange(
|
| 927 |
+
hidden_states, "(b f) d c -> (b d) f c", f=video_length
|
| 928 |
+
)
|
| 929 |
+
|
| 930 |
+
if self.pos_encoder is not None:
|
| 931 |
+
hidden_states = self.pos_encoder(hidden_states).to(hidden_states.dtype)
|
| 932 |
+
|
| 933 |
+
encoder_hidden_states = (
|
| 934 |
+
repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
|
| 935 |
+
if encoder_hidden_states is not None
|
| 936 |
+
else encoder_hidden_states
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
hidden_states = super().forward(
|
| 940 |
+
hidden_states,
|
| 941 |
+
encoder_hidden_states,
|
| 942 |
+
value=None,
|
| 943 |
+
mask=attention_mask,
|
| 944 |
+
scale_mask=scale_mask,
|
| 945 |
+
)
|
| 946 |
+
|
| 947 |
+
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
|
| 948 |
+
|
| 949 |
+
return hidden_states
|
ComfyUI-Advanced-ControlNet/adv_control/control_svd.py
ADDED
|
@@ -0,0 +1,517 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
import comfy.model_detection
|
| 6 |
+
from comfy.utils import UNET_MAP_BASIC, UNET_MAP_RESNET, UNET_MAP_ATTENTIONS, TRANSFORMER_BLOCKS
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from comfy.ldm.modules.diffusionmodules.util import (
|
| 12 |
+
zero_module,
|
| 13 |
+
timestep_embedding,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
from comfy.ldm.modules.attention import SpatialVideoTransformer
|
| 17 |
+
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, VideoResBlock, Downsample
|
| 18 |
+
from comfy.ldm.util import exists
|
| 19 |
+
import comfy.ops
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class SVDControlNet(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
image_size,
|
| 26 |
+
in_channels,
|
| 27 |
+
model_channels,
|
| 28 |
+
hint_channels,
|
| 29 |
+
num_res_blocks,
|
| 30 |
+
dropout=0,
|
| 31 |
+
channel_mult=(1, 2, 4, 8),
|
| 32 |
+
conv_resample=True,
|
| 33 |
+
dims=2,
|
| 34 |
+
num_classes=None,
|
| 35 |
+
use_checkpoint=False,
|
| 36 |
+
dtype=torch.float32,
|
| 37 |
+
num_heads=-1,
|
| 38 |
+
num_head_channels=-1,
|
| 39 |
+
num_heads_upsample=-1,
|
| 40 |
+
use_scale_shift_norm=False,
|
| 41 |
+
resblock_updown=False,
|
| 42 |
+
use_new_attention_order=False,
|
| 43 |
+
use_spatial_transformer=False, # custom transformer support
|
| 44 |
+
transformer_depth=1, # custom transformer support
|
| 45 |
+
context_dim=None, # custom transformer support
|
| 46 |
+
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
| 47 |
+
legacy=True,
|
| 48 |
+
disable_self_attentions=None,
|
| 49 |
+
num_attention_blocks=None,
|
| 50 |
+
disable_middle_self_attn=False,
|
| 51 |
+
use_linear_in_transformer=False,
|
| 52 |
+
adm_in_channels=None,
|
| 53 |
+
transformer_depth_middle=None,
|
| 54 |
+
transformer_depth_output=None,
|
| 55 |
+
use_spatial_context=False,
|
| 56 |
+
extra_ff_mix_layer=False,
|
| 57 |
+
merge_strategy="fixed",
|
| 58 |
+
merge_factor=0.5,
|
| 59 |
+
video_kernel_size=3,
|
| 60 |
+
device=None,
|
| 61 |
+
operations=comfy.ops.disable_weight_init,
|
| 62 |
+
**kwargs,
|
| 63 |
+
):
|
| 64 |
+
super().__init__()
|
| 65 |
+
assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
|
| 66 |
+
if use_spatial_transformer:
|
| 67 |
+
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
| 68 |
+
|
| 69 |
+
if context_dim is not None:
|
| 70 |
+
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
| 71 |
+
# from omegaconf.listconfig import ListConfig
|
| 72 |
+
# if type(context_dim) == ListConfig:
|
| 73 |
+
# context_dim = list(context_dim)
|
| 74 |
+
|
| 75 |
+
if num_heads_upsample == -1:
|
| 76 |
+
num_heads_upsample = num_heads
|
| 77 |
+
|
| 78 |
+
if num_heads == -1:
|
| 79 |
+
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
| 80 |
+
|
| 81 |
+
if num_head_channels == -1:
|
| 82 |
+
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
| 83 |
+
|
| 84 |
+
self.dims = dims
|
| 85 |
+
self.image_size = image_size
|
| 86 |
+
self.in_channels = in_channels
|
| 87 |
+
self.model_channels = model_channels
|
| 88 |
+
|
| 89 |
+
if isinstance(num_res_blocks, int):
|
| 90 |
+
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
| 91 |
+
else:
|
| 92 |
+
if len(num_res_blocks) != len(channel_mult):
|
| 93 |
+
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
| 94 |
+
"as a list/tuple (per-level) with the same length as channel_mult")
|
| 95 |
+
self.num_res_blocks = num_res_blocks
|
| 96 |
+
|
| 97 |
+
if disable_self_attentions is not None:
|
| 98 |
+
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
| 99 |
+
assert len(disable_self_attentions) == len(channel_mult)
|
| 100 |
+
if num_attention_blocks is not None:
|
| 101 |
+
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
| 102 |
+
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
| 103 |
+
|
| 104 |
+
transformer_depth = transformer_depth[:]
|
| 105 |
+
|
| 106 |
+
self.dropout = dropout
|
| 107 |
+
self.channel_mult = channel_mult
|
| 108 |
+
self.conv_resample = conv_resample
|
| 109 |
+
self.num_classes = num_classes
|
| 110 |
+
self.use_checkpoint = use_checkpoint
|
| 111 |
+
self.dtype = dtype
|
| 112 |
+
self.num_heads = num_heads
|
| 113 |
+
self.num_head_channels = num_head_channels
|
| 114 |
+
self.num_heads_upsample = num_heads_upsample
|
| 115 |
+
self.predict_codebook_ids = n_embed is not None
|
| 116 |
+
|
| 117 |
+
time_embed_dim = model_channels * 4
|
| 118 |
+
self.time_embed = nn.Sequential(
|
| 119 |
+
operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
|
| 120 |
+
nn.SiLU(),
|
| 121 |
+
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if self.num_classes is not None:
|
| 125 |
+
if isinstance(self.num_classes, int):
|
| 126 |
+
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
| 127 |
+
elif self.num_classes == "continuous":
|
| 128 |
+
print("setting up linear c_adm embedding layer")
|
| 129 |
+
self.label_emb = nn.Linear(1, time_embed_dim)
|
| 130 |
+
elif self.num_classes == "sequential":
|
| 131 |
+
assert adm_in_channels is not None
|
| 132 |
+
self.label_emb = nn.Sequential(
|
| 133 |
+
nn.Sequential(
|
| 134 |
+
operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
|
| 135 |
+
nn.SiLU(),
|
| 136 |
+
operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
else:
|
| 140 |
+
raise ValueError()
|
| 141 |
+
|
| 142 |
+
self.input_blocks = nn.ModuleList(
|
| 143 |
+
[
|
| 144 |
+
TimestepEmbedSequential(
|
| 145 |
+
operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
| 146 |
+
)
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
|
| 150 |
+
|
| 151 |
+
self.input_hint_block = TimestepEmbedSequential(
|
| 152 |
+
operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
|
| 153 |
+
nn.SiLU(),
|
| 154 |
+
operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
|
| 155 |
+
nn.SiLU(),
|
| 156 |
+
operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
| 157 |
+
nn.SiLU(),
|
| 158 |
+
operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
|
| 159 |
+
nn.SiLU(),
|
| 160 |
+
operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
| 161 |
+
nn.SiLU(),
|
| 162 |
+
operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
|
| 163 |
+
nn.SiLU(),
|
| 164 |
+
operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
|
| 165 |
+
nn.SiLU(),
|
| 166 |
+
operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
self._feature_size = model_channels
|
| 170 |
+
input_block_chans = [model_channels]
|
| 171 |
+
ch = model_channels
|
| 172 |
+
ds = 1
|
| 173 |
+
for level, mult in enumerate(channel_mult):
|
| 174 |
+
for nr in range(self.num_res_blocks[level]):
|
| 175 |
+
layers = [
|
| 176 |
+
VideoResBlock(
|
| 177 |
+
ch,
|
| 178 |
+
time_embed_dim,
|
| 179 |
+
dropout,
|
| 180 |
+
out_channels=mult * model_channels,
|
| 181 |
+
dims=dims,
|
| 182 |
+
use_checkpoint=use_checkpoint,
|
| 183 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 184 |
+
dtype=self.dtype,
|
| 185 |
+
device=device,
|
| 186 |
+
operations=operations,
|
| 187 |
+
video_kernel_size=video_kernel_size,
|
| 188 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
| 189 |
+
)
|
| 190 |
+
]
|
| 191 |
+
ch = mult * model_channels
|
| 192 |
+
num_transformers = transformer_depth.pop(0)
|
| 193 |
+
if num_transformers > 0:
|
| 194 |
+
if num_head_channels == -1:
|
| 195 |
+
dim_head = ch // num_heads
|
| 196 |
+
else:
|
| 197 |
+
num_heads = ch // num_head_channels
|
| 198 |
+
dim_head = num_head_channels
|
| 199 |
+
if legacy:
|
| 200 |
+
#num_heads = 1
|
| 201 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 202 |
+
if exists(disable_self_attentions):
|
| 203 |
+
disabled_sa = disable_self_attentions[level]
|
| 204 |
+
else:
|
| 205 |
+
disabled_sa = False
|
| 206 |
+
|
| 207 |
+
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
| 208 |
+
layers.append(
|
| 209 |
+
SpatialVideoTransformer(
|
| 210 |
+
ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
|
| 211 |
+
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
|
| 212 |
+
checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations,
|
| 213 |
+
use_spatial_context=use_spatial_context, ff_in=extra_ff_mix_layer,
|
| 214 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
| 215 |
+
)
|
| 216 |
+
)
|
| 217 |
+
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
| 218 |
+
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
| 219 |
+
self._feature_size += ch
|
| 220 |
+
input_block_chans.append(ch)
|
| 221 |
+
if level != len(channel_mult) - 1:
|
| 222 |
+
out_ch = ch
|
| 223 |
+
self.input_blocks.append(
|
| 224 |
+
TimestepEmbedSequential(
|
| 225 |
+
VideoResBlock(
|
| 226 |
+
ch,
|
| 227 |
+
time_embed_dim,
|
| 228 |
+
dropout,
|
| 229 |
+
out_channels=out_ch,
|
| 230 |
+
dims=dims,
|
| 231 |
+
use_checkpoint=use_checkpoint,
|
| 232 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 233 |
+
down=True,
|
| 234 |
+
dtype=self.dtype,
|
| 235 |
+
device=device,
|
| 236 |
+
operations=operations,
|
| 237 |
+
video_kernel_size=video_kernel_size,
|
| 238 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
| 239 |
+
)
|
| 240 |
+
if resblock_updown
|
| 241 |
+
else Downsample(
|
| 242 |
+
ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
|
| 243 |
+
)
|
| 244 |
+
)
|
| 245 |
+
)
|
| 246 |
+
ch = out_ch
|
| 247 |
+
input_block_chans.append(ch)
|
| 248 |
+
self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
|
| 249 |
+
ds *= 2
|
| 250 |
+
self._feature_size += ch
|
| 251 |
+
|
| 252 |
+
if num_head_channels == -1:
|
| 253 |
+
dim_head = ch // num_heads
|
| 254 |
+
else:
|
| 255 |
+
num_heads = ch // num_head_channels
|
| 256 |
+
dim_head = num_head_channels
|
| 257 |
+
if legacy:
|
| 258 |
+
#num_heads = 1
|
| 259 |
+
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
| 260 |
+
mid_block = [
|
| 261 |
+
VideoResBlock(
|
| 262 |
+
ch,
|
| 263 |
+
time_embed_dim,
|
| 264 |
+
dropout,
|
| 265 |
+
dims=dims,
|
| 266 |
+
use_checkpoint=use_checkpoint,
|
| 267 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 268 |
+
dtype=self.dtype,
|
| 269 |
+
device=device,
|
| 270 |
+
operations=operations,
|
| 271 |
+
video_kernel_size=video_kernel_size,
|
| 272 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
| 273 |
+
)]
|
| 274 |
+
if transformer_depth_middle >= 0:
|
| 275 |
+
mid_block += [SpatialVideoTransformer( # always uses a self-attn
|
| 276 |
+
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
|
| 277 |
+
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
|
| 278 |
+
checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations,
|
| 279 |
+
use_spatial_context=use_spatial_context, ff_in=extra_ff_mix_layer,
|
| 280 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
| 281 |
+
),
|
| 282 |
+
VideoResBlock(
|
| 283 |
+
ch,
|
| 284 |
+
time_embed_dim,
|
| 285 |
+
dropout,
|
| 286 |
+
dims=dims,
|
| 287 |
+
use_checkpoint=use_checkpoint,
|
| 288 |
+
use_scale_shift_norm=use_scale_shift_norm,
|
| 289 |
+
dtype=self.dtype,
|
| 290 |
+
device=device,
|
| 291 |
+
operations=operations,
|
| 292 |
+
video_kernel_size=video_kernel_size,
|
| 293 |
+
merge_strategy=merge_strategy, merge_factor=merge_factor,
|
| 294 |
+
)]
|
| 295 |
+
self.middle_block = TimestepEmbedSequential(*mid_block)
|
| 296 |
+
self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
|
| 297 |
+
self._feature_size += ch
|
| 298 |
+
|
| 299 |
+
def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
|
| 300 |
+
return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
|
| 301 |
+
|
| 302 |
+
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
| 303 |
+
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
| 304 |
+
emb = self.time_embed(t_emb)
|
| 305 |
+
|
| 306 |
+
cond = kwargs["cond"]
|
| 307 |
+
num_video_frames = cond["num_video_frames"]
|
| 308 |
+
image_only_indicator = cond.get("image_only_indicator", None)
|
| 309 |
+
time_context = cond.get("time_context", None)
|
| 310 |
+
del cond
|
| 311 |
+
|
| 312 |
+
guided_hint = self.input_hint_block(hint, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
| 313 |
+
|
| 314 |
+
outs = []
|
| 315 |
+
|
| 316 |
+
hs = []
|
| 317 |
+
if self.num_classes is not None:
|
| 318 |
+
assert y.shape[0] == x.shape[0]
|
| 319 |
+
emb = emb + self.label_emb(y)
|
| 320 |
+
|
| 321 |
+
h = x
|
| 322 |
+
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
| 323 |
+
if guided_hint is not None:
|
| 324 |
+
h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
| 325 |
+
h += guided_hint
|
| 326 |
+
guided_hint = None
|
| 327 |
+
else:
|
| 328 |
+
h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
| 329 |
+
outs.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
|
| 330 |
+
|
| 331 |
+
h = self.middle_block(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
| 332 |
+
outs.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
|
| 333 |
+
|
| 334 |
+
return outs
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
TEMPORAL_TRANSFORMER_BLOCKS = {
|
| 338 |
+
"norm_in.weight",
|
| 339 |
+
"norm_in.bias",
|
| 340 |
+
"ff_in.net.0.proj.weight",
|
| 341 |
+
"ff_in.net.0.proj.bias",
|
| 342 |
+
"ff_in.net.2.weight",
|
| 343 |
+
"ff_in.net.2.bias",
|
| 344 |
+
}
|
| 345 |
+
TEMPORAL_TRANSFORMER_BLOCKS.update(TRANSFORMER_BLOCKS)
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
TEMPORAL_UNET_MAP_ATTENTIONS = {
|
| 349 |
+
"time_mixer.mix_factor",
|
| 350 |
+
}
|
| 351 |
+
TEMPORAL_UNET_MAP_ATTENTIONS.update(UNET_MAP_ATTENTIONS)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
TEMPORAL_TRANSFORMER_MAP = {
|
| 355 |
+
"time_pos_embed.0.weight": "time_pos_embed.linear_1.weight",
|
| 356 |
+
"time_pos_embed.0.bias": "time_pos_embed.linear_1.bias",
|
| 357 |
+
"time_pos_embed.2.weight": "time_pos_embed.linear_2.weight",
|
| 358 |
+
"time_pos_embed.2.bias": "time_pos_embed.linear_2.bias",
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
TEMPORAL_RESNET = {
|
| 363 |
+
"time_mixer.mix_factor",
|
| 364 |
+
}
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def svd_unet_config_from_diffusers_unet(state_dict: dict[str, Tensor], dtype):
|
| 368 |
+
match = {}
|
| 369 |
+
transformer_depth = []
|
| 370 |
+
|
| 371 |
+
attn_res = 1
|
| 372 |
+
down_blocks = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}")
|
| 373 |
+
for i in range(down_blocks):
|
| 374 |
+
attn_blocks = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
|
| 375 |
+
for ab in range(attn_blocks):
|
| 376 |
+
transformer_count = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
|
| 377 |
+
transformer_depth.append(transformer_count)
|
| 378 |
+
if transformer_count > 0:
|
| 379 |
+
match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
|
| 380 |
+
|
| 381 |
+
attn_res *= 2
|
| 382 |
+
if attn_blocks == 0:
|
| 383 |
+
transformer_depth.append(0)
|
| 384 |
+
transformer_depth.append(0)
|
| 385 |
+
|
| 386 |
+
match["transformer_depth"] = transformer_depth
|
| 387 |
+
|
| 388 |
+
match["model_channels"] = state_dict["conv_in.weight"].shape[0]
|
| 389 |
+
match["in_channels"] = state_dict["conv_in.weight"].shape[1]
|
| 390 |
+
match["adm_in_channels"] = None
|
| 391 |
+
if "class_embedding.linear_1.weight" in state_dict:
|
| 392 |
+
match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
|
| 393 |
+
elif "add_embedding.linear_1.weight" in state_dict:
|
| 394 |
+
match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
|
| 395 |
+
|
| 396 |
+
# based on unet_config of SVD
|
| 397 |
+
SVD = {
|
| 398 |
+
'use_checkpoint': False,
|
| 399 |
+
'image_size': 32,
|
| 400 |
+
'use_spatial_transformer': True,
|
| 401 |
+
'legacy': False,
|
| 402 |
+
'num_classes': 'sequential',
|
| 403 |
+
'adm_in_channels': 768,
|
| 404 |
+
'dtype': dtype,
|
| 405 |
+
'in_channels': 8,
|
| 406 |
+
'out_channels': 4,
|
| 407 |
+
'model_channels': 320,
|
| 408 |
+
'num_res_blocks': [2, 2, 2, 2],
|
| 409 |
+
'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
| 410 |
+
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
| 411 |
+
'channel_mult': [1, 2, 4, 4],
|
| 412 |
+
'transformer_depth_middle': 1,
|
| 413 |
+
'use_linear_in_transformer': True,
|
| 414 |
+
'context_dim': 1024,
|
| 415 |
+
'extra_ff_mix_layer': True,
|
| 416 |
+
'use_spatial_context': True,
|
| 417 |
+
'merge_strategy': 'learned_with_images',
|
| 418 |
+
'merge_factor': 0.0,
|
| 419 |
+
'video_kernel_size': [3, 1, 1],
|
| 420 |
+
'use_temporal_attention': True,
|
| 421 |
+
'use_temporal_resblock': True,
|
| 422 |
+
'num_heads': -1,
|
| 423 |
+
'num_head_channels': 64,
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
supported_models = [SVD]
|
| 427 |
+
|
| 428 |
+
for unet_config in supported_models:
|
| 429 |
+
matches = True
|
| 430 |
+
for k in match:
|
| 431 |
+
if match[k] != unet_config[k]:
|
| 432 |
+
matches = False
|
| 433 |
+
break
|
| 434 |
+
if matches:
|
| 435 |
+
return comfy.model_detection.convert_config(unet_config)
|
| 436 |
+
return None
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def svd_unet_to_diffusers(unet_config):
|
| 440 |
+
num_res_blocks = unet_config["num_res_blocks"]
|
| 441 |
+
channel_mult = unet_config["channel_mult"]
|
| 442 |
+
transformer_depth = unet_config["transformer_depth"][:]
|
| 443 |
+
transformer_depth_output = unet_config["transformer_depth_output"][:]
|
| 444 |
+
num_blocks = len(channel_mult)
|
| 445 |
+
|
| 446 |
+
transformers_mid = unet_config.get("transformer_depth_middle", None)
|
| 447 |
+
|
| 448 |
+
diffusers_unet_map = {}
|
| 449 |
+
for x in range(num_blocks):
|
| 450 |
+
n = 1 + (num_res_blocks[x] + 1) * x
|
| 451 |
+
for i in range(num_res_blocks[x]):
|
| 452 |
+
for b in TEMPORAL_RESNET:
|
| 453 |
+
diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, b)] = "input_blocks.{}.0.{}".format(n, b)
|
| 454 |
+
for b in UNET_MAP_RESNET:
|
| 455 |
+
diffusers_unet_map["down_blocks.{}.resnets.{}.spatial_res_block.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
| 456 |
+
diffusers_unet_map["down_blocks.{}.resnets.{}.temporal_res_block.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.time_stack.{}".format(n, b)
|
| 457 |
+
#diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
|
| 458 |
+
num_transformers = transformer_depth.pop(0)
|
| 459 |
+
if num_transformers > 0:
|
| 460 |
+
for b in TEMPORAL_UNET_MAP_ATTENTIONS:
|
| 461 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
|
| 462 |
+
for b in TEMPORAL_TRANSFORMER_MAP:
|
| 463 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, TEMPORAL_TRANSFORMER_MAP[b])] = "input_blocks.{}.1.{}".format(n, b)
|
| 464 |
+
for t in range(num_transformers):
|
| 465 |
+
for b in TRANSFORMER_BLOCKS:
|
| 466 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
| 467 |
+
for b in TEMPORAL_TRANSFORMER_BLOCKS:
|
| 468 |
+
diffusers_unet_map["down_blocks.{}.attentions.{}.temporal_transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.time_stack.{}.{}".format(n, t, b)
|
| 469 |
+
n += 1
|
| 470 |
+
for k in ["weight", "bias"]:
|
| 471 |
+
diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
|
| 472 |
+
|
| 473 |
+
i = 0
|
| 474 |
+
for b in TEMPORAL_UNET_MAP_ATTENTIONS:
|
| 475 |
+
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
|
| 476 |
+
for b in TEMPORAL_TRANSFORMER_MAP:
|
| 477 |
+
diffusers_unet_map["mid_block.attentions.{}.{}".format(i, TEMPORAL_TRANSFORMER_MAP[b])] = "middle_block.1.{}".format(b)
|
| 478 |
+
for t in range(transformers_mid):
|
| 479 |
+
for b in TRANSFORMER_BLOCKS:
|
| 480 |
+
diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
|
| 481 |
+
for b in TEMPORAL_TRANSFORMER_BLOCKS:
|
| 482 |
+
diffusers_unet_map["mid_block.attentions.{}.temporal_transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.time_stack.{}.{}".format(t, b)
|
| 483 |
+
|
| 484 |
+
for i, n in enumerate([0, 2]):
|
| 485 |
+
for b in TEMPORAL_RESNET:
|
| 486 |
+
diffusers_unet_map["mid_block.resnets.{}.{}".format(i, b)] = "middle_block.{}.{}".format(n, b)
|
| 487 |
+
for b in UNET_MAP_RESNET:
|
| 488 |
+
diffusers_unet_map["mid_block.resnets.{}.spatial_res_block.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
| 489 |
+
diffusers_unet_map["mid_block.resnets.{}.temporal_res_block.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.time_stack.{}".format(n, b)
|
| 490 |
+
#diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
|
| 491 |
+
|
| 492 |
+
num_res_blocks = list(reversed(num_res_blocks))
|
| 493 |
+
for x in range(num_blocks):
|
| 494 |
+
n = (num_res_blocks[x] + 1) * x
|
| 495 |
+
l = num_res_blocks[x] + 1
|
| 496 |
+
for i in range(l):
|
| 497 |
+
c = 0
|
| 498 |
+
for b in UNET_MAP_RESNET:
|
| 499 |
+
diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
|
| 500 |
+
c += 1
|
| 501 |
+
num_transformers = transformer_depth_output.pop()
|
| 502 |
+
if num_transformers > 0:
|
| 503 |
+
c += 1
|
| 504 |
+
for b in UNET_MAP_ATTENTIONS:
|
| 505 |
+
diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
|
| 506 |
+
for t in range(num_transformers):
|
| 507 |
+
for b in TRANSFORMER_BLOCKS:
|
| 508 |
+
diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
|
| 509 |
+
if i == l - 1:
|
| 510 |
+
for k in ["weight", "bias"]:
|
| 511 |
+
diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
|
| 512 |
+
n += 1
|
| 513 |
+
|
| 514 |
+
for k in UNET_MAP_BASIC:
|
| 515 |
+
diffusers_unet_map[k[1]] = k[0]
|
| 516 |
+
|
| 517 |
+
return diffusers_unet_map
|
ComfyUI-Advanced-ControlNet/adv_control/logger.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import copy
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ColoredFormatter(logging.Formatter):
|
| 7 |
+
COLORS = {
|
| 8 |
+
"DEBUG": "\033[0;36m", # CYAN
|
| 9 |
+
"INFO": "\033[0;32m", # GREEN
|
| 10 |
+
"WARNING": "\033[0;33m", # YELLOW
|
| 11 |
+
"ERROR": "\033[0;31m", # RED
|
| 12 |
+
"CRITICAL": "\033[0;37;41m", # WHITE ON RED
|
| 13 |
+
"RESET": "\033[0m", # RESET COLOR
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
def format(self, record):
|
| 17 |
+
colored_record = copy.copy(record)
|
| 18 |
+
levelname = colored_record.levelname
|
| 19 |
+
seq = self.COLORS.get(levelname, self.COLORS["RESET"])
|
| 20 |
+
colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
|
| 21 |
+
return super().format(colored_record)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Create a new logger
|
| 25 |
+
logger = logging.getLogger("Advanced-ControlNet")
|
| 26 |
+
logger.propagate = False
|
| 27 |
+
|
| 28 |
+
# Add handler if we don't have one.
|
| 29 |
+
if not logger.handlers:
|
| 30 |
+
handler = logging.StreamHandler(sys.stdout)
|
| 31 |
+
handler.setFormatter(ColoredFormatter("[%(name)s] - %(levelname)s - %(message)s"))
|
| 32 |
+
logger.addHandler(handler)
|
| 33 |
+
|
| 34 |
+
# Configure logger
|
| 35 |
+
loglevel = logging.INFO
|
| 36 |
+
logger.setLevel(loglevel)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
|
| 4 |
+
import folder_paths
|
| 5 |
+
from comfy.model_patcher import ModelPatcher
|
| 6 |
+
|
| 7 |
+
from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet
|
| 8 |
+
from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, BIGMAX
|
| 9 |
+
from .nodes_weight import (DefaultWeights, ScaledSoftMaskedUniversalWeights, ScaledSoftUniversalWeights, SoftControlNetWeights, CustomControlNetWeights,
|
| 10 |
+
SoftT2IAdapterWeights, CustomT2IAdapterWeights)
|
| 11 |
+
from .nodes_keyframes import (LatentKeyframeGroupNode, LatentKeyframeInterpolationNode, LatentKeyframeBatchedGroupNode, LatentKeyframeNode,
|
| 12 |
+
TimestepKeyframeNode, TimestepKeyframeInterpolationNode, TimestepKeyframeFromStrengthListNode)
|
| 13 |
+
from .nodes_sparsectrl import SparseCtrlMergedLoaderAdvanced, SparseCtrlLoaderAdvanced, SparseIndexMethodNode, SparseSpreadMethodNode, RgbSparseCtrlPreprocessor
|
| 14 |
+
from .nodes_reference import ReferenceControlNetNode, ReferenceControlFinetune, ReferencePreprocessorNode
|
| 15 |
+
from .nodes_loosecontrol import ControlNetLoaderWithLoraAdvanced
|
| 16 |
+
from .nodes_deprecated import LoadImagesFromDirectory
|
| 17 |
+
from .logger import logger
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ControlNetLoaderAdvanced:
|
| 21 |
+
@classmethod
|
| 22 |
+
def INPUT_TYPES(s):
|
| 23 |
+
return {
|
| 24 |
+
"required": {
|
| 25 |
+
"control_net_name": (folder_paths.get_filename_list("controlnet"), ),
|
| 26 |
+
},
|
| 27 |
+
"optional": {
|
| 28 |
+
"timestep_keyframe": ("TIMESTEP_KEYFRAME", ),
|
| 29 |
+
}
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
| 33 |
+
FUNCTION = "load_controlnet"
|
| 34 |
+
|
| 35 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
|
| 36 |
+
|
| 37 |
+
def load_controlnet(self, control_net_name,
|
| 38 |
+
timestep_keyframe: TimestepKeyframeGroup=None
|
| 39 |
+
):
|
| 40 |
+
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
| 41 |
+
controlnet = load_controlnet(controlnet_path, timestep_keyframe)
|
| 42 |
+
return (controlnet,)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class DiffControlNetLoaderAdvanced:
|
| 46 |
+
@classmethod
|
| 47 |
+
def INPUT_TYPES(s):
|
| 48 |
+
return {
|
| 49 |
+
"required": {
|
| 50 |
+
"model": ("MODEL",),
|
| 51 |
+
"control_net_name": (folder_paths.get_filename_list("controlnet"), )
|
| 52 |
+
},
|
| 53 |
+
"optional": {
|
| 54 |
+
"timestep_keyframe": ("TIMESTEP_KEYFRAME", ),
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
| 59 |
+
FUNCTION = "load_controlnet"
|
| 60 |
+
|
| 61 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
|
| 62 |
+
|
| 63 |
+
def load_controlnet(self, control_net_name, model,
|
| 64 |
+
timestep_keyframe: TimestepKeyframeGroup=None
|
| 65 |
+
):
|
| 66 |
+
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
| 67 |
+
controlnet = load_controlnet(controlnet_path, timestep_keyframe, model)
|
| 68 |
+
if is_advanced_controlnet(controlnet):
|
| 69 |
+
controlnet.verify_all_weights()
|
| 70 |
+
return (controlnet,)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class AdvancedControlNetApply:
|
| 74 |
+
@classmethod
|
| 75 |
+
def INPUT_TYPES(s):
|
| 76 |
+
return {
|
| 77 |
+
"required": {
|
| 78 |
+
"positive": ("CONDITIONING", ),
|
| 79 |
+
"negative": ("CONDITIONING", ),
|
| 80 |
+
"control_net": ("CONTROL_NET", ),
|
| 81 |
+
"image": ("IMAGE", ),
|
| 82 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
|
| 83 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 84 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
| 85 |
+
},
|
| 86 |
+
"optional": {
|
| 87 |
+
"mask_optional": ("MASK", ),
|
| 88 |
+
"timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
| 89 |
+
"latent_kf_override": ("LATENT_KEYFRAME", ),
|
| 90 |
+
"weights_override": ("CONTROL_NET_WEIGHTS", ),
|
| 91 |
+
"model_optional": ("MODEL",),
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
RETURN_TYPES = ("CONDITIONING","CONDITIONING","MODEL",)
|
| 96 |
+
RETURN_NAMES = ("positive", "negative", "model_opt")
|
| 97 |
+
FUNCTION = "apply_controlnet"
|
| 98 |
+
|
| 99 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
|
| 100 |
+
|
| 101 |
+
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent,
|
| 102 |
+
mask_optional: Tensor=None, model_optional: ModelPatcher=None,
|
| 103 |
+
timestep_kf: TimestepKeyframeGroup=None, latent_kf_override: LatentKeyframeGroup=None,
|
| 104 |
+
weights_override: ControlWeights=None):
|
| 105 |
+
if strength == 0:
|
| 106 |
+
return (positive, negative, model_optional)
|
| 107 |
+
if model_optional:
|
| 108 |
+
model_optional = model_optional.clone()
|
| 109 |
+
|
| 110 |
+
control_hint = image.movedim(-1,1)
|
| 111 |
+
cnets = {}
|
| 112 |
+
|
| 113 |
+
out = []
|
| 114 |
+
for conditioning in [positive, negative]:
|
| 115 |
+
c = []
|
| 116 |
+
for t in conditioning:
|
| 117 |
+
d = t[1].copy()
|
| 118 |
+
|
| 119 |
+
prev_cnet = d.get('control', None)
|
| 120 |
+
if prev_cnet in cnets:
|
| 121 |
+
c_net = cnets[prev_cnet]
|
| 122 |
+
else:
|
| 123 |
+
# copy, convert to advanced if needed, and set cond
|
| 124 |
+
c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent))
|
| 125 |
+
if is_advanced_controlnet(c_net):
|
| 126 |
+
# disarm node check
|
| 127 |
+
c_net.disarm()
|
| 128 |
+
# if model required, verify model is passed in, and if so patch it
|
| 129 |
+
if c_net.require_model:
|
| 130 |
+
if not model_optional:
|
| 131 |
+
raise Exception(f"Type '{type(c_net).__name__}' requires model_optional input, but got None.")
|
| 132 |
+
c_net.patch_model(model=model_optional)
|
| 133 |
+
# apply optional parameters and overrides, if provided
|
| 134 |
+
if timestep_kf is not None:
|
| 135 |
+
c_net.set_timestep_keyframes(timestep_kf)
|
| 136 |
+
if latent_kf_override is not None:
|
| 137 |
+
c_net.latent_keyframe_override = latent_kf_override
|
| 138 |
+
if weights_override is not None:
|
| 139 |
+
c_net.weights_override = weights_override
|
| 140 |
+
# verify weights are compatible
|
| 141 |
+
c_net.verify_all_weights()
|
| 142 |
+
# set cond hint mask
|
| 143 |
+
if mask_optional is not None:
|
| 144 |
+
mask_optional = mask_optional.clone()
|
| 145 |
+
# if not in the form of a batch, make it so
|
| 146 |
+
if len(mask_optional.shape) < 3:
|
| 147 |
+
mask_optional = mask_optional.unsqueeze(0)
|
| 148 |
+
c_net.set_cond_hint_mask(mask_optional)
|
| 149 |
+
c_net.set_previous_controlnet(prev_cnet)
|
| 150 |
+
cnets[prev_cnet] = c_net
|
| 151 |
+
|
| 152 |
+
d['control'] = c_net
|
| 153 |
+
d['control_apply_to_uncond'] = False
|
| 154 |
+
n = [t[0], d]
|
| 155 |
+
c.append(n)
|
| 156 |
+
out.append(c)
|
| 157 |
+
return (out[0], out[1], model_optional)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# NODE MAPPING
|
| 161 |
+
NODE_CLASS_MAPPINGS = {
|
| 162 |
+
# Keyframes
|
| 163 |
+
"TimestepKeyframe": TimestepKeyframeNode,
|
| 164 |
+
"ACN_TimestepKeyframeInterpolation": TimestepKeyframeInterpolationNode,
|
| 165 |
+
"ACN_TimestepKeyframeFromStrengthList": TimestepKeyframeFromStrengthListNode,
|
| 166 |
+
"LatentKeyframe": LatentKeyframeNode,
|
| 167 |
+
"LatentKeyframeTiming": LatentKeyframeInterpolationNode,
|
| 168 |
+
"LatentKeyframeBatchedGroup": LatentKeyframeBatchedGroupNode,
|
| 169 |
+
"LatentKeyframeGroup": LatentKeyframeGroupNode,
|
| 170 |
+
# Conditioning
|
| 171 |
+
"ACN_AdvancedControlNetApply": AdvancedControlNetApply,
|
| 172 |
+
# Loaders
|
| 173 |
+
"ControlNetLoaderAdvanced": ControlNetLoaderAdvanced,
|
| 174 |
+
"DiffControlNetLoaderAdvanced": DiffControlNetLoaderAdvanced,
|
| 175 |
+
# Weights
|
| 176 |
+
"ScaledSoftControlNetWeights": ScaledSoftUniversalWeights,
|
| 177 |
+
"ScaledSoftMaskedUniversalWeights": ScaledSoftMaskedUniversalWeights,
|
| 178 |
+
"SoftControlNetWeights": SoftControlNetWeights,
|
| 179 |
+
"CustomControlNetWeights": CustomControlNetWeights,
|
| 180 |
+
"SoftT2IAdapterWeights": SoftT2IAdapterWeights,
|
| 181 |
+
"CustomT2IAdapterWeights": CustomT2IAdapterWeights,
|
| 182 |
+
"ACN_DefaultUniversalWeights": DefaultWeights,
|
| 183 |
+
# SparseCtrl
|
| 184 |
+
"ACN_SparseCtrlRGBPreprocessor": RgbSparseCtrlPreprocessor,
|
| 185 |
+
"ACN_SparseCtrlLoaderAdvanced": SparseCtrlLoaderAdvanced,
|
| 186 |
+
"ACN_SparseCtrlMergedLoaderAdvanced": SparseCtrlMergedLoaderAdvanced,
|
| 187 |
+
"ACN_SparseCtrlIndexMethodNode": SparseIndexMethodNode,
|
| 188 |
+
"ACN_SparseCtrlSpreadMethodNode": SparseSpreadMethodNode,
|
| 189 |
+
# Reference
|
| 190 |
+
"ACN_ReferencePreprocessor": ReferencePreprocessorNode,
|
| 191 |
+
"ACN_ReferenceControlNet": ReferenceControlNetNode,
|
| 192 |
+
"ACN_ReferenceControlNetFinetune": ReferenceControlFinetune,
|
| 193 |
+
# LOOSEControl
|
| 194 |
+
#"ACN_ControlNetLoaderWithLoraAdvanced": ControlNetLoaderWithLoraAdvanced,
|
| 195 |
+
# Deprecated
|
| 196 |
+
"LoadImagesFromDirectory": LoadImagesFromDirectory,
|
| 197 |
+
}
|
| 198 |
+
|
| 199 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 200 |
+
# Keyframes
|
| 201 |
+
"TimestepKeyframe": "Timestep Keyframe 🛂🅐🅒🅝",
|
| 202 |
+
"ACN_TimestepKeyframeInterpolation": "Timestep Keyframe Interpolation 🛂🅐🅒🅝",
|
| 203 |
+
"ACN_TimestepKeyframeFromStrengthList": "Timestep Keyframe From List 🛂🅐🅒🅝",
|
| 204 |
+
"LatentKeyframe": "Latent Keyframe 🛂🅐🅒🅝",
|
| 205 |
+
"LatentKeyframeTiming": "Latent Keyframe Interpolation 🛂🅐🅒🅝",
|
| 206 |
+
"LatentKeyframeBatchedGroup": "Latent Keyframe From List 🛂🅐🅒🅝",
|
| 207 |
+
"LatentKeyframeGroup": "Latent Keyframe Group 🛂🅐🅒🅝",
|
| 208 |
+
# Conditioning
|
| 209 |
+
"ACN_AdvancedControlNetApply": "Apply Advanced ControlNet 🛂🅐🅒🅝",
|
| 210 |
+
# Loaders
|
| 211 |
+
"ControlNetLoaderAdvanced": "Load Advanced ControlNet Model 🛂🅐🅒🅝",
|
| 212 |
+
"DiffControlNetLoaderAdvanced": "Load Advanced ControlNet Model (diff) 🛂🅐🅒🅝",
|
| 213 |
+
# Weights
|
| 214 |
+
"ScaledSoftControlNetWeights": "Scaled Soft Weights 🛂🅐🅒🅝",
|
| 215 |
+
"ScaledSoftMaskedUniversalWeights": "Scaled Soft Masked Weights 🛂🅐🅒🅝",
|
| 216 |
+
"SoftControlNetWeights": "ControlNet Soft Weights 🛂🅐🅒🅝",
|
| 217 |
+
"CustomControlNetWeights": "ControlNet Custom Weights 🛂🅐🅒🅝",
|
| 218 |
+
"SoftT2IAdapterWeights": "T2IAdapter Soft Weights 🛂🅐🅒🅝",
|
| 219 |
+
"CustomT2IAdapterWeights": "T2IAdapter Custom Weights 🛂🅐🅒🅝",
|
| 220 |
+
"ACN_DefaultUniversalWeights": "Force Default Weights 🛂🅐🅒🅝",
|
| 221 |
+
# SparseCtrl
|
| 222 |
+
"ACN_SparseCtrlRGBPreprocessor": "RGB SparseCtrl 🛂🅐🅒🅝",
|
| 223 |
+
"ACN_SparseCtrlLoaderAdvanced": "Load SparseCtrl Model 🛂🅐🅒🅝",
|
| 224 |
+
"ACN_SparseCtrlMergedLoaderAdvanced": "🧪Load Merged SparseCtrl Model 🛂🅐🅒🅝",
|
| 225 |
+
"ACN_SparseCtrlIndexMethodNode": "SparseCtrl Index Method 🛂🅐🅒🅝",
|
| 226 |
+
"ACN_SparseCtrlSpreadMethodNode": "SparseCtrl Spread Method 🛂🅐🅒🅝",
|
| 227 |
+
# Reference
|
| 228 |
+
"ACN_ReferencePreprocessor": "Reference Preproccessor 🛂🅐🅒🅝",
|
| 229 |
+
"ACN_ReferenceControlNet": "Reference ControlNet 🛂🅐🅒🅝",
|
| 230 |
+
"ACN_ReferenceControlNetFinetune": "Reference ControlNet (Finetune) 🛂🅐🅒🅝",
|
| 231 |
+
# LOOSEControl
|
| 232 |
+
#"ACN_ControlNetLoaderWithLoraAdvanced": "Load Adv. ControlNet Model w/ LoRA 🛂🅐🅒🅝",
|
| 233 |
+
# Deprecated
|
| 234 |
+
"LoadImagesFromDirectory": "🚫Load Images [DEPRECATED] 🛂🅐🅒🅝",
|
| 235 |
+
}
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_deprecated.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image, ImageOps
|
| 7 |
+
from .utils import BIGMAX
|
| 8 |
+
from .logger import logger
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LoadImagesFromDirectory:
|
| 12 |
+
@classmethod
|
| 13 |
+
def INPUT_TYPES(s):
|
| 14 |
+
return {
|
| 15 |
+
"required": {
|
| 16 |
+
"directory": ("STRING", {"default": ""}),
|
| 17 |
+
},
|
| 18 |
+
"optional": {
|
| 19 |
+
"image_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
|
| 20 |
+
"start_index": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
RETURN_TYPES = ("IMAGE", "MASK", "INT")
|
| 25 |
+
FUNCTION = "load_images"
|
| 26 |
+
|
| 27 |
+
CATEGORY = ""
|
| 28 |
+
|
| 29 |
+
def load_images(self, directory: str, image_load_cap: int = 0, start_index: int = 0):
|
| 30 |
+
if not os.path.isdir(directory):
|
| 31 |
+
raise FileNotFoundError(f"Directory '{directory} cannot be found.'")
|
| 32 |
+
dir_files = os.listdir(directory)
|
| 33 |
+
if len(dir_files) == 0:
|
| 34 |
+
raise FileNotFoundError(f"No files in directory '{directory}'.")
|
| 35 |
+
|
| 36 |
+
dir_files = sorted(dir_files)
|
| 37 |
+
dir_files = [os.path.join(directory, x) for x in dir_files]
|
| 38 |
+
# start at start_index
|
| 39 |
+
dir_files = dir_files[start_index:]
|
| 40 |
+
|
| 41 |
+
images = []
|
| 42 |
+
masks = []
|
| 43 |
+
|
| 44 |
+
limit_images = False
|
| 45 |
+
if image_load_cap > 0:
|
| 46 |
+
limit_images = True
|
| 47 |
+
image_count = 0
|
| 48 |
+
|
| 49 |
+
for image_path in dir_files:
|
| 50 |
+
if os.path.isdir(image_path):
|
| 51 |
+
continue
|
| 52 |
+
if limit_images and image_count >= image_load_cap:
|
| 53 |
+
break
|
| 54 |
+
i = Image.open(image_path)
|
| 55 |
+
i = ImageOps.exif_transpose(i)
|
| 56 |
+
image = i.convert("RGB")
|
| 57 |
+
image = np.array(image).astype(np.float32) / 255.0
|
| 58 |
+
image = torch.from_numpy(image)[None,]
|
| 59 |
+
if 'A' in i.getbands():
|
| 60 |
+
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
| 61 |
+
mask = 1. - torch.from_numpy(mask)
|
| 62 |
+
else:
|
| 63 |
+
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
| 64 |
+
images.append(image)
|
| 65 |
+
masks.append(mask)
|
| 66 |
+
image_count += 1
|
| 67 |
+
|
| 68 |
+
if len(images) == 0:
|
| 69 |
+
raise FileNotFoundError(f"No images could be loaded from directory '{directory}'.")
|
| 70 |
+
|
| 71 |
+
return (torch.cat(images, dim=0), torch.stack(masks, dim=0), image_count)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_keyframes.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
import numpy as np
|
| 3 |
+
from collections.abc import Iterable
|
| 4 |
+
|
| 5 |
+
from .utils import ControlWeights, TimestepKeyframe, TimestepKeyframeGroup, LatentKeyframe, LatentKeyframeGroup, BIGMIN, BIGMAX
|
| 6 |
+
from .utils import StrengthInterpolation as SI
|
| 7 |
+
from .logger import logger
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TimestepKeyframeNode:
|
| 11 |
+
OUTDATED_DUMMY = -39
|
| 12 |
+
|
| 13 |
+
@classmethod
|
| 14 |
+
def INPUT_TYPES(s):
|
| 15 |
+
return {
|
| 16 |
+
"required": {
|
| 17 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
| 18 |
+
},
|
| 19 |
+
"optional": {
|
| 20 |
+
"prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
| 21 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 22 |
+
"cn_weights": ("CONTROL_NET_WEIGHTS", ),
|
| 23 |
+
"latent_keyframe": ("LATENT_KEYFRAME", ),
|
| 24 |
+
"null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 25 |
+
"inherit_missing": ("BOOLEAN", {"default": True}, ),
|
| 26 |
+
"guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}),
|
| 27 |
+
"mask_optional": ("MASK", ),
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
RETURN_NAMES = ("TIMESTEP_KF", )
|
| 32 |
+
RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
|
| 33 |
+
FUNCTION = "load_keyframe"
|
| 34 |
+
|
| 35 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
| 36 |
+
|
| 37 |
+
def load_keyframe(self,
|
| 38 |
+
start_percent: float,
|
| 39 |
+
strength: float=1.0,
|
| 40 |
+
cn_weights: ControlWeights=None, control_net_weights: ControlWeights=None, # old name
|
| 41 |
+
latent_keyframe: LatentKeyframeGroup=None,
|
| 42 |
+
prev_timestep_kf: TimestepKeyframeGroup=None, prev_timestep_keyframe: TimestepKeyframeGroup=None, # old name
|
| 43 |
+
null_latent_kf_strength: float=0.0,
|
| 44 |
+
inherit_missing=True,
|
| 45 |
+
guarantee_steps=OUTDATED_DUMMY,
|
| 46 |
+
guarantee_usage=True, # old input
|
| 47 |
+
mask_optional=None,):
|
| 48 |
+
# if using outdated dummy value, means node on workflow is outdated and should appropriately convert behavior
|
| 49 |
+
if guarantee_steps == self.OUTDATED_DUMMY:
|
| 50 |
+
guarantee_steps = int(guarantee_usage)
|
| 51 |
+
control_net_weights = control_net_weights if control_net_weights else cn_weights
|
| 52 |
+
prev_timestep_keyframe = prev_timestep_keyframe if prev_timestep_keyframe else prev_timestep_kf
|
| 53 |
+
if not prev_timestep_keyframe:
|
| 54 |
+
prev_timestep_keyframe = TimestepKeyframeGroup()
|
| 55 |
+
else:
|
| 56 |
+
prev_timestep_keyframe = prev_timestep_keyframe.clone()
|
| 57 |
+
keyframe = TimestepKeyframe(start_percent=start_percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
|
| 58 |
+
control_weights=control_net_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
|
| 59 |
+
guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional)
|
| 60 |
+
prev_timestep_keyframe.add(keyframe)
|
| 61 |
+
return (prev_timestep_keyframe,)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class TimestepKeyframeInterpolationNode:
|
| 65 |
+
@classmethod
|
| 66 |
+
def INPUT_TYPES(s):
|
| 67 |
+
return {
|
| 68 |
+
"required": {
|
| 69 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001},),
|
| 70 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 71 |
+
"strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},),
|
| 72 |
+
"strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},),
|
| 73 |
+
"interpolation": (SI._LIST, ),
|
| 74 |
+
"intervals": ("INT", {"default": 50, "min": 2, "max": 100, "step": 1}),
|
| 75 |
+
},
|
| 76 |
+
"optional": {
|
| 77 |
+
"prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
| 78 |
+
"cn_weights": ("CONTROL_NET_WEIGHTS", ),
|
| 79 |
+
"latent_keyframe": ("LATENT_KEYFRAME", ),
|
| 80 |
+
"null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001},),
|
| 81 |
+
"inherit_missing": ("BOOLEAN", {"default": True},),
|
| 82 |
+
"mask_optional": ("MASK", ),
|
| 83 |
+
"print_keyframes": ("BOOLEAN", {"default": False}),
|
| 84 |
+
}
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
RETURN_NAMES = ("TIMESTEP_KF", )
|
| 88 |
+
RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
|
| 89 |
+
FUNCTION = "load_keyframe"
|
| 90 |
+
|
| 91 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
| 92 |
+
|
| 93 |
+
def load_keyframe(self,
|
| 94 |
+
start_percent: float, end_percent: float,
|
| 95 |
+
strength_start: float, strength_end: float, interpolation: str, intervals: int,
|
| 96 |
+
cn_weights: ControlWeights=None,
|
| 97 |
+
latent_keyframe: LatentKeyframeGroup=None,
|
| 98 |
+
prev_timestep_kf: TimestepKeyframeGroup=None,
|
| 99 |
+
null_latent_kf_strength: float=0.0,
|
| 100 |
+
inherit_missing=True,
|
| 101 |
+
guarantee_steps=1,
|
| 102 |
+
mask_optional=None, print_keyframes=False):
|
| 103 |
+
if not prev_timestep_kf:
|
| 104 |
+
prev_timestep_kf = TimestepKeyframeGroup()
|
| 105 |
+
else:
|
| 106 |
+
prev_timestep_kf = prev_timestep_kf.clone()
|
| 107 |
+
|
| 108 |
+
percents = SI.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=SI.LINEAR)
|
| 109 |
+
strengths = SI.get_weights(num_from=strength_start, num_to=strength_end, length=intervals, method=interpolation)
|
| 110 |
+
|
| 111 |
+
is_first = True
|
| 112 |
+
for percent, strength in zip(percents, strengths):
|
| 113 |
+
guarantee_steps = 0
|
| 114 |
+
if is_first:
|
| 115 |
+
guarantee_steps = 1
|
| 116 |
+
is_first = False
|
| 117 |
+
prev_timestep_kf.add(TimestepKeyframe(start_percent=percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
|
| 118 |
+
control_weights=cn_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
|
| 119 |
+
guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional))
|
| 120 |
+
if print_keyframes:
|
| 121 |
+
logger.info(f"TimestepKeyframe - start_percent:{percent} = {strength}")
|
| 122 |
+
return (prev_timestep_kf,)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class TimestepKeyframeFromStrengthListNode:
|
| 126 |
+
@classmethod
|
| 127 |
+
def INPUT_TYPES(s):
|
| 128 |
+
return {
|
| 129 |
+
"required": {
|
| 130 |
+
"float_strengths": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
| 131 |
+
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001},),
|
| 132 |
+
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 133 |
+
},
|
| 134 |
+
"optional": {
|
| 135 |
+
"prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
|
| 136 |
+
"cn_weights": ("CONTROL_NET_WEIGHTS", ),
|
| 137 |
+
"latent_keyframe": ("LATENT_KEYFRAME", ),
|
| 138 |
+
"null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001},),
|
| 139 |
+
"inherit_missing": ("BOOLEAN", {"default": True},),
|
| 140 |
+
"mask_optional": ("MASK", ),
|
| 141 |
+
"print_keyframes": ("BOOLEAN", {"default": False}),
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
RETURN_NAMES = ("TIMESTEP_KF", )
|
| 146 |
+
RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
|
| 147 |
+
FUNCTION = "load_keyframe"
|
| 148 |
+
|
| 149 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
| 150 |
+
|
| 151 |
+
def load_keyframe(self,
|
| 152 |
+
start_percent: float, end_percent: float,
|
| 153 |
+
float_strengths: float,
|
| 154 |
+
cn_weights: ControlWeights=None,
|
| 155 |
+
latent_keyframe: LatentKeyframeGroup=None,
|
| 156 |
+
prev_timestep_kf: TimestepKeyframeGroup=None,
|
| 157 |
+
null_latent_kf_strength: float=0.0,
|
| 158 |
+
inherit_missing=True,
|
| 159 |
+
guarantee_steps=1,
|
| 160 |
+
mask_optional=None, print_keyframes=False):
|
| 161 |
+
if not prev_timestep_kf:
|
| 162 |
+
prev_timestep_kf = TimestepKeyframeGroup()
|
| 163 |
+
else:
|
| 164 |
+
prev_timestep_kf = prev_timestep_kf.clone()
|
| 165 |
+
|
| 166 |
+
if type(float_strengths) in (float, int):
|
| 167 |
+
float_strengths = [float(float_strengths)]
|
| 168 |
+
elif isinstance(float_strengths, Iterable):
|
| 169 |
+
pass
|
| 170 |
+
else:
|
| 171 |
+
raise Exception(f"strengths_float must be either an iterable input or a float, but was {type(float_strengths).__repr__}.")
|
| 172 |
+
percents = SI.get_weights(num_from=start_percent, num_to=end_percent, length=len(float_strengths), method=SI.LINEAR)
|
| 173 |
+
|
| 174 |
+
is_first = True
|
| 175 |
+
for percent, strength in zip(percents, float_strengths):
|
| 176 |
+
guarantee_steps = 0
|
| 177 |
+
if is_first:
|
| 178 |
+
guarantee_steps = 1
|
| 179 |
+
is_first = False
|
| 180 |
+
prev_timestep_kf.add(TimestepKeyframe(start_percent=percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
|
| 181 |
+
control_weights=cn_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
|
| 182 |
+
guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional))
|
| 183 |
+
if print_keyframes:
|
| 184 |
+
logger.info(f"TimestepKeyframe - start_percent:{percent} = {strength}")
|
| 185 |
+
return (prev_timestep_kf,)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class LatentKeyframeNode:
|
| 189 |
+
@classmethod
|
| 190 |
+
def INPUT_TYPES(s):
|
| 191 |
+
return {
|
| 192 |
+
"required": {
|
| 193 |
+
"batch_index": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
|
| 194 |
+
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 195 |
+
},
|
| 196 |
+
"optional": {
|
| 197 |
+
"prev_latent_kf": ("LATENT_KEYFRAME", ),
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
RETURN_NAMES = ("LATENT_KF", )
|
| 202 |
+
RETURN_TYPES = ("LATENT_KEYFRAME", )
|
| 203 |
+
FUNCTION = "load_keyframe"
|
| 204 |
+
|
| 205 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
| 206 |
+
|
| 207 |
+
def load_keyframe(self,
|
| 208 |
+
batch_index: int,
|
| 209 |
+
strength: float,
|
| 210 |
+
prev_latent_kf: LatentKeyframeGroup=None,
|
| 211 |
+
prev_latent_keyframe: LatentKeyframeGroup=None, # old name
|
| 212 |
+
):
|
| 213 |
+
prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
|
| 214 |
+
if not prev_latent_keyframe:
|
| 215 |
+
prev_latent_keyframe = LatentKeyframeGroup()
|
| 216 |
+
else:
|
| 217 |
+
prev_latent_keyframe = prev_latent_keyframe.clone()
|
| 218 |
+
keyframe = LatentKeyframe(batch_index, strength)
|
| 219 |
+
prev_latent_keyframe.add(keyframe)
|
| 220 |
+
return (prev_latent_keyframe,)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class LatentKeyframeGroupNode:
|
| 224 |
+
@classmethod
|
| 225 |
+
def INPUT_TYPES(s):
|
| 226 |
+
return {
|
| 227 |
+
"required": {
|
| 228 |
+
"index_strengths": ("STRING", {"multiline": True, "default": ""}),
|
| 229 |
+
},
|
| 230 |
+
"optional": {
|
| 231 |
+
"prev_latent_kf": ("LATENT_KEYFRAME", ),
|
| 232 |
+
"latent_optional": ("LATENT", ),
|
| 233 |
+
"print_keyframes": ("BOOLEAN", {"default": False})
|
| 234 |
+
}
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
RETURN_NAMES = ("LATENT_KF", )
|
| 238 |
+
RETURN_TYPES = ("LATENT_KEYFRAME", )
|
| 239 |
+
FUNCTION = "load_keyframes"
|
| 240 |
+
|
| 241 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
| 242 |
+
|
| 243 |
+
def validate_index(self, index: int, latent_count: int = 0, is_range: bool = False, allow_negative = False) -> int:
|
| 244 |
+
# if part of range, do nothing
|
| 245 |
+
if is_range:
|
| 246 |
+
return index
|
| 247 |
+
# otherwise, validate index
|
| 248 |
+
# validate not out of range - only when latent_count is passed in
|
| 249 |
+
if latent_count > 0 and index > latent_count-1:
|
| 250 |
+
raise IndexError(f"Index '{index}' out of range for the total {latent_count} latents.")
|
| 251 |
+
# if negative, validate not out of range
|
| 252 |
+
if index < 0:
|
| 253 |
+
if not allow_negative:
|
| 254 |
+
raise IndexError(f"Negative indeces not allowed, but was {index}.")
|
| 255 |
+
conv_index = latent_count+index
|
| 256 |
+
if conv_index < 0:
|
| 257 |
+
raise IndexError(f"Index '{index}', converted to '{conv_index}' out of range for the total {latent_count} latents.")
|
| 258 |
+
index = conv_index
|
| 259 |
+
return index
|
| 260 |
+
|
| 261 |
+
def convert_to_index_int(self, raw_index: str, latent_count: int = 0, is_range: bool = False, allow_negative = False) -> int:
|
| 262 |
+
try:
|
| 263 |
+
return self.validate_index(int(raw_index), latent_count=latent_count, is_range=is_range, allow_negative=allow_negative)
|
| 264 |
+
except ValueError as e:
|
| 265 |
+
raise ValueError(f"index '{raw_index}' must be an integer.", e)
|
| 266 |
+
|
| 267 |
+
def convert_to_latent_keyframes(self, latent_indeces: str, latent_count: int) -> set[LatentKeyframe]:
|
| 268 |
+
if not latent_indeces:
|
| 269 |
+
return set()
|
| 270 |
+
int_latent_indeces = [i for i in range(0, latent_count)]
|
| 271 |
+
allow_negative = latent_count > 0
|
| 272 |
+
chosen_indeces = set()
|
| 273 |
+
# parse string - allow positive ints, negative ints, and ranges separated by ':'
|
| 274 |
+
groups = latent_indeces.split(",")
|
| 275 |
+
groups = [g.strip() for g in groups]
|
| 276 |
+
for g in groups:
|
| 277 |
+
# parse strengths - default to 1.0 if no strength given
|
| 278 |
+
strength = 1.0
|
| 279 |
+
if '=' in g:
|
| 280 |
+
g, strength_str = g.split("=", 1)
|
| 281 |
+
g = g.strip()
|
| 282 |
+
try:
|
| 283 |
+
strength = float(strength_str.strip())
|
| 284 |
+
except ValueError as e:
|
| 285 |
+
raise ValueError(f"strength '{strength_str}' must be a float.", e)
|
| 286 |
+
if strength < 0:
|
| 287 |
+
raise ValueError(f"Strength '{strength}' cannot be negative.")
|
| 288 |
+
# parse range of indeces (e.g. 2:16)
|
| 289 |
+
if ':' in g:
|
| 290 |
+
index_range = g.split(":", 1)
|
| 291 |
+
index_range = [r.strip() for r in index_range]
|
| 292 |
+
start_index = self.convert_to_index_int(index_range[0], latent_count=latent_count, is_range=True, allow_negative=allow_negative)
|
| 293 |
+
end_index = self.convert_to_index_int(index_range[1], latent_count=latent_count, is_range=True, allow_negative=allow_negative)
|
| 294 |
+
# if latents were passed in, base indeces on known latent count
|
| 295 |
+
if len(int_latent_indeces) > 0:
|
| 296 |
+
for i in int_latent_indeces[start_index:end_index]:
|
| 297 |
+
chosen_indeces.add(LatentKeyframe(i, strength))
|
| 298 |
+
# otherwise, assume indeces are valid
|
| 299 |
+
else:
|
| 300 |
+
for i in range(start_index, end_index):
|
| 301 |
+
chosen_indeces.add(LatentKeyframe(i, strength))
|
| 302 |
+
# parse individual indeces
|
| 303 |
+
else:
|
| 304 |
+
chosen_indeces.add(LatentKeyframe(self.convert_to_index_int(g, latent_count=latent_count, allow_negative=allow_negative), strength))
|
| 305 |
+
return chosen_indeces
|
| 306 |
+
|
| 307 |
+
def load_keyframes(self,
|
| 308 |
+
index_strengths: str,
|
| 309 |
+
prev_latent_kf: LatentKeyframeGroup=None,
|
| 310 |
+
prev_latent_keyframe: LatentKeyframeGroup=None, # old name
|
| 311 |
+
latent_image_opt=None,
|
| 312 |
+
print_keyframes=False):
|
| 313 |
+
prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
|
| 314 |
+
if not prev_latent_keyframe:
|
| 315 |
+
prev_latent_keyframe = LatentKeyframeGroup()
|
| 316 |
+
else:
|
| 317 |
+
prev_latent_keyframe = prev_latent_keyframe.clone()
|
| 318 |
+
curr_latent_keyframe = LatentKeyframeGroup()
|
| 319 |
+
|
| 320 |
+
latent_count = -1
|
| 321 |
+
if latent_image_opt:
|
| 322 |
+
latent_count = latent_image_opt['samples'].size()[0]
|
| 323 |
+
latent_keyframes = self.convert_to_latent_keyframes(index_strengths, latent_count=latent_count)
|
| 324 |
+
|
| 325 |
+
for latent_keyframe in latent_keyframes:
|
| 326 |
+
curr_latent_keyframe.add(latent_keyframe)
|
| 327 |
+
|
| 328 |
+
if print_keyframes:
|
| 329 |
+
for keyframe in curr_latent_keyframe.keyframes:
|
| 330 |
+
logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
|
| 331 |
+
|
| 332 |
+
# replace values with prev_latent_keyframes
|
| 333 |
+
for latent_keyframe in prev_latent_keyframe.keyframes:
|
| 334 |
+
curr_latent_keyframe.add(latent_keyframe)
|
| 335 |
+
|
| 336 |
+
return (curr_latent_keyframe,)
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class LatentKeyframeInterpolationNode:
|
| 340 |
+
@classmethod
|
| 341 |
+
def INPUT_TYPES(s):
|
| 342 |
+
return {
|
| 343 |
+
"required": {
|
| 344 |
+
"batch_index_from": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
|
| 345 |
+
"batch_index_to_excl": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
|
| 346 |
+
"strength_from": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 347 |
+
"strength_to": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 348 |
+
"interpolation": (SI._LIST, ),
|
| 349 |
+
},
|
| 350 |
+
"optional": {
|
| 351 |
+
"prev_latent_kf": ("LATENT_KEYFRAME", ),
|
| 352 |
+
"print_keyframes": ("BOOLEAN", {"default": False})
|
| 353 |
+
}
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
RETURN_NAMES = ("LATENT_KF", )
|
| 357 |
+
RETURN_TYPES = ("LATENT_KEYFRAME", )
|
| 358 |
+
FUNCTION = "load_keyframe"
|
| 359 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
| 360 |
+
|
| 361 |
+
def load_keyframe(self,
|
| 362 |
+
batch_index_from: int,
|
| 363 |
+
strength_from: float,
|
| 364 |
+
batch_index_to_excl: int,
|
| 365 |
+
strength_to: float,
|
| 366 |
+
interpolation: str,
|
| 367 |
+
prev_latent_kf: LatentKeyframeGroup=None,
|
| 368 |
+
prev_latent_keyframe: LatentKeyframeGroup=None, # old name
|
| 369 |
+
print_keyframes=False):
|
| 370 |
+
|
| 371 |
+
if (batch_index_from > batch_index_to_excl):
|
| 372 |
+
raise ValueError("batch_index_from must be less than or equal to batch_index_to.")
|
| 373 |
+
|
| 374 |
+
if (batch_index_from < 0 and batch_index_to_excl >= 0):
|
| 375 |
+
raise ValueError("batch_index_from and batch_index_to must be either both positive or both negative.")
|
| 376 |
+
|
| 377 |
+
prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
|
| 378 |
+
if not prev_latent_keyframe:
|
| 379 |
+
prev_latent_keyframe = LatentKeyframeGroup()
|
| 380 |
+
else:
|
| 381 |
+
prev_latent_keyframe = prev_latent_keyframe.clone()
|
| 382 |
+
curr_latent_keyframe = LatentKeyframeGroup()
|
| 383 |
+
|
| 384 |
+
steps = batch_index_to_excl - batch_index_from
|
| 385 |
+
diff = strength_to - strength_from
|
| 386 |
+
if interpolation == SI.LINEAR:
|
| 387 |
+
weights = np.linspace(strength_from, strength_to, steps)
|
| 388 |
+
elif interpolation == SI.EASE_IN:
|
| 389 |
+
index = np.linspace(0, 1, steps)
|
| 390 |
+
weights = diff * np.power(index, 2) + strength_from
|
| 391 |
+
elif interpolation == SI.EASE_OUT:
|
| 392 |
+
index = np.linspace(0, 1, steps)
|
| 393 |
+
weights = diff * (1 - np.power(1 - index, 2)) + strength_from
|
| 394 |
+
elif interpolation == SI.EASE_IN_OUT:
|
| 395 |
+
index = np.linspace(0, 1, steps)
|
| 396 |
+
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + strength_from
|
| 397 |
+
|
| 398 |
+
for i in range(steps):
|
| 399 |
+
keyframe = LatentKeyframe(batch_index_from + i, float(weights[i]))
|
| 400 |
+
curr_latent_keyframe.add(keyframe)
|
| 401 |
+
|
| 402 |
+
if print_keyframes:
|
| 403 |
+
for keyframe in curr_latent_keyframe.keyframes:
|
| 404 |
+
logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
|
| 405 |
+
|
| 406 |
+
# replace values with prev_latent_keyframes
|
| 407 |
+
for latent_keyframe in prev_latent_keyframe.keyframes:
|
| 408 |
+
curr_latent_keyframe.add(latent_keyframe)
|
| 409 |
+
|
| 410 |
+
return (curr_latent_keyframe,)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
class LatentKeyframeBatchedGroupNode:
|
| 414 |
+
@classmethod
|
| 415 |
+
def INPUT_TYPES(s):
|
| 416 |
+
return {
|
| 417 |
+
"required": {
|
| 418 |
+
"float_strengths": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
|
| 419 |
+
},
|
| 420 |
+
"optional": {
|
| 421 |
+
"prev_latent_kf": ("LATENT_KEYFRAME", ),
|
| 422 |
+
"print_keyframes": ("BOOLEAN", {"default": False})
|
| 423 |
+
}
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
RETURN_NAMES = ("LATENT_KF", )
|
| 427 |
+
RETURN_TYPES = ("LATENT_KEYFRAME", )
|
| 428 |
+
FUNCTION = "load_keyframe"
|
| 429 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
|
| 430 |
+
|
| 431 |
+
def load_keyframe(self, float_strengths: Union[float, list[float]],
|
| 432 |
+
prev_latent_kf: LatentKeyframeGroup=None,
|
| 433 |
+
prev_latent_keyframe: LatentKeyframeGroup=None, # old name
|
| 434 |
+
print_keyframes=False):
|
| 435 |
+
prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
|
| 436 |
+
if not prev_latent_keyframe:
|
| 437 |
+
prev_latent_keyframe = LatentKeyframeGroup()
|
| 438 |
+
else:
|
| 439 |
+
prev_latent_keyframe = prev_latent_keyframe.clone()
|
| 440 |
+
curr_latent_keyframe = LatentKeyframeGroup()
|
| 441 |
+
|
| 442 |
+
# if received a normal float input, do nothing
|
| 443 |
+
if type(float_strengths) in (float, int):
|
| 444 |
+
logger.info("No batched float_strengths passed into Latent Keyframe Batch Group node; will not create any new keyframes.")
|
| 445 |
+
# if iterable, attempt to create LatentKeyframes with chosen strengths
|
| 446 |
+
elif isinstance(float_strengths, Iterable):
|
| 447 |
+
for idx, strength in enumerate(float_strengths):
|
| 448 |
+
keyframe = LatentKeyframe(idx, strength)
|
| 449 |
+
curr_latent_keyframe.add(keyframe)
|
| 450 |
+
else:
|
| 451 |
+
raise ValueError(f"Expected strengths to be an iterable input, but was {type(float_strengths).__repr__}.")
|
| 452 |
+
|
| 453 |
+
if print_keyframes:
|
| 454 |
+
for keyframe in curr_latent_keyframe.keyframes:
|
| 455 |
+
logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
|
| 456 |
+
|
| 457 |
+
# replace values with prev_latent_keyframes
|
| 458 |
+
for latent_keyframe in prev_latent_keyframe.keyframes:
|
| 459 |
+
curr_latent_keyframe.add(latent_keyframe)
|
| 460 |
+
|
| 461 |
+
return (curr_latent_keyframe,)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_loosecontrol.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import folder_paths
|
| 2 |
+
import comfy.utils
|
| 3 |
+
import comfy.model_detection
|
| 4 |
+
import comfy.model_management
|
| 5 |
+
import comfy.lora
|
| 6 |
+
from comfy.model_patcher import ModelPatcher
|
| 7 |
+
|
| 8 |
+
from .utils import TimestepKeyframeGroup
|
| 9 |
+
from .control import ControlNetAdvanced, load_controlnet
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def convert_cn_lora_from_diffusers(cn_model: ModelPatcher, lora_path: str):
|
| 15 |
+
lora_data = comfy.utils.load_torch_file(lora_path, safe_load=True)
|
| 16 |
+
unet_dtype = comfy.model_management.unet_dtype()
|
| 17 |
+
for key, value in lora_data.items():
|
| 18 |
+
lora_data[key] = value.to(unet_dtype)
|
| 19 |
+
diffusers_keys = comfy.utils.unet_to_diffusers(cn_model.model.state_dict())
|
| 20 |
+
|
| 21 |
+
#lora_data = comfy.model_detection.unet_config_from_diffusers_unet(lora_data, dtype=unet_dtype)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
#key_map = comfy.lora.model_lora_keys_unet(cn_model.model, key_map)
|
| 26 |
+
lora_data = comfy.lora.load_lora(lora_data, to_load=diffusers_keys)
|
| 27 |
+
|
| 28 |
+
# TODO: detect if diffusers for sure? not sure if needed at this time, since cn loras are
|
| 29 |
+
# only used currently for LOOSEControl, and those are all in diffusers format
|
| 30 |
+
#unet_dtype = comfy.model_management.unet_dtype()
|
| 31 |
+
#lora_data = comfy.model_detection.unet_config_from_diffusers_unet(lora_data, unet_dtype)
|
| 32 |
+
return lora_data
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class ControlNetLoaderWithLoraAdvanced:
|
| 36 |
+
@classmethod
|
| 37 |
+
def INPUT_TYPES(s):
|
| 38 |
+
return {
|
| 39 |
+
"required": {
|
| 40 |
+
"control_net_name": (folder_paths.get_filename_list("controlnet"), ),
|
| 41 |
+
"cn_lora_name": (folder_paths.get_filename_list("controlnet"), ),
|
| 42 |
+
"cn_lora_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
| 43 |
+
},
|
| 44 |
+
"optional": {
|
| 45 |
+
"timestep_keyframe": ("TIMESTEP_KEYFRAME", ),
|
| 46 |
+
}
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
| 50 |
+
FUNCTION = "load_controlnet"
|
| 51 |
+
|
| 52 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/LOOSEControl"
|
| 53 |
+
|
| 54 |
+
def load_controlnet(self, control_net_name, cn_lora_name, cn_lora_strength: float,
|
| 55 |
+
timestep_keyframe: TimestepKeyframeGroup=None
|
| 56 |
+
):
|
| 57 |
+
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
| 58 |
+
controlnet: ControlNetAdvanced = load_controlnet(controlnet_path, timestep_keyframe)
|
| 59 |
+
if not isinstance(controlnet, ControlNetAdvanced):
|
| 60 |
+
raise ValueError("Type {} is not compatible with CN LoRA features at this time.")
|
| 61 |
+
# now, try to load CN LoRA
|
| 62 |
+
lora_path = folder_paths.get_full_path("controlnet", cn_lora_name)
|
| 63 |
+
lora_data = convert_cn_lora_from_diffusers(cn_model=controlnet.control_model_wrapped, lora_path=lora_path)
|
| 64 |
+
# apply patches to wrapped control_model
|
| 65 |
+
controlnet.control_model_wrapped.add_patches(lora_data, strength_patch=cn_lora_strength)
|
| 66 |
+
# all done
|
| 67 |
+
return (controlnet,)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_reference.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import Tensor
|
| 2 |
+
|
| 3 |
+
from nodes import VAEEncode
|
| 4 |
+
import comfy.utils
|
| 5 |
+
from comfy.sd import VAE
|
| 6 |
+
|
| 7 |
+
from .control_reference import ReferenceAdvanced, ReferenceOptions, ReferenceType, ReferencePreprocWrapper
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# node for ReferenceCN
|
| 11 |
+
class ReferenceControlNetNode:
|
| 12 |
+
@classmethod
|
| 13 |
+
def INPUT_TYPES(s):
|
| 14 |
+
return {
|
| 15 |
+
"required": {
|
| 16 |
+
"reference_type": (ReferenceType._LIST,),
|
| 17 |
+
"style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 18 |
+
"ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 19 |
+
},
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
| 23 |
+
FUNCTION = "load_controlnet"
|
| 24 |
+
|
| 25 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference"
|
| 26 |
+
|
| 27 |
+
def load_controlnet(self, reference_type: str, style_fidelity: float, ref_weight: float):
|
| 28 |
+
ref_opts = ReferenceOptions.create_combo(reference_type=reference_type, style_fidelity=style_fidelity, ref_weight=ref_weight)
|
| 29 |
+
controlnet = ReferenceAdvanced(ref_opts=ref_opts, timestep_keyframes=None)
|
| 30 |
+
return (controlnet,)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ReferenceControlFinetune:
|
| 34 |
+
@classmethod
|
| 35 |
+
def INPUT_TYPES(s):
|
| 36 |
+
return {
|
| 37 |
+
"required": {
|
| 38 |
+
"attn_style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 39 |
+
"attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 40 |
+
"attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 41 |
+
"adain_style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 42 |
+
"adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 43 |
+
"adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
| 44 |
+
},
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
| 48 |
+
FUNCTION = "load_controlnet"
|
| 49 |
+
|
| 50 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference"
|
| 51 |
+
|
| 52 |
+
def load_controlnet(self,
|
| 53 |
+
attn_style_fidelity: float, attn_ref_weight: float, attn_strength: float,
|
| 54 |
+
adain_style_fidelity: float, adain_ref_weight: float, adain_strength: float):
|
| 55 |
+
ref_opts = ReferenceOptions(reference_type=ReferenceType.ATTN_ADAIN,
|
| 56 |
+
attn_style_fidelity=attn_style_fidelity, attn_ref_weight=attn_ref_weight, attn_strength=attn_strength,
|
| 57 |
+
adain_style_fidelity=adain_style_fidelity, adain_ref_weight=adain_ref_weight, adain_strength=adain_strength)
|
| 58 |
+
controlnet = ReferenceAdvanced(ref_opts=ref_opts, timestep_keyframes=None)
|
| 59 |
+
return (controlnet,)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ReferencePreprocessorNode:
|
| 63 |
+
@classmethod
|
| 64 |
+
def INPUT_TYPES(s):
|
| 65 |
+
return {
|
| 66 |
+
"required": {
|
| 67 |
+
"image": ("IMAGE", ),
|
| 68 |
+
"vae": ("VAE", ),
|
| 69 |
+
"latent_size": ("LATENT", ),
|
| 70 |
+
}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
RETURN_TYPES = ("IMAGE",)
|
| 74 |
+
RETURN_NAMES = ("proc_IMAGE",)
|
| 75 |
+
FUNCTION = "preprocess_images"
|
| 76 |
+
|
| 77 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference/preprocess"
|
| 78 |
+
|
| 79 |
+
def preprocess_images(self, vae: VAE, image: Tensor, latent_size: Tensor):
|
| 80 |
+
# first, resize image to match latents
|
| 81 |
+
image = image.movedim(-1,1)
|
| 82 |
+
image = comfy.utils.common_upscale(image, latent_size["samples"].shape[3] * 8, latent_size["samples"].shape[2] * 8, 'nearest-exact', "center")
|
| 83 |
+
image = image.movedim(1,-1)
|
| 84 |
+
# then, vae encode
|
| 85 |
+
try:
|
| 86 |
+
image = vae.vae_encode_crop_pixels(image)
|
| 87 |
+
except Exception:
|
| 88 |
+
image = VAEEncode.vae_encode_crop_pixels(image)
|
| 89 |
+
encoded = vae.encode(image[:,:,:,:3])
|
| 90 |
+
return (ReferencePreprocWrapper(condhint=encoded),)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_sparsectrl.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import Tensor
|
| 2 |
+
|
| 3 |
+
import folder_paths
|
| 4 |
+
from nodes import VAEEncode
|
| 5 |
+
import comfy.utils
|
| 6 |
+
from comfy.sd import VAE
|
| 7 |
+
|
| 8 |
+
from .utils import TimestepKeyframeGroup
|
| 9 |
+
from .control_sparsectrl import SparseMethod, SparseIndexMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper
|
| 10 |
+
from .control import load_sparsectrl, load_controlnet, ControlNetAdvanced, SparseCtrlAdvanced
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# node for SparseCtrl loading
|
| 14 |
+
class SparseCtrlLoaderAdvanced:
|
| 15 |
+
@classmethod
|
| 16 |
+
def INPUT_TYPES(s):
|
| 17 |
+
return {
|
| 18 |
+
"required": {
|
| 19 |
+
"sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ),
|
| 20 |
+
"use_motion": ("BOOLEAN", {"default": True}, ),
|
| 21 |
+
"motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 22 |
+
"motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 23 |
+
},
|
| 24 |
+
"optional": {
|
| 25 |
+
"sparse_method": ("SPARSE_METHOD", ),
|
| 26 |
+
"tk_optional": ("TIMESTEP_KEYFRAME", ),
|
| 27 |
+
}
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
| 31 |
+
FUNCTION = "load_controlnet"
|
| 32 |
+
|
| 33 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
|
| 34 |
+
|
| 35 |
+
def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
|
| 36 |
+
sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
|
| 37 |
+
sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale)
|
| 38 |
+
sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
|
| 39 |
+
return (sparsectrl,)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class SparseCtrlMergedLoaderAdvanced:
|
| 43 |
+
@classmethod
|
| 44 |
+
def INPUT_TYPES(s):
|
| 45 |
+
return {
|
| 46 |
+
"required": {
|
| 47 |
+
"sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ),
|
| 48 |
+
"control_net_name": (folder_paths.get_filename_list("controlnet"), ),
|
| 49 |
+
"use_motion": ("BOOLEAN", {"default": True}, ),
|
| 50 |
+
"motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 51 |
+
"motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 52 |
+
},
|
| 53 |
+
"optional": {
|
| 54 |
+
"sparse_method": ("SPARSE_METHOD", ),
|
| 55 |
+
"tk_optional": ("TIMESTEP_KEYFRAME", ),
|
| 56 |
+
}
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
RETURN_TYPES = ("CONTROL_NET", )
|
| 60 |
+
FUNCTION = "load_controlnet"
|
| 61 |
+
|
| 62 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl/experimental"
|
| 63 |
+
|
| 64 |
+
def load_controlnet(self, sparsectrl_name: str, control_net_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
|
| 65 |
+
sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
|
| 66 |
+
controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
|
| 67 |
+
sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale, merged=True)
|
| 68 |
+
# first, load normal controlnet
|
| 69 |
+
controlnet = load_controlnet(controlnet_path, timestep_keyframe=tk_optional)
|
| 70 |
+
# confirm that controlnet is ControlNetAdvanced
|
| 71 |
+
if controlnet is None or type(controlnet) != ControlNetAdvanced:
|
| 72 |
+
raise ValueError(f"controlnet_path must point to a normal ControlNet, but instead: {type(controlnet).__name__}")
|
| 73 |
+
# next, load sparsectrl, making sure to load motion portion
|
| 74 |
+
sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=SparseSettings.default())
|
| 75 |
+
# now, combine state dicts
|
| 76 |
+
new_state_dict = controlnet.control_model.state_dict()
|
| 77 |
+
for key, value in sparsectrl.control_model.motion_holder.motion_wrapper.state_dict().items():
|
| 78 |
+
new_state_dict[key] = value
|
| 79 |
+
# now, reload sparsectrl with real settings
|
| 80 |
+
sparsectrl = load_sparsectrl(sparsectrl_path, controlnet_data=new_state_dict, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
|
| 81 |
+
return (sparsectrl,)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class SparseIndexMethodNode:
|
| 85 |
+
@classmethod
|
| 86 |
+
def INPUT_TYPES(s):
|
| 87 |
+
return {
|
| 88 |
+
"required": {
|
| 89 |
+
"indexes": ("STRING", {"default": "0"}),
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
RETURN_TYPES = ("SPARSE_METHOD",)
|
| 94 |
+
FUNCTION = "get_method"
|
| 95 |
+
|
| 96 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
|
| 97 |
+
|
| 98 |
+
def get_method(self, indexes: str):
|
| 99 |
+
idxs = []
|
| 100 |
+
unique_idxs = set()
|
| 101 |
+
# get indeces from string
|
| 102 |
+
str_idxs = [x.strip() for x in indexes.strip().split(",")]
|
| 103 |
+
for str_idx in str_idxs:
|
| 104 |
+
try:
|
| 105 |
+
idx = int(str_idx)
|
| 106 |
+
if idx in unique_idxs:
|
| 107 |
+
raise ValueError(f"'{idx}' is duplicated; indexes must be unique.")
|
| 108 |
+
idxs.append(idx)
|
| 109 |
+
unique_idxs.add(idx)
|
| 110 |
+
except ValueError:
|
| 111 |
+
raise ValueError(f"'{str_idx}' is not a valid integer index.")
|
| 112 |
+
if len(idxs) == 0:
|
| 113 |
+
raise ValueError(f"No indexes were listed in Sparse Index Method.")
|
| 114 |
+
return (SparseIndexMethod(idxs),)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class SparseSpreadMethodNode:
|
| 118 |
+
@classmethod
|
| 119 |
+
def INPUT_TYPES(s):
|
| 120 |
+
return {
|
| 121 |
+
"required": {
|
| 122 |
+
"spread": (SparseSpreadMethod.LIST,),
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
RETURN_TYPES = ("SPARSE_METHOD",)
|
| 127 |
+
FUNCTION = "get_method"
|
| 128 |
+
|
| 129 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
|
| 130 |
+
|
| 131 |
+
def get_method(self, spread: str):
|
| 132 |
+
return (SparseSpreadMethod(spread=spread),)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class RgbSparseCtrlPreprocessor:
|
| 136 |
+
@classmethod
|
| 137 |
+
def INPUT_TYPES(s):
|
| 138 |
+
return {
|
| 139 |
+
"required": {
|
| 140 |
+
"image": ("IMAGE", ),
|
| 141 |
+
"vae": ("VAE", ),
|
| 142 |
+
"latent_size": ("LATENT", ),
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
RETURN_TYPES = ("IMAGE",)
|
| 147 |
+
RETURN_NAMES = ("proc_IMAGE",)
|
| 148 |
+
FUNCTION = "preprocess_images"
|
| 149 |
+
|
| 150 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl/preprocess"
|
| 151 |
+
|
| 152 |
+
def preprocess_images(self, vae: VAE, image: Tensor, latent_size: Tensor):
|
| 153 |
+
# first, resize image to match latents
|
| 154 |
+
image = image.movedim(-1,1)
|
| 155 |
+
image = comfy.utils.common_upscale(image, latent_size["samples"].shape[3] * 8, latent_size["samples"].shape[2] * 8, 'nearest-exact', "center")
|
| 156 |
+
image = image.movedim(1,-1)
|
| 157 |
+
# then, vae encode
|
| 158 |
+
try:
|
| 159 |
+
image = vae.vae_encode_crop_pixels(image)
|
| 160 |
+
except Exception:
|
| 161 |
+
image = VAEEncode.vae_encode_crop_pixels(image)
|
| 162 |
+
encoded = vae.encode(image[:,:,:,:3])
|
| 163 |
+
return (PreprocSparseRGBWrapper(condhint=encoded),)
|
ComfyUI-Advanced-ControlNet/adv_control/nodes_weight.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import Tensor
|
| 2 |
+
import torch
|
| 3 |
+
from .utils import TimestepKeyframe, TimestepKeyframeGroup, ControlWeights, get_properly_arranged_t2i_weights, linear_conversion
|
| 4 |
+
from .logger import logger
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
WEIGHTS_RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DefaultWeights:
|
| 11 |
+
@classmethod
|
| 12 |
+
def INPUT_TYPES(s):
|
| 13 |
+
return {
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
| 17 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
| 18 |
+
FUNCTION = "load_weights"
|
| 19 |
+
|
| 20 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
|
| 21 |
+
|
| 22 |
+
def load_weights(self):
|
| 23 |
+
weights = ControlWeights.default()
|
| 24 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class ScaledSoftMaskedUniversalWeights:
|
| 28 |
+
@classmethod
|
| 29 |
+
def INPUT_TYPES(s):
|
| 30 |
+
return {
|
| 31 |
+
"required": {
|
| 32 |
+
"mask": ("MASK", ),
|
| 33 |
+
"min_base_multiplier": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
| 34 |
+
"max_base_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
| 35 |
+
#"lock_min": ("BOOLEAN", {"default": False}, ),
|
| 36 |
+
#"lock_max": ("BOOLEAN", {"default": False}, ),
|
| 37 |
+
},
|
| 38 |
+
"optional": {
|
| 39 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
| 44 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
| 45 |
+
FUNCTION = "load_weights"
|
| 46 |
+
|
| 47 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
|
| 48 |
+
|
| 49 |
+
def load_weights(self, mask: Tensor, min_base_multiplier: float, max_base_multiplier: float, lock_min=False, lock_max=False,
|
| 50 |
+
uncond_multiplier: float=1.0):
|
| 51 |
+
# normalize mask
|
| 52 |
+
mask = mask.clone()
|
| 53 |
+
x_min = 0.0 if lock_min else mask.min()
|
| 54 |
+
x_max = 1.0 if lock_max else mask.max()
|
| 55 |
+
if x_min == x_max:
|
| 56 |
+
mask = torch.ones_like(mask) * max_base_multiplier
|
| 57 |
+
else:
|
| 58 |
+
mask = linear_conversion(mask, x_min, x_max, min_base_multiplier, max_base_multiplier)
|
| 59 |
+
weights = ControlWeights.universal_mask(weight_mask=mask, uncond_multiplier=uncond_multiplier)
|
| 60 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ScaledSoftUniversalWeights:
|
| 64 |
+
@classmethod
|
| 65 |
+
def INPUT_TYPES(s):
|
| 66 |
+
return {
|
| 67 |
+
"required": {
|
| 68 |
+
"base_multiplier": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 1.0, "step": 0.001}, ),
|
| 69 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
| 70 |
+
},
|
| 71 |
+
"optional": {
|
| 72 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
| 73 |
+
}
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
| 77 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
| 78 |
+
FUNCTION = "load_weights"
|
| 79 |
+
|
| 80 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
|
| 81 |
+
|
| 82 |
+
def load_weights(self, base_multiplier, flip_weights, uncond_multiplier: float=1.0):
|
| 83 |
+
weights = ControlWeights.universal(base_multiplier=base_multiplier, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 84 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class SoftControlNetWeights:
|
| 88 |
+
@classmethod
|
| 89 |
+
def INPUT_TYPES(s):
|
| 90 |
+
return {
|
| 91 |
+
"required": {
|
| 92 |
+
"weight_00": ("FLOAT", {"default": 0.09941396206337118, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 93 |
+
"weight_01": ("FLOAT", {"default": 0.12050177219802567, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 94 |
+
"weight_02": ("FLOAT", {"default": 0.14606275417942507, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 95 |
+
"weight_03": ("FLOAT", {"default": 0.17704576264172736, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 96 |
+
"weight_04": ("FLOAT", {"default": 0.214600924414215, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 97 |
+
"weight_05": ("FLOAT", {"default": 0.26012233262329093, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 98 |
+
"weight_06": ("FLOAT", {"default": 0.3152997971191405, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 99 |
+
"weight_07": ("FLOAT", {"default": 0.3821815722656249, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 100 |
+
"weight_08": ("FLOAT", {"default": 0.4632503906249999, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 101 |
+
"weight_09": ("FLOAT", {"default": 0.561515625, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 102 |
+
"weight_10": ("FLOAT", {"default": 0.6806249999999999, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 103 |
+
"weight_11": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 104 |
+
"weight_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 105 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
| 106 |
+
},
|
| 107 |
+
"optional": {
|
| 108 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
| 113 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
| 114 |
+
FUNCTION = "load_weights"
|
| 115 |
+
|
| 116 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet"
|
| 117 |
+
|
| 118 |
+
def load_weights(self, weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
|
| 119 |
+
weight_07, weight_08, weight_09, weight_10, weight_11, weight_12, flip_weights,
|
| 120 |
+
uncond_multiplier: float=1.0):
|
| 121 |
+
weights = [weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
|
| 122 |
+
weight_07, weight_08, weight_09, weight_10, weight_11, weight_12]
|
| 123 |
+
weights = ControlWeights.controlnet(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 124 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
class CustomControlNetWeights:
|
| 128 |
+
@classmethod
|
| 129 |
+
def INPUT_TYPES(s):
|
| 130 |
+
return {
|
| 131 |
+
"required": {
|
| 132 |
+
"weight_00": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 133 |
+
"weight_01": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 134 |
+
"weight_02": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 135 |
+
"weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 136 |
+
"weight_04": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 137 |
+
"weight_05": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 138 |
+
"weight_06": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 139 |
+
"weight_07": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 140 |
+
"weight_08": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 141 |
+
"weight_09": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 142 |
+
"weight_10": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 143 |
+
"weight_11": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 144 |
+
"weight_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 145 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
| 146 |
+
},
|
| 147 |
+
"optional": {
|
| 148 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
| 153 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
| 154 |
+
FUNCTION = "load_weights"
|
| 155 |
+
|
| 156 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet"
|
| 157 |
+
|
| 158 |
+
def load_weights(self, weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
|
| 159 |
+
weight_07, weight_08, weight_09, weight_10, weight_11, weight_12, flip_weights,
|
| 160 |
+
uncond_multiplier: float=1.0):
|
| 161 |
+
weights = [weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
|
| 162 |
+
weight_07, weight_08, weight_09, weight_10, weight_11, weight_12]
|
| 163 |
+
weights = ControlWeights.controlnet(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 164 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class SoftT2IAdapterWeights:
|
| 168 |
+
@classmethod
|
| 169 |
+
def INPUT_TYPES(s):
|
| 170 |
+
return {
|
| 171 |
+
"required": {
|
| 172 |
+
"weight_00": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 173 |
+
"weight_01": ("FLOAT", {"default": 0.62, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 174 |
+
"weight_02": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 175 |
+
"weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 176 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
| 177 |
+
},
|
| 178 |
+
"optional": {
|
| 179 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
| 180 |
+
}
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
| 184 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
| 185 |
+
FUNCTION = "load_weights"
|
| 186 |
+
|
| 187 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/T2IAdapter"
|
| 188 |
+
|
| 189 |
+
def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
|
| 190 |
+
uncond_multiplier: float=1.0):
|
| 191 |
+
weights = [weight_00, weight_01, weight_02, weight_03]
|
| 192 |
+
weights = get_properly_arranged_t2i_weights(weights)
|
| 193 |
+
weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 194 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
class CustomT2IAdapterWeights:
|
| 198 |
+
@classmethod
|
| 199 |
+
def INPUT_TYPES(s):
|
| 200 |
+
return {
|
| 201 |
+
"required": {
|
| 202 |
+
"weight_00": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 203 |
+
"weight_01": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 204 |
+
"weight_02": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 205 |
+
"weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
|
| 206 |
+
"flip_weights": ("BOOLEAN", {"default": False}),
|
| 207 |
+
},
|
| 208 |
+
"optional": {
|
| 209 |
+
"uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
|
| 214 |
+
RETURN_NAMES = WEIGHTS_RETURN_NAMES
|
| 215 |
+
FUNCTION = "load_weights"
|
| 216 |
+
|
| 217 |
+
CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/T2IAdapter"
|
| 218 |
+
|
| 219 |
+
def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
|
| 220 |
+
uncond_multiplier: float=1.0):
|
| 221 |
+
weights = [weight_00, weight_01, weight_02, weight_03]
|
| 222 |
+
weights = get_properly_arranged_t2i_weights(weights)
|
| 223 |
+
weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 224 |
+
return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
|
ComfyUI-Advanced-ControlNet/adv_control/utils.py
ADDED
|
@@ -0,0 +1,915 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
from typing import Callable, Union
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor
|
| 5 |
+
import torch.nn.functional
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import comfy.ops
|
| 10 |
+
import comfy.utils
|
| 11 |
+
import comfy.sample
|
| 12 |
+
import comfy.samplers
|
| 13 |
+
import comfy.model_base
|
| 14 |
+
|
| 15 |
+
from comfy.controlnet import ControlBase, broadcast_image_to
|
| 16 |
+
from comfy.model_patcher import ModelPatcher
|
| 17 |
+
|
| 18 |
+
from .logger import logger
|
| 19 |
+
|
| 20 |
+
BIGMIN = -(2**53-1)
|
| 21 |
+
BIGMAX = (2**53-1)
|
| 22 |
+
|
| 23 |
+
def load_torch_file_with_dict_factory(controlnet_data: dict[str, Tensor], orig_load_torch_file: Callable):
|
| 24 |
+
def load_torch_file_with_dict(*args, **kwargs):
|
| 25 |
+
# immediately restore load_torch_file to original version
|
| 26 |
+
comfy.utils.load_torch_file = orig_load_torch_file
|
| 27 |
+
return controlnet_data
|
| 28 |
+
return load_torch_file_with_dict
|
| 29 |
+
|
| 30 |
+
# wrapping len function so that it will save the thing len is trying to get the length of;
|
| 31 |
+
# this will be assumed to be the cond_or_uncond variable;
|
| 32 |
+
# automatically restores len to original function after running
|
| 33 |
+
def wrapper_len_factory(orig_len: Callable) -> Callable:
|
| 34 |
+
def wrapper_len(*args, **kwargs):
|
| 35 |
+
cond_or_uncond = args[0]
|
| 36 |
+
real_length = orig_len(*args, **kwargs)
|
| 37 |
+
if real_length > 0 and type(cond_or_uncond) == list and (cond_or_uncond[0] in [0, 1]):
|
| 38 |
+
try:
|
| 39 |
+
to_return = IntWithCondOrUncond(real_length)
|
| 40 |
+
setattr(to_return, "cond_or_uncond", cond_or_uncond)
|
| 41 |
+
return to_return
|
| 42 |
+
finally:
|
| 43 |
+
__builtins__["len"] = orig_len
|
| 44 |
+
else:
|
| 45 |
+
return real_length
|
| 46 |
+
return wrapper_len
|
| 47 |
+
|
| 48 |
+
# wrapping cond_cat function so that it will wrap around len function to get cond_or_uncond variable value
|
| 49 |
+
# from comfy.samplers.calc_conds_batch
|
| 50 |
+
def wrapper_cond_cat_factory(orig_cond_cat: Callable):
|
| 51 |
+
def wrapper_cond_cat(*args, **kwargs):
|
| 52 |
+
__builtins__["len"] = wrapper_len_factory(__builtins__["len"])
|
| 53 |
+
return orig_cond_cat(*args, **kwargs)
|
| 54 |
+
return wrapper_cond_cat
|
| 55 |
+
orig_cond_cat = comfy.samplers.cond_cat
|
| 56 |
+
comfy.samplers.cond_cat = wrapper_cond_cat_factory(orig_cond_cat)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# wrapping apply_model so that len function will be cleaned up fairly soon after being injected
|
| 60 |
+
def apply_model_uncond_cleanup_factory(orig_apply_model, orig_len):
|
| 61 |
+
def apply_model_uncond_cleanup_wrapper(self, *args, **kwargs):
|
| 62 |
+
__builtins__["len"] = orig_len
|
| 63 |
+
return orig_apply_model(self, *args, **kwargs)
|
| 64 |
+
return apply_model_uncond_cleanup_wrapper
|
| 65 |
+
global_orig_len = __builtins__["len"]
|
| 66 |
+
orig_apply_model = comfy.model_base.BaseModel.apply_model
|
| 67 |
+
comfy.model_base.BaseModel.apply_model = apply_model_uncond_cleanup_factory(orig_apply_model, global_orig_len)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def uncond_multiplier_check_cn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
|
| 71 |
+
def contains_uncond_multiplier(control: Union[ControlBase, 'AdvancedControlBase']):
|
| 72 |
+
if control is None:
|
| 73 |
+
return False
|
| 74 |
+
if not isinstance(control, AdvancedControlBase):
|
| 75 |
+
return contains_uncond_multiplier(control.previous_controlnet)
|
| 76 |
+
# check if weights_override has an uncond_multiplier
|
| 77 |
+
if control.weights_override is not None and control.weights_override.has_uncond_multiplier:
|
| 78 |
+
return True
|
| 79 |
+
# check if any timestep_keyframes have an uncond_multiplier on their weights
|
| 80 |
+
if control.timestep_keyframes is not None:
|
| 81 |
+
for tk in control.timestep_keyframes.keyframes:
|
| 82 |
+
if tk.has_control_weights() and tk.control_weights.has_uncond_multiplier:
|
| 83 |
+
return True
|
| 84 |
+
return contains_uncond_multiplier(control.previous_controlnet)
|
| 85 |
+
|
| 86 |
+
# check if positive or negative conds contain Adv. Cns that use multiply_negative on weights
|
| 87 |
+
def uncond_multiplier_check_cn_sample(model: ModelPatcher, *args, **kwargs):
|
| 88 |
+
positive = args[-3]
|
| 89 |
+
negative = args[-2]
|
| 90 |
+
has_uncond_multiplier = False
|
| 91 |
+
if positive is not None:
|
| 92 |
+
for cond in positive:
|
| 93 |
+
if "control" in cond[1]:
|
| 94 |
+
has_uncond_multiplier = contains_uncond_multiplier(cond[1]["control"])
|
| 95 |
+
if has_uncond_multiplier:
|
| 96 |
+
break
|
| 97 |
+
if negative is not None and not has_uncond_multiplier:
|
| 98 |
+
for cond in negative:
|
| 99 |
+
if "control" in cond[1]:
|
| 100 |
+
has_uncond_multiplier = contains_uncond_multiplier(cond[1]["control"])
|
| 101 |
+
if has_uncond_multiplier:
|
| 102 |
+
break
|
| 103 |
+
try:
|
| 104 |
+
# if uncond_multiplier found, continue to use wrapped version of function
|
| 105 |
+
if has_uncond_multiplier:
|
| 106 |
+
return orig_comfy_sample(model, *args, **kwargs)
|
| 107 |
+
# otherwise, use original version of function to prevent even the smallest of slowdowns (0.XX%)
|
| 108 |
+
try:
|
| 109 |
+
wrapped_cond_cat = comfy.samplers.cond_cat
|
| 110 |
+
comfy.samplers.cond_cat = orig_cond_cat
|
| 111 |
+
return orig_comfy_sample(model, *args, **kwargs)
|
| 112 |
+
finally:
|
| 113 |
+
comfy.samplers.cond_cat = wrapped_cond_cat
|
| 114 |
+
finally:
|
| 115 |
+
# make sure len function is unwrapped by the time sampling is done, just in case
|
| 116 |
+
__builtins__["len"] = global_orig_len
|
| 117 |
+
return uncond_multiplier_check_cn_sample
|
| 118 |
+
# inject sample functions
|
| 119 |
+
comfy.sample.sample = uncond_multiplier_check_cn_sample_factory(comfy.sample.sample)
|
| 120 |
+
comfy.sample.sample_custom = uncond_multiplier_check_cn_sample_factory(comfy.sample.sample_custom, is_custom=True)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class IntWithCondOrUncond(int):
|
| 124 |
+
def __new__(cls, *args, **kwargs):
|
| 125 |
+
return super(IntWithCondOrUncond, cls).__new__(cls, *args, **kwargs)
|
| 126 |
+
|
| 127 |
+
def __init__(self, *args, **kwargs):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.cond_or_uncond = None
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def get_properly_arranged_t2i_weights(initial_weights: list[float]):
|
| 134 |
+
new_weights = []
|
| 135 |
+
new_weights.extend([initial_weights[0]]*3)
|
| 136 |
+
new_weights.extend([initial_weights[1]]*3)
|
| 137 |
+
new_weights.extend([initial_weights[2]]*3)
|
| 138 |
+
new_weights.extend([initial_weights[3]]*3)
|
| 139 |
+
return new_weights
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class ControlWeightType:
|
| 143 |
+
DEFAULT = "default"
|
| 144 |
+
UNIVERSAL = "universal"
|
| 145 |
+
T2IADAPTER = "t2iadapter"
|
| 146 |
+
CONTROLNET = "controlnet"
|
| 147 |
+
CONTROLLORA = "controllora"
|
| 148 |
+
CONTROLLLLITE = "controllllite"
|
| 149 |
+
SVD_CONTROLNET = "svd_controlnet"
|
| 150 |
+
SPARSECTRL = "sparsectrl"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ControlWeights:
|
| 154 |
+
def __init__(self, weight_type: str, base_multiplier: float=1.0, flip_weights: bool=False, weights: list[float]=None, weight_mask: Tensor=None,
|
| 155 |
+
uncond_multiplier=1.0):
|
| 156 |
+
self.weight_type = weight_type
|
| 157 |
+
self.base_multiplier = base_multiplier
|
| 158 |
+
self.flip_weights = flip_weights
|
| 159 |
+
self.weights = weights
|
| 160 |
+
if self.weights is not None and self.flip_weights:
|
| 161 |
+
self.weights.reverse()
|
| 162 |
+
self.weight_mask = weight_mask
|
| 163 |
+
self.uncond_multiplier = float(uncond_multiplier)
|
| 164 |
+
self.has_uncond_multiplier = not math.isclose(self.uncond_multiplier, 1.0)
|
| 165 |
+
|
| 166 |
+
def get(self, idx: int, default=1.0) -> Union[float, Tensor]:
|
| 167 |
+
# if weights is not none, return index
|
| 168 |
+
if self.weights is not None:
|
| 169 |
+
# this implies weights list is not aligning with expectations - will need to adjust code
|
| 170 |
+
if idx >= len(self.weights):
|
| 171 |
+
return default
|
| 172 |
+
return self.weights[idx]
|
| 173 |
+
return 1.0
|
| 174 |
+
|
| 175 |
+
def copy_with_new_weights(self, new_weights: list[float]):
|
| 176 |
+
return ControlWeights(weight_type=self.weight_type, base_multiplier=self.base_multiplier, flip_weights=self.flip_weights,
|
| 177 |
+
weights=new_weights, weight_mask=self.weight_mask, uncond_multiplier=self.uncond_multiplier)
|
| 178 |
+
|
| 179 |
+
@classmethod
|
| 180 |
+
def default(cls):
|
| 181 |
+
return cls(ControlWeightType.DEFAULT)
|
| 182 |
+
|
| 183 |
+
@classmethod
|
| 184 |
+
def universal(cls, base_multiplier: float, flip_weights: bool=False, uncond_multiplier: float=1.0):
|
| 185 |
+
return cls(ControlWeightType.UNIVERSAL, base_multiplier=base_multiplier, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def universal_mask(cls, weight_mask: Tensor, uncond_multiplier: float=1.0):
|
| 189 |
+
return cls(ControlWeightType.UNIVERSAL, weight_mask=weight_mask, uncond_multiplier=uncond_multiplier)
|
| 190 |
+
|
| 191 |
+
@classmethod
|
| 192 |
+
def t2iadapter(cls, weights: list[float]=None, flip_weights: bool=False, uncond_multiplier: float=1.0):
|
| 193 |
+
if weights is None:
|
| 194 |
+
weights = [1.0]*12
|
| 195 |
+
return cls(ControlWeightType.T2IADAPTER, weights=weights,flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 196 |
+
|
| 197 |
+
@classmethod
|
| 198 |
+
def controlnet(cls, weights: list[float]=None, flip_weights: bool=False, uncond_multiplier: float=1.0):
|
| 199 |
+
if weights is None:
|
| 200 |
+
weights = [1.0]*13
|
| 201 |
+
return cls(ControlWeightType.CONTROLNET, weights=weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 202 |
+
|
| 203 |
+
@classmethod
|
| 204 |
+
def controllora(cls, weights: list[float]=None, flip_weights: bool=False, uncond_multiplier: float=1.0):
|
| 205 |
+
if weights is None:
|
| 206 |
+
weights = [1.0]*10
|
| 207 |
+
return cls(ControlWeightType.CONTROLLORA, weights=weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 208 |
+
|
| 209 |
+
@classmethod
|
| 210 |
+
def controllllite(cls, weights: list[float]=None, flip_weights: bool=False, uncond_multiplier: float=1.0):
|
| 211 |
+
if weights is None:
|
| 212 |
+
# TODO: make this have a real value
|
| 213 |
+
weights = [1.0]*200
|
| 214 |
+
return cls(ControlWeightType.CONTROLLLLITE, weights=weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
class StrengthInterpolation:
|
| 218 |
+
LINEAR = "linear"
|
| 219 |
+
EASE_IN = "ease-in"
|
| 220 |
+
EASE_OUT = "ease-out"
|
| 221 |
+
EASE_IN_OUT = "ease-in-out"
|
| 222 |
+
NONE = "none"
|
| 223 |
+
|
| 224 |
+
_LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
|
| 225 |
+
_LIST_WITH_NONE = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT, NONE]
|
| 226 |
+
|
| 227 |
+
@classmethod
|
| 228 |
+
def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
|
| 229 |
+
diff = num_to - num_from
|
| 230 |
+
if method == cls.LINEAR:
|
| 231 |
+
weights = torch.linspace(num_from, num_to, length)
|
| 232 |
+
elif method == cls.EASE_IN:
|
| 233 |
+
index = torch.linspace(0, 1, length)
|
| 234 |
+
weights = diff * np.power(index, 2) + num_from
|
| 235 |
+
elif method == cls.EASE_OUT:
|
| 236 |
+
index = torch.linspace(0, 1, length)
|
| 237 |
+
weights = diff * (1 - np.power(1 - index, 2)) + num_from
|
| 238 |
+
elif method == cls.EASE_IN_OUT:
|
| 239 |
+
index = torch.linspace(0, 1, length)
|
| 240 |
+
weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
|
| 241 |
+
else:
|
| 242 |
+
raise ValueError(f"Unrecognized interpolation method '{method}'.")
|
| 243 |
+
if reverse:
|
| 244 |
+
weights = weights.flip(dims=(0,))
|
| 245 |
+
return weights
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class LatentKeyframe:
|
| 249 |
+
def __init__(self, batch_index: int, strength: float) -> None:
|
| 250 |
+
self.batch_index = batch_index
|
| 251 |
+
self.strength = strength
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# always maintain sorted state (by batch_index of LatentKeyframe)
|
| 255 |
+
class LatentKeyframeGroup:
|
| 256 |
+
def __init__(self) -> None:
|
| 257 |
+
self.keyframes: list[LatentKeyframe] = []
|
| 258 |
+
|
| 259 |
+
def add(self, keyframe: LatentKeyframe) -> None:
|
| 260 |
+
added = False
|
| 261 |
+
# replace existing keyframe if same batch_index
|
| 262 |
+
for i in range(len(self.keyframes)):
|
| 263 |
+
if self.keyframes[i].batch_index == keyframe.batch_index:
|
| 264 |
+
self.keyframes[i] = keyframe
|
| 265 |
+
added = True
|
| 266 |
+
break
|
| 267 |
+
if not added:
|
| 268 |
+
self.keyframes.append(keyframe)
|
| 269 |
+
self.keyframes.sort(key=lambda k: k.batch_index)
|
| 270 |
+
|
| 271 |
+
def get_index(self, index: int) -> Union[LatentKeyframe, None]:
|
| 272 |
+
try:
|
| 273 |
+
return self.keyframes[index]
|
| 274 |
+
except IndexError:
|
| 275 |
+
return None
|
| 276 |
+
|
| 277 |
+
def __getitem__(self, index) -> LatentKeyframe:
|
| 278 |
+
return self.keyframes[index]
|
| 279 |
+
|
| 280 |
+
def is_empty(self) -> bool:
|
| 281 |
+
return len(self.keyframes) == 0
|
| 282 |
+
|
| 283 |
+
def clone(self) -> 'LatentKeyframeGroup':
|
| 284 |
+
cloned = LatentKeyframeGroup()
|
| 285 |
+
for tk in self.keyframes:
|
| 286 |
+
cloned.add(tk)
|
| 287 |
+
return cloned
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
class TimestepKeyframe:
|
| 291 |
+
def __init__(self,
|
| 292 |
+
start_percent: float = 0.0,
|
| 293 |
+
strength: float = 1.0,
|
| 294 |
+
control_weights: ControlWeights = None,
|
| 295 |
+
latent_keyframes: LatentKeyframeGroup = None,
|
| 296 |
+
null_latent_kf_strength: float = 0.0,
|
| 297 |
+
inherit_missing: bool = True,
|
| 298 |
+
guarantee_steps: int = 1,
|
| 299 |
+
mask_hint_orig: Tensor = None) -> None:
|
| 300 |
+
self.start_percent = float(start_percent)
|
| 301 |
+
self.start_t = 999999999.9
|
| 302 |
+
self.strength = strength
|
| 303 |
+
self.control_weights = control_weights
|
| 304 |
+
self.latent_keyframes = latent_keyframes
|
| 305 |
+
self.null_latent_kf_strength = null_latent_kf_strength
|
| 306 |
+
self.inherit_missing = inherit_missing
|
| 307 |
+
self.guarantee_steps = guarantee_steps
|
| 308 |
+
self.mask_hint_orig = mask_hint_orig
|
| 309 |
+
|
| 310 |
+
def has_control_weights(self):
|
| 311 |
+
return self.control_weights is not None
|
| 312 |
+
|
| 313 |
+
def has_latent_keyframes(self):
|
| 314 |
+
return self.latent_keyframes is not None
|
| 315 |
+
|
| 316 |
+
def has_mask_hint(self):
|
| 317 |
+
return self.mask_hint_orig is not None
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@staticmethod
|
| 321 |
+
def default() -> 'TimestepKeyframe':
|
| 322 |
+
return TimestepKeyframe(start_percent=0.0, guarantee_steps=0)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# always maintain sorted state (by start_percent of TimestepKeyFrame)
|
| 326 |
+
class TimestepKeyframeGroup:
|
| 327 |
+
def __init__(self) -> None:
|
| 328 |
+
self.keyframes: list[TimestepKeyframe] = []
|
| 329 |
+
self.keyframes.append(TimestepKeyframe.default())
|
| 330 |
+
|
| 331 |
+
def add(self, keyframe: TimestepKeyframe) -> None:
|
| 332 |
+
# add to end of list, then sort
|
| 333 |
+
self.keyframes.append(keyframe)
|
| 334 |
+
self.keyframes = get_sorted_list_via_attr(self.keyframes, attr="start_percent")
|
| 335 |
+
|
| 336 |
+
def get_index(self, index: int) -> Union[TimestepKeyframe, None]:
|
| 337 |
+
try:
|
| 338 |
+
return self.keyframes[index]
|
| 339 |
+
except IndexError:
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
def has_index(self, index: int) -> int:
|
| 343 |
+
return index >=0 and index < len(self.keyframes)
|
| 344 |
+
|
| 345 |
+
def __getitem__(self, index) -> TimestepKeyframe:
|
| 346 |
+
return self.keyframes[index]
|
| 347 |
+
|
| 348 |
+
def __len__(self) -> int:
|
| 349 |
+
return len(self.keyframes)
|
| 350 |
+
|
| 351 |
+
def is_empty(self) -> bool:
|
| 352 |
+
return len(self.keyframes) == 0
|
| 353 |
+
|
| 354 |
+
def clone(self) -> 'TimestepKeyframeGroup':
|
| 355 |
+
cloned = TimestepKeyframeGroup()
|
| 356 |
+
# already sorted, so don't use add function to make cloning quicker
|
| 357 |
+
for tk in self.keyframes:
|
| 358 |
+
cloned.keyframes.append(tk)
|
| 359 |
+
return cloned
|
| 360 |
+
|
| 361 |
+
@classmethod
|
| 362 |
+
def default(cls, keyframe: TimestepKeyframe) -> 'TimestepKeyframeGroup':
|
| 363 |
+
group = cls()
|
| 364 |
+
group.keyframes[0] = keyframe
|
| 365 |
+
return group
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class AbstractPreprocWrapper:
|
| 369 |
+
error_msg = "Invalid use of [InsertHere] output. The output of [InsertHere] preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
|
| 370 |
+
def __init__(self, condhint: Tensor):
|
| 371 |
+
self.condhint = condhint
|
| 372 |
+
|
| 373 |
+
def movedim(self, *args, **kwargs):
|
| 374 |
+
return self
|
| 375 |
+
|
| 376 |
+
def __getattr__(self, *args, **kwargs):
|
| 377 |
+
raise AttributeError(self.error_msg)
|
| 378 |
+
|
| 379 |
+
def __setattr__(self, name, value):
|
| 380 |
+
if name != "condhint":
|
| 381 |
+
raise AttributeError(self.error_msg)
|
| 382 |
+
super().__setattr__(name, value)
|
| 383 |
+
|
| 384 |
+
def __iter__(self, *args, **kwargs):
|
| 385 |
+
raise AttributeError(self.error_msg)
|
| 386 |
+
|
| 387 |
+
def __next__(self, *args, **kwargs):
|
| 388 |
+
raise AttributeError(self.error_msg)
|
| 389 |
+
|
| 390 |
+
def __len__(self, *args, **kwargs):
|
| 391 |
+
raise AttributeError(self.error_msg)
|
| 392 |
+
|
| 393 |
+
def __getitem__(self, *args, **kwargs):
|
| 394 |
+
raise AttributeError(self.error_msg)
|
| 395 |
+
|
| 396 |
+
def __setitem__(self, *args, **kwargs):
|
| 397 |
+
raise AttributeError(self.error_msg)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
# depending on model, AnimateDiff may inject into GroupNorm, so make sure GroupNorm will be clean
|
| 401 |
+
class disable_weight_init_clean_groupnorm(comfy.ops.disable_weight_init):
|
| 402 |
+
class GroupNorm(comfy.ops.disable_weight_init.GroupNorm):
|
| 403 |
+
def forward_comfy_cast_weights(self, input):
|
| 404 |
+
weight, bias = comfy.ops.cast_bias_weight(self, input)
|
| 405 |
+
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
|
| 406 |
+
|
| 407 |
+
def forward(self, input):
|
| 408 |
+
if self.comfy_cast_weights:
|
| 409 |
+
return self.forward_comfy_cast_weights(input)
|
| 410 |
+
else:
|
| 411 |
+
return torch.nn.functional.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
|
| 412 |
+
|
| 413 |
+
class manual_cast_clean_groupnorm(comfy.ops.manual_cast):
|
| 414 |
+
class GroupNorm(disable_weight_init_clean_groupnorm.GroupNorm):
|
| 415 |
+
comfy_cast_weights = True
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# adapted from comfy/sample.py
|
| 419 |
+
def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False):
|
| 420 |
+
mask = mask.clone()
|
| 421 |
+
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear")
|
| 422 |
+
if match_dim1:
|
| 423 |
+
mask = torch.cat([mask] * shape[1], dim=1)
|
| 424 |
+
return mask
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# applies min-max normalization, from:
|
| 428 |
+
# https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch
|
| 429 |
+
def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0):
|
| 430 |
+
x_min, x_max = x.min(), x.max()
|
| 431 |
+
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min
|
| 432 |
+
|
| 433 |
+
def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0):
|
| 434 |
+
return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def broadcast_image_to_full(tensor, target_batch_size, batched_number, except_one=True):
|
| 438 |
+
current_batch_size = tensor.shape[0]
|
| 439 |
+
#print(current_batch_size, target_batch_size)
|
| 440 |
+
if except_one and current_batch_size == 1:
|
| 441 |
+
return tensor
|
| 442 |
+
|
| 443 |
+
per_batch = target_batch_size // batched_number
|
| 444 |
+
tensor = tensor[:per_batch]
|
| 445 |
+
|
| 446 |
+
if per_batch > tensor.shape[0]:
|
| 447 |
+
tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
|
| 448 |
+
|
| 449 |
+
current_batch_size = tensor.shape[0]
|
| 450 |
+
if current_batch_size == target_batch_size:
|
| 451 |
+
return tensor
|
| 452 |
+
else:
|
| 453 |
+
return torch.cat([tensor] * batched_number, dim=0)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
# from https://stackoverflow.com/a/24621200
|
| 457 |
+
def deepcopy_with_sharing(obj, shared_attribute_names, memo=None):
|
| 458 |
+
'''
|
| 459 |
+
Deepcopy an object, except for a given list of attributes, which should
|
| 460 |
+
be shared between the original object and its copy.
|
| 461 |
+
|
| 462 |
+
obj is some object
|
| 463 |
+
shared_attribute_names: A list of strings identifying the attributes that
|
| 464 |
+
should be shared between the original and its copy.
|
| 465 |
+
memo is the dictionary passed into __deepcopy__. Ignore this argument if
|
| 466 |
+
not calling from within __deepcopy__.
|
| 467 |
+
'''
|
| 468 |
+
assert isinstance(shared_attribute_names, (list, tuple))
|
| 469 |
+
|
| 470 |
+
shared_attributes = {k: getattr(obj, k) for k in shared_attribute_names}
|
| 471 |
+
|
| 472 |
+
if hasattr(obj, '__deepcopy__'):
|
| 473 |
+
# Do hack to prevent infinite recursion in call to deepcopy
|
| 474 |
+
deepcopy_method = obj.__deepcopy__
|
| 475 |
+
obj.__deepcopy__ = None
|
| 476 |
+
|
| 477 |
+
for attr in shared_attribute_names:
|
| 478 |
+
del obj.__dict__[attr]
|
| 479 |
+
|
| 480 |
+
clone = deepcopy(obj)
|
| 481 |
+
|
| 482 |
+
for attr, val in shared_attributes.items():
|
| 483 |
+
setattr(obj, attr, val)
|
| 484 |
+
setattr(clone, attr, val)
|
| 485 |
+
|
| 486 |
+
if hasattr(obj, '__deepcopy__'):
|
| 487 |
+
# Undo hack
|
| 488 |
+
obj.__deepcopy__ = deepcopy_method
|
| 489 |
+
del clone.__deepcopy__
|
| 490 |
+
|
| 491 |
+
return clone
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def get_sorted_list_via_attr(objects: list, attr: str) -> list:
|
| 495 |
+
if not objects:
|
| 496 |
+
return objects
|
| 497 |
+
elif len(objects) <= 1:
|
| 498 |
+
return [x for x in objects]
|
| 499 |
+
# now that we know we have to sort, do it following these rules:
|
| 500 |
+
# a) if objects have same value of attribute, maintain their relative order
|
| 501 |
+
# b) perform sorting of the groups of objects with same attributes
|
| 502 |
+
unique_attrs = {}
|
| 503 |
+
for o in objects:
|
| 504 |
+
val_attr = getattr(o, attr)
|
| 505 |
+
attr_list: list = unique_attrs.get(val_attr, list())
|
| 506 |
+
attr_list.append(o)
|
| 507 |
+
if val_attr not in unique_attrs:
|
| 508 |
+
unique_attrs[val_attr] = attr_list
|
| 509 |
+
# now that we have the unique attr values grouped together in relative order, sort them by key
|
| 510 |
+
sorted_attrs = dict(sorted(unique_attrs.items()))
|
| 511 |
+
# now flatten out the dict into a list to return
|
| 512 |
+
sorted_list = []
|
| 513 |
+
for object_list in sorted_attrs.values():
|
| 514 |
+
sorted_list.extend(object_list)
|
| 515 |
+
return sorted_list
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
class WeightTypeException(TypeError):
|
| 519 |
+
"Raised when weight not compatible with AdvancedControlBase object"
|
| 520 |
+
pass
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class AdvancedControlBase:
|
| 524 |
+
def __init__(self, base: ControlBase, timestep_keyframes: TimestepKeyframeGroup, weights_default: ControlWeights, require_model=False):
|
| 525 |
+
self.base = base
|
| 526 |
+
self.compatible_weights = [ControlWeightType.UNIVERSAL]
|
| 527 |
+
self.add_compatible_weight(weights_default.weight_type)
|
| 528 |
+
# mask for which parts of controlnet output to keep
|
| 529 |
+
self.mask_cond_hint_original = None
|
| 530 |
+
self.mask_cond_hint = None
|
| 531 |
+
self.tk_mask_cond_hint_original = None
|
| 532 |
+
self.tk_mask_cond_hint = None
|
| 533 |
+
self.weight_mask_cond_hint = None
|
| 534 |
+
# actual index values
|
| 535 |
+
self.sub_idxs = None
|
| 536 |
+
self.full_latent_length = 0
|
| 537 |
+
self.context_length = 0
|
| 538 |
+
# timesteps
|
| 539 |
+
self.t: Tensor = None
|
| 540 |
+
self.batched_number: Union[int, IntWithCondOrUncond] = None
|
| 541 |
+
self.batch_size: int = 0
|
| 542 |
+
# weights + override
|
| 543 |
+
self.weights: ControlWeights = None
|
| 544 |
+
self.weights_default: ControlWeights = weights_default
|
| 545 |
+
self.weights_override: ControlWeights = None
|
| 546 |
+
# latent keyframe + override
|
| 547 |
+
self.latent_keyframes: LatentKeyframeGroup = None
|
| 548 |
+
self.latent_keyframe_override: LatentKeyframeGroup = None
|
| 549 |
+
# initialize timestep_keyframes
|
| 550 |
+
self.set_timestep_keyframes(timestep_keyframes)
|
| 551 |
+
# override some functions
|
| 552 |
+
self.get_control = self.get_control_inject
|
| 553 |
+
self.control_merge = self.control_merge_inject
|
| 554 |
+
self.pre_run = self.pre_run_inject
|
| 555 |
+
self.cleanup = self.cleanup_inject
|
| 556 |
+
self.set_previous_controlnet = self.set_previous_controlnet_inject
|
| 557 |
+
# require model to be passed into Apply Advanced ControlNet 🛂🅐🅒🅝 node
|
| 558 |
+
self.require_model = require_model
|
| 559 |
+
# disarm - when set to False, used to force usage of Apply Advanced ControlNet 🛂🅐🅒🅝 node (which will set it to True)
|
| 560 |
+
self.disarmed = not require_model
|
| 561 |
+
|
| 562 |
+
def patch_model(self, model: ModelPatcher):
|
| 563 |
+
pass
|
| 564 |
+
|
| 565 |
+
def add_compatible_weight(self, control_weight_type: str):
|
| 566 |
+
self.compatible_weights.append(control_weight_type)
|
| 567 |
+
|
| 568 |
+
def verify_all_weights(self, throw_error=True):
|
| 569 |
+
# first, check if override exists - if so, only need to check the override
|
| 570 |
+
if self.weights_override is not None:
|
| 571 |
+
if self.weights_override.weight_type not in self.compatible_weights:
|
| 572 |
+
msg = f"Weight override is type {self.weights_override.weight_type}, but loaded {type(self).__name__}" + \
|
| 573 |
+
f"only supports {self.compatible_weights} weights."
|
| 574 |
+
raise WeightTypeException(msg)
|
| 575 |
+
# otherwise, check all timestep keyframe weights
|
| 576 |
+
else:
|
| 577 |
+
for tk in self.timestep_keyframes.keyframes:
|
| 578 |
+
if tk.has_control_weights() and tk.control_weights.weight_type not in self.compatible_weights:
|
| 579 |
+
msg = f"Weight on Timestep Keyframe with start_percent={tk.start_percent} is type" + \
|
| 580 |
+
f"{tk.control_weights.weight_type}, but loaded {type(self).__name__} only supports {self.compatible_weights} weights."
|
| 581 |
+
raise WeightTypeException(msg)
|
| 582 |
+
|
| 583 |
+
def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroup):
|
| 584 |
+
self.timestep_keyframes = timestep_keyframes if timestep_keyframes else TimestepKeyframeGroup()
|
| 585 |
+
# prepare first timestep_keyframe related stuff
|
| 586 |
+
self._current_timestep_keyframe = None
|
| 587 |
+
self._current_timestep_index = -1
|
| 588 |
+
self._current_used_steps = 0
|
| 589 |
+
self.weights = None
|
| 590 |
+
self.latent_keyframes = None
|
| 591 |
+
|
| 592 |
+
def prepare_current_timestep(self, t: Tensor, batched_number: int):
|
| 593 |
+
self.t = float(t[0])
|
| 594 |
+
self.batched_number = batched_number
|
| 595 |
+
self.batch_size = len(t)
|
| 596 |
+
# get current step percent
|
| 597 |
+
curr_t: float = self.t
|
| 598 |
+
prev_index = self._current_timestep_index
|
| 599 |
+
# if met guaranteed steps (or no current keyframe), look for next keyframe in case need to switch
|
| 600 |
+
if self._current_timestep_keyframe is None or self._current_used_steps >= self._current_timestep_keyframe.guarantee_steps:
|
| 601 |
+
# if has next index, loop through and see if need to switch
|
| 602 |
+
if self.timestep_keyframes.has_index(self._current_timestep_index+1):
|
| 603 |
+
for i in range(self._current_timestep_index+1, len(self.timestep_keyframes)):
|
| 604 |
+
eval_tk = self.timestep_keyframes[i]
|
| 605 |
+
# check if start percent is less or equal to curr_t
|
| 606 |
+
if eval_tk.start_t >= curr_t:
|
| 607 |
+
self._current_timestep_index = i
|
| 608 |
+
self._current_timestep_keyframe = eval_tk
|
| 609 |
+
self._current_used_steps = 0
|
| 610 |
+
# keep track of control weights, latent keyframes, and masks,
|
| 611 |
+
# accounting for inherit_missing
|
| 612 |
+
if self._current_timestep_keyframe.has_control_weights():
|
| 613 |
+
self.weights = self._current_timestep_keyframe.control_weights
|
| 614 |
+
elif not self._current_timestep_keyframe.inherit_missing:
|
| 615 |
+
self.weights = self.weights_default
|
| 616 |
+
if self._current_timestep_keyframe.has_latent_keyframes():
|
| 617 |
+
self.latent_keyframes = self._current_timestep_keyframe.latent_keyframes
|
| 618 |
+
elif not self._current_timestep_keyframe.inherit_missing:
|
| 619 |
+
self.latent_keyframes = None
|
| 620 |
+
if self._current_timestep_keyframe.has_mask_hint():
|
| 621 |
+
self.tk_mask_cond_hint_original = self._current_timestep_keyframe.mask_hint_orig
|
| 622 |
+
elif not self._current_timestep_keyframe.inherit_missing:
|
| 623 |
+
del self.tk_mask_cond_hint_original
|
| 624 |
+
self.tk_mask_cond_hint_original = None
|
| 625 |
+
# if guarantee_steps greater than zero, stop searching for other keyframes
|
| 626 |
+
if self._current_timestep_keyframe.guarantee_steps > 0:
|
| 627 |
+
break
|
| 628 |
+
# if eval_tk is outside of percent range, stop looking further
|
| 629 |
+
else:
|
| 630 |
+
break
|
| 631 |
+
|
| 632 |
+
# update steps current keyframe is used
|
| 633 |
+
self._current_used_steps += 1
|
| 634 |
+
# if index changed, apply overrides
|
| 635 |
+
if prev_index != self._current_timestep_index:
|
| 636 |
+
if self.weights_override is not None:
|
| 637 |
+
self.weights = self.weights_override
|
| 638 |
+
if self.latent_keyframe_override is not None:
|
| 639 |
+
self.latent_keyframes = self.latent_keyframe_override
|
| 640 |
+
|
| 641 |
+
# make sure weights and latent_keyframes are in a workable state
|
| 642 |
+
# Note: each AdvancedControlBase should create their own get_universal_weights class
|
| 643 |
+
self.prepare_weights()
|
| 644 |
+
|
| 645 |
+
def prepare_weights(self):
|
| 646 |
+
if self.weights is None or self.weights.weight_type == ControlWeightType.DEFAULT:
|
| 647 |
+
self.weights = self.weights_default
|
| 648 |
+
elif self.weights.weight_type == ControlWeightType.UNIVERSAL:
|
| 649 |
+
# if universal and weight_mask present, no need to convert
|
| 650 |
+
if self.weights.weight_mask is not None:
|
| 651 |
+
return
|
| 652 |
+
self.weights = self.get_universal_weights()
|
| 653 |
+
|
| 654 |
+
def get_universal_weights(self) -> ControlWeights:
|
| 655 |
+
return self.weights
|
| 656 |
+
|
| 657 |
+
def set_cond_hint_mask(self, mask_hint):
|
| 658 |
+
self.mask_cond_hint_original = mask_hint
|
| 659 |
+
return self
|
| 660 |
+
|
| 661 |
+
def pre_run_inject(self, model, percent_to_timestep_function):
|
| 662 |
+
self.base.pre_run(model, percent_to_timestep_function)
|
| 663 |
+
self.pre_run_advanced(model, percent_to_timestep_function)
|
| 664 |
+
|
| 665 |
+
def pre_run_advanced(self, model, percent_to_timestep_function):
|
| 666 |
+
# for each timestep keyframe, calculate the start_t
|
| 667 |
+
for tk in self.timestep_keyframes.keyframes:
|
| 668 |
+
tk.start_t = percent_to_timestep_function(tk.start_percent)
|
| 669 |
+
# clear variables
|
| 670 |
+
self.cleanup_advanced()
|
| 671 |
+
|
| 672 |
+
def set_previous_controlnet_inject(self, *args, **kwargs):
|
| 673 |
+
to_return = self.base.set_previous_controlnet(*args, **kwargs)
|
| 674 |
+
if not self.disarmed:
|
| 675 |
+
raise Exception(f"Type '{type(self).__name__}' must be used with Apply Advanced ControlNet 🛂🅐🅒🅝 node (with model_optional passed in); otherwise, it will not work.")
|
| 676 |
+
return to_return
|
| 677 |
+
|
| 678 |
+
def disarm(self):
|
| 679 |
+
self.disarmed = True
|
| 680 |
+
|
| 681 |
+
def should_run(self):
|
| 682 |
+
if math.isclose(self.strength, 0.0) or math.isclose(self._current_timestep_keyframe.strength, 0.0):
|
| 683 |
+
return False
|
| 684 |
+
if self.timestep_range is not None:
|
| 685 |
+
if self.t > self.timestep_range[0] or self.t < self.timestep_range[1]:
|
| 686 |
+
return False
|
| 687 |
+
return True
|
| 688 |
+
|
| 689 |
+
def get_control_inject(self, x_noisy, t, cond, batched_number):
|
| 690 |
+
# prepare timestep and everything related
|
| 691 |
+
self.prepare_current_timestep(t=t, batched_number=batched_number)
|
| 692 |
+
# if should not perform any actions for the controlnet, exit without doing any work
|
| 693 |
+
if self.strength == 0.0 or self._current_timestep_keyframe.strength == 0.0:
|
| 694 |
+
return self.default_control_actions(x_noisy, t, cond, batched_number)
|
| 695 |
+
# otherwise, perform normal function
|
| 696 |
+
return self.get_control_advanced(x_noisy, t, cond, batched_number)
|
| 697 |
+
|
| 698 |
+
def get_control_advanced(self, x_noisy, t, cond, batched_number):
|
| 699 |
+
return self.default_control_actions(x_noisy, t, cond, batched_number)
|
| 700 |
+
|
| 701 |
+
def default_control_actions(self, x_noisy, t, cond, batched_number):
|
| 702 |
+
control_prev = None
|
| 703 |
+
if self.previous_controlnet is not None:
|
| 704 |
+
control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
|
| 705 |
+
return control_prev
|
| 706 |
+
|
| 707 |
+
def calc_weight(self, idx: int, x: Tensor, layers: int) -> Union[float, Tensor]:
|
| 708 |
+
if self.weights.weight_mask is not None:
|
| 709 |
+
# prepare weight mask
|
| 710 |
+
self.prepare_weight_mask_cond_hint(x, self.batched_number)
|
| 711 |
+
# adjust mask for current layer and return
|
| 712 |
+
return torch.pow(self.weight_mask_cond_hint, self.get_calc_pow(idx=idx, layers=layers))
|
| 713 |
+
return self.weights.get(idx=idx)
|
| 714 |
+
|
| 715 |
+
def get_calc_pow(self, idx: int, layers: int) -> int:
|
| 716 |
+
return (layers-1)-idx
|
| 717 |
+
|
| 718 |
+
def calc_latent_keyframe_mults(self, x: Tensor, batched_number: int) -> Tensor:
|
| 719 |
+
# apply strengths, and get batch indeces to null out
|
| 720 |
+
# AKA latents that should not be influenced by ControlNet
|
| 721 |
+
final_mults = [1.0] * x.shape[0]
|
| 722 |
+
if self.latent_keyframes:
|
| 723 |
+
latent_count = x.shape[0] // batched_number
|
| 724 |
+
indeces_to_null = set(range(latent_count))
|
| 725 |
+
mapped_indeces = None
|
| 726 |
+
# if expecting subdivision, will need to translate between subset and actual idx values
|
| 727 |
+
if self.sub_idxs:
|
| 728 |
+
mapped_indeces = {}
|
| 729 |
+
for i, actual in enumerate(self.sub_idxs):
|
| 730 |
+
mapped_indeces[actual] = i
|
| 731 |
+
for keyframe in self.latent_keyframes:
|
| 732 |
+
real_index = keyframe.batch_index
|
| 733 |
+
# if negative, count from end
|
| 734 |
+
if real_index < 0:
|
| 735 |
+
real_index += latent_count if self.sub_idxs is None else self.full_latent_length
|
| 736 |
+
|
| 737 |
+
# if not mapping indeces, what you see is what you get
|
| 738 |
+
if mapped_indeces is None:
|
| 739 |
+
if real_index in indeces_to_null:
|
| 740 |
+
indeces_to_null.remove(real_index)
|
| 741 |
+
# otherwise, see if batch_index is even included in this set of latents
|
| 742 |
+
else:
|
| 743 |
+
real_index = mapped_indeces.get(real_index, None)
|
| 744 |
+
if real_index is None:
|
| 745 |
+
continue
|
| 746 |
+
indeces_to_null.remove(real_index)
|
| 747 |
+
|
| 748 |
+
# if real_index is outside the bounds of latents, don't apply
|
| 749 |
+
if real_index >= latent_count or real_index < 0:
|
| 750 |
+
continue
|
| 751 |
+
|
| 752 |
+
# apply strength for each batched cond/uncond
|
| 753 |
+
for b in range(batched_number):
|
| 754 |
+
final_mults[(latent_count*b)+real_index] = keyframe.strength
|
| 755 |
+
# null them out by multiplying by null_latent_kf_strength
|
| 756 |
+
for batch_index in indeces_to_null:
|
| 757 |
+
# apply null for each batched cond/uncond
|
| 758 |
+
for b in range(batched_number):
|
| 759 |
+
final_mults[(latent_count*b)+batch_index] = self._current_timestep_keyframe.null_latent_kf_strength
|
| 760 |
+
# convert final_mults into tensor and match expected dimension count
|
| 761 |
+
final_tensor = torch.tensor(final_mults, dtype=x.dtype, device=x.device)
|
| 762 |
+
while len(final_tensor.shape) < len(x.shape):
|
| 763 |
+
final_tensor = final_tensor.unsqueeze(-1)
|
| 764 |
+
return final_tensor
|
| 765 |
+
|
| 766 |
+
def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int):
|
| 767 |
+
# handle weight's uncond_multiplier, if applicable
|
| 768 |
+
if self.weights.has_uncond_multiplier:
|
| 769 |
+
cond_or_uncond = self.batched_number.cond_or_uncond
|
| 770 |
+
actual_length = x.size(0) // batched_number
|
| 771 |
+
for idx, cond_type in enumerate(cond_or_uncond):
|
| 772 |
+
# if uncond, set to weight's uncond_multiplier
|
| 773 |
+
if cond_type == 1:
|
| 774 |
+
x[actual_length*idx:actual_length*(idx+1)] *= self.weights.uncond_multiplier
|
| 775 |
+
|
| 776 |
+
if self.latent_keyframes is not None:
|
| 777 |
+
x[:] = x[:] * self.calc_latent_keyframe_mults(x=x, batched_number=batched_number)
|
| 778 |
+
# apply masks, resizing mask to required dims
|
| 779 |
+
if self.mask_cond_hint is not None:
|
| 780 |
+
masks = prepare_mask_batch(self.mask_cond_hint, x.shape)
|
| 781 |
+
x[:] = x[:] * masks
|
| 782 |
+
if self.tk_mask_cond_hint is not None:
|
| 783 |
+
masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape)
|
| 784 |
+
x[:] = x[:] * masks
|
| 785 |
+
# apply timestep keyframe strengths
|
| 786 |
+
if self._current_timestep_keyframe.strength != 1.0:
|
| 787 |
+
x[:] *= self._current_timestep_keyframe.strength
|
| 788 |
+
|
| 789 |
+
def control_merge_inject(self: 'AdvancedControlBase', control_input, control_output, control_prev, output_dtype):
|
| 790 |
+
out = {'input':[], 'middle':[], 'output': []}
|
| 791 |
+
|
| 792 |
+
if control_input is not None:
|
| 793 |
+
for i in range(len(control_input)):
|
| 794 |
+
key = 'input'
|
| 795 |
+
x = control_input[i]
|
| 796 |
+
if x is not None:
|
| 797 |
+
self.apply_advanced_strengths_and_masks(x, self.batched_number)
|
| 798 |
+
|
| 799 |
+
x *= self.strength * self.calc_weight(i, x, len(control_input))
|
| 800 |
+
if x.dtype != output_dtype:
|
| 801 |
+
x = x.to(output_dtype)
|
| 802 |
+
out[key].insert(0, x)
|
| 803 |
+
|
| 804 |
+
if control_output is not None:
|
| 805 |
+
for i in range(len(control_output)):
|
| 806 |
+
if i == (len(control_output) - 1):
|
| 807 |
+
key = 'middle'
|
| 808 |
+
index = 0
|
| 809 |
+
else:
|
| 810 |
+
key = 'output'
|
| 811 |
+
index = i
|
| 812 |
+
x = control_output[i]
|
| 813 |
+
if x is not None:
|
| 814 |
+
self.apply_advanced_strengths_and_masks(x, self.batched_number)
|
| 815 |
+
|
| 816 |
+
if self.global_average_pooling:
|
| 817 |
+
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
| 818 |
+
|
| 819 |
+
x *= self.strength * self.calc_weight(i, x, len(control_output))
|
| 820 |
+
if x.dtype != output_dtype:
|
| 821 |
+
x = x.to(output_dtype)
|
| 822 |
+
|
| 823 |
+
out[key].append(x)
|
| 824 |
+
if control_prev is not None:
|
| 825 |
+
for x in ['input', 'middle', 'output']:
|
| 826 |
+
o = out[x]
|
| 827 |
+
for i in range(len(control_prev[x])):
|
| 828 |
+
prev_val = control_prev[x][i]
|
| 829 |
+
if i >= len(o):
|
| 830 |
+
o.append(prev_val)
|
| 831 |
+
elif prev_val is not None:
|
| 832 |
+
if o[i] is None:
|
| 833 |
+
o[i] = prev_val
|
| 834 |
+
else:
|
| 835 |
+
if o[i].shape[0] < prev_val.shape[0]:
|
| 836 |
+
o[i] = prev_val + o[i]
|
| 837 |
+
else:
|
| 838 |
+
o[i] += prev_val
|
| 839 |
+
return out
|
| 840 |
+
|
| 841 |
+
def prepare_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
|
| 842 |
+
self._prepare_mask("mask_cond_hint", self.mask_cond_hint_original, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
|
| 843 |
+
self.prepare_tk_mask_cond_hint(x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
|
| 844 |
+
|
| 845 |
+
def prepare_tk_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
|
| 846 |
+
return self._prepare_mask("tk_mask_cond_hint", self._current_timestep_keyframe.mask_hint_orig, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
|
| 847 |
+
|
| 848 |
+
def prepare_weight_mask_cond_hint(self, x_noisy: Tensor, batched_number, dtype=None):
|
| 849 |
+
return self._prepare_mask("weight_mask_cond_hint", self.weights.weight_mask, x_noisy, t=None, cond=None, batched_number=batched_number, dtype=dtype, direct_attn=True)
|
| 850 |
+
|
| 851 |
+
def _prepare_mask(self, attr_name, orig_mask: Tensor, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
|
| 852 |
+
# make mask appropriate dimensions, if present
|
| 853 |
+
if orig_mask is not None:
|
| 854 |
+
out_mask = getattr(self, attr_name)
|
| 855 |
+
multiplier = 1 if direct_attn else 8
|
| 856 |
+
if self.sub_idxs is not None or out_mask is None or x_noisy.shape[2] * multiplier != out_mask.shape[1] or x_noisy.shape[3] * multiplier != out_mask.shape[2]:
|
| 857 |
+
self._reset_attr(attr_name)
|
| 858 |
+
del out_mask
|
| 859 |
+
# TODO: perform upscale on only the sub_idxs masks at a time instead of all to conserve RAM
|
| 860 |
+
# resize mask and match batch count
|
| 861 |
+
out_mask = prepare_mask_batch(orig_mask, x_noisy.shape, multiplier=multiplier)
|
| 862 |
+
actual_latent_length = x_noisy.shape[0] // batched_number
|
| 863 |
+
out_mask = comfy.utils.repeat_to_batch_size(out_mask, actual_latent_length if self.sub_idxs is None else self.full_latent_length)
|
| 864 |
+
if self.sub_idxs is not None:
|
| 865 |
+
out_mask = out_mask[self.sub_idxs]
|
| 866 |
+
# make cond_hint_mask length match x_noise
|
| 867 |
+
if x_noisy.shape[0] != out_mask.shape[0]:
|
| 868 |
+
out_mask = broadcast_image_to(out_mask, x_noisy.shape[0], batched_number)
|
| 869 |
+
# default dtype to be same as x_noisy
|
| 870 |
+
if dtype is None:
|
| 871 |
+
dtype = x_noisy.dtype
|
| 872 |
+
setattr(self, attr_name, out_mask.to(dtype=dtype).to(self.device))
|
| 873 |
+
del out_mask
|
| 874 |
+
|
| 875 |
+
def _reset_attr(self, attr_name, new_value=None):
|
| 876 |
+
if hasattr(self, attr_name):
|
| 877 |
+
delattr(self, attr_name)
|
| 878 |
+
setattr(self, attr_name, new_value)
|
| 879 |
+
|
| 880 |
+
def cleanup_inject(self):
|
| 881 |
+
self.base.cleanup()
|
| 882 |
+
self.cleanup_advanced()
|
| 883 |
+
|
| 884 |
+
def cleanup_advanced(self):
|
| 885 |
+
self.sub_idxs = None
|
| 886 |
+
self.full_latent_length = 0
|
| 887 |
+
self.context_length = 0
|
| 888 |
+
self.t = None
|
| 889 |
+
self.batched_number = None
|
| 890 |
+
self.batch_size = 0
|
| 891 |
+
self.weights = None
|
| 892 |
+
self.latent_keyframes = None
|
| 893 |
+
# timestep stuff
|
| 894 |
+
self._current_timestep_keyframe = None
|
| 895 |
+
self._current_timestep_index = -1
|
| 896 |
+
self._current_used_steps = 0
|
| 897 |
+
# clear mask hints
|
| 898 |
+
if self.mask_cond_hint is not None:
|
| 899 |
+
del self.mask_cond_hint
|
| 900 |
+
self.mask_cond_hint = None
|
| 901 |
+
if self.tk_mask_cond_hint_original is not None:
|
| 902 |
+
del self.tk_mask_cond_hint_original
|
| 903 |
+
self.tk_mask_cond_hint_original = None
|
| 904 |
+
if self.tk_mask_cond_hint is not None:
|
| 905 |
+
del self.tk_mask_cond_hint
|
| 906 |
+
self.tk_mask_cond_hint = None
|
| 907 |
+
if self.weight_mask_cond_hint is not None:
|
| 908 |
+
del self.weight_mask_cond_hint
|
| 909 |
+
self.weight_mask_cond_hint = None
|
| 910 |
+
|
| 911 |
+
def copy_to_advanced(self, copied: 'AdvancedControlBase'):
|
| 912 |
+
copied.mask_cond_hint_original = self.mask_cond_hint_original
|
| 913 |
+
copied.weights_override = self.weights_override
|
| 914 |
+
copied.latent_keyframe_override = self.latent_keyframe_override
|
| 915 |
+
copied.disarmed = self.disarmed
|
ComfyUI-Advanced-ControlNet/pyproject.toml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "comfyui-advanced-controlnet"
|
| 3 |
+
description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks."
|
| 4 |
+
version = "1.0.2"
|
| 5 |
+
license = "LICENSE"
|
| 6 |
+
dependencies = []
|
| 7 |
+
|
| 8 |
+
[project.urls]
|
| 9 |
+
Repository = "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet"
|
| 10 |
+
|
| 11 |
+
# Used by Comfy Registry https://comfyregistry.org
|
| 12 |
+
[tool.comfy]
|
| 13 |
+
PublisherId = "kosinkadink"
|
| 14 |
+
DisplayName = "ComfyUI-Advanced-ControlNet"
|
| 15 |
+
Icon = ""
|
ComfyUI-Advanced-ControlNet/requirements.txt
ADDED
|
File without changes
|
ComfyUI-BrushNet/BIG_IMAGE.md
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
|
| 3 |
+
[workflow](example/BrushNet_cut_for_inpaint.json)
|
| 4 |
+
|
| 5 |
+
When you work with big image and your inpaint mask is small it is better to cut part of the image, work with it and then blend it back.
|
| 6 |
+
I created a node for such workflow, see example.
|
ComfyUI-BrushNet/CN.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## ControlNet Canny Edge
|
| 2 |
+
|
| 3 |
+
Let's take the pestered cake and try to inpaint it again. Now I would like to use a sleeping cat for it:
|
| 4 |
+
|
| 5 |
+

|
| 6 |
+
|
| 7 |
+
I use Canny Edge node from [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux). Don't forget to resize canny edge mask to 512 pixels:
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
Let's look at the result:
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
The first problem I see here is some kind of object behind the cat. Such objects appear since the inpainting mask strictly aligns with the removed object, the cake in our case. To remove such artifact we should expand our mask a little:
|
| 16 |
+
|
| 17 |
+

|
| 18 |
+
|
| 19 |
+
Now. what's up with cat back and tail? Let's see the inpainting mask and canny edge mask side to side:
|
| 20 |
+
|
| 21 |
+

|
| 22 |
+
|
| 23 |
+
The inpainting works (mostly) only in masked (white) area, so we cut off cat's back. **The ControlNet mask should be inside the inpaint mask.**
|
| 24 |
+
|
| 25 |
+
To address the issue I resized the mask to 256 pixels:
|
| 26 |
+
|
| 27 |
+

|
| 28 |
+
|
| 29 |
+
This is better but still have a room for improvement. The problem with edge mask downsampling is that edge lines tend to be broken and after some size we will got a mess:
|
| 30 |
+
|
| 31 |
+

|
| 32 |
+
|
| 33 |
+
Look at the edge mask, at this resolution it is so broken:
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
ComfyUI-BrushNet/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
ComfyUI-BrushNet/PARAMS.md
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Start At and End At parameters usage
|
| 2 |
+
|
| 3 |
+
### start_at
|
| 4 |
+
|
| 5 |
+
Let's start with a ELLA outpaint [workflow](example/BrushNet_with_ELLA.json) and switch off Blend Inpaint node:
|
| 6 |
+
|
| 7 |
+

|
| 8 |
+
|
| 9 |
+
For this example I use "wargaming shop showcase" prompt, `dpmpp_2m` deterministic sampler and `karras` scheduler with 15 steps. This is the result:
|
| 10 |
+
|
| 11 |
+

|
| 12 |
+
|
| 13 |
+
The `start_at` BrushNet node parameter allows us to delay BrushNet inference for some steps, so the base model will do all the job. Let's see what the result will be without BrushNet. For this I set up `start_at` parameter to 20 - it should be more then `steps` in KSampler node:
|
| 14 |
+
|
| 15 |
+

|
| 16 |
+
|
| 17 |
+
So, if we apply BrushNet from the beginning (`start_at` equals 0), the resulting scene will be heavily influenced by BrushNet image. The more we increase this parameter, the more scene will be based on prompt. Let's compare:
|
| 18 |
+
|
| 19 |
+
| `start_at` = 1 | `start_at` = 2 | `start_at` = 3 |
|
| 20 |
+
|:--------------:|:--------------:|:--------------:|
|
| 21 |
+
|  |  |  |
|
| 22 |
+
| `start_at` = 4 | `start_at` = 5 | `start_at` = 6 |
|
| 23 |
+
|  |  |  |
|
| 24 |
+
| `start_at` = 7 | `start_at` = 8 | `start_at` = 9 |
|
| 25 |
+
|  |  |  |
|
| 26 |
+
|
| 27 |
+
Look how the floor is aligned with toy's base - at some step it looses consistency. The results will depend on type of sampler and number of KSampler steps, of course.
|
| 28 |
+
|
| 29 |
+
### end_at
|
| 30 |
+
|
| 31 |
+
The `end_at` parameter switches off BrushNet at the last steps. If you use deterministic sampler it will only influences details on last steps, but stochastic samplers can change the whole scene. For a description of samplers see, for example, Matteo Spinelli's [video on ComfyUI basics](https://youtu.be/_C7kR2TFIX0?t=516).
|
| 32 |
+
|
| 33 |
+
Here I use basic BrushNet inpaint [example](example/BrushNet_basic.json), with "intricate teapot" prompt, `dpmpp_2m` deterministic sampler and `karras` scheduler with 15 steps:
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
There are almost no changes when we set 'end_at' paramter to 10, but starting from it:
|
| 38 |
+
|
| 39 |
+
| `end_at` = 10 | `end_at` = 9 | `end_at` = 8 |
|
| 40 |
+
|:--------------:|:--------------:|:--------------:|
|
| 41 |
+
|  |  |  |
|
| 42 |
+
| `end_at` = 7 | `end_at` = 6 | `end_at` = 5 |
|
| 43 |
+
|  |  |  |
|
| 44 |
+
| `end_at` = 4 | `end_at` = 3 | `end_at` = 2 |
|
| 45 |
+
|  |  |  |
|
| 46 |
+
|
| 47 |
+
You can see how the scene was completely redrawn.
|
ComfyUI-BrushNet/RAUNET.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
During investigation of compatibility issues with [WASasquatch's FreeU_Advanced](https://github.com/WASasquatch/FreeU_Advanced/tree/main) and [blepping's jank HiDiffusion](https://github.com/blepping/comfyui_jankhidiffusion) nodes I stumbled upon some quite hard problems. There are `FreeU` nodes in ComfyUI, but no such for HiDiffusion, so I decided to implement RAUNet on base of my BrushNet implementation. **blepping**, I am sorry. :)
|
| 2 |
+
|
| 3 |
+
### RAUNet
|
| 4 |
+
|
| 5 |
+
What is RAUNet? I know many of you saw and generate images with a lot of limbs, fingers and faces all morphed together.
|
| 6 |
+
|
| 7 |
+
The authors of HiDiffusion invent simple, yet efficient trick to alleviate this problem. Here is an example:
|
| 8 |
+
|
| 9 |
+

|
| 10 |
+
|
| 11 |
+
[workflow](example/RAUNet_basic.json)
|
| 12 |
+
|
| 13 |
+
The left picture is created using ZavyChromaXL checkpoint on 2048x2048 canvas. The right one uses RAUNet.
|
| 14 |
+
|
| 15 |
+
In my experience the node is helpful but quite sensitive to its parameters. And there is no universal solution - you should adjust them for every new image you generate. It also lowers model's imagination, you usually get only what you described in the prompt. Look at the example: in first you have a forest in the background, but RAUNet deleted all except fox which is described in the prompt.
|
| 16 |
+
|
| 17 |
+
From the [paper](https://arxiv.org/abs/2311.17528): Diffusion models denoise from structures to details. RAU-Net introduces additional downsampling and upsampling operations, leading to a certain degree of information loss. In the early stages of denoising, RAU-Net can generate reasonable structures with minimal impact from information loss. However, in the later stages of denoising when generating fine details, the information loss in RAU-Net results in the loss of image details and a degradation in quality.
|
| 18 |
+
|
| 19 |
+
### Parameters
|
| 20 |
+
|
| 21 |
+
There are two independent parts in this node: DU (Downsample/Upsample) and XA (CrossAttention). The four parameters are the start and end steps for applying these parts.
|
| 22 |
+
|
| 23 |
+
The Downsample/Upsample part lowers models degrees of freedom. If you apply it a lot (for more steps) the resulting images will have a lot of symmetries.
|
| 24 |
+
|
| 25 |
+
The CrossAttension part lowers number of objects which model tracks in image.
|
| 26 |
+
|
| 27 |
+
Usually you apply DU and after several steps apply XA, sometimes you will need only XA, you should try it yourself.
|
| 28 |
+
|
| 29 |
+
### Compatibility
|
| 30 |
+
|
| 31 |
+
It is compatible with BrushNet and most other nodes.
|
| 32 |
+
|
| 33 |
+
This is ControlNet example. The lower image is pure model, the upper is after using RAUNet. You can see small fox and two tails in lower image.
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
[workflow](example/RAUNet_with_CN.json)
|
| 38 |
+
|
| 39 |
+
The node can be implemented for any model. Right now it can be applied to SD15 and SDXL models.
|
ComfyUI-BrushNet/README.md
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## ComfyUI-BrushNet
|
| 2 |
+
|
| 3 |
+
These are custom nodes for ComfyUI native implementation of
|
| 4 |
+
|
| 5 |
+
- Brushnet: ["BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"](https://arxiv.org/abs/2403.06976)
|
| 6 |
+
- PowerPaint: [A Task is Worth One Word: Learning with Task Prompts for High-Quality Versatile Image Inpainting](https://arxiv.org/abs/2312.03594)
|
| 7 |
+
- HiDiffusion: [HiDiffusion: Unlocking Higher-Resolution Creativity and Efficiency in Pretrained Diffusion Models](https://arxiv.org/abs/2311.17528)
|
| 8 |
+
|
| 9 |
+
My contribution is limited to the ComfyUI adaptation, and all credit goes to the authors of the papers.
|
| 10 |
+
|
| 11 |
+
## Updates
|
| 12 |
+
|
| 13 |
+
May 16, 2024. Internal rework to improve compatibility with other nodes. [RAUNet](RAUNET.md) is implemented.
|
| 14 |
+
|
| 15 |
+
May 12, 2024. CutForInpaint node, see [example](BIG_IMAGE.md).
|
| 16 |
+
|
| 17 |
+
May 11, 2024. Image batch is implemented. You can even add BrushNet to AnimateDiff vid2vid workflow, but they don't work together - they are different models and both try to patch UNet. Added some more examples.
|
| 18 |
+
|
| 19 |
+
May 6, 2024. PowerPaint v2 model is implemented. After update your workflow probably will not work. Don't panic! Check `end_at` parameter of BrushNode, if it equals 1, change it to some big number. Read about parameters in Usage section below.
|
| 20 |
+
|
| 21 |
+
May 2, 2024. BrushNet SDXL is live. It needs positive and negative conditioning though, so workflow changes a little, see example.
|
| 22 |
+
|
| 23 |
+
Apr 28, 2024. Another rework, sorry for inconvenience. But now BrushNet is native to ComfyUI. Famous cubiq's [IPAdapter Plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) is now working with BrushNet! I hope... :) Please, report any bugs you found.
|
| 24 |
+
|
| 25 |
+
Apr 18, 2024. Complete rework, no more custom `diffusers` library. It is possible to use LoRA models.
|
| 26 |
+
|
| 27 |
+
Apr 11, 2024. Initial commit.
|
| 28 |
+
|
| 29 |
+
## Plans
|
| 30 |
+
|
| 31 |
+
- [x] BrushNet SDXL
|
| 32 |
+
- [x] PowerPaint v2
|
| 33 |
+
- [x] Image batch
|
| 34 |
+
|
| 35 |
+
## Installation
|
| 36 |
+
|
| 37 |
+
Clone the repo into the `custom_nodes` directory and install the requirements:
|
| 38 |
+
|
| 39 |
+
```
|
| 40 |
+
git clone https://github.com/nullquant/ComfyUI-BrushNet.git
|
| 41 |
+
pip install -r requirements.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Checkpoints of BrushNet can be downloaded from [here](https://drive.google.com/drive/folders/1fqmS1CEOvXCxNWFrsSYd_jHYXxrydh1n?usp=drive_link).
|
| 45 |
+
|
| 46 |
+
The checkpoint in `segmentation_mask_brushnet_ckpt` provides checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The `random_mask_brushnet_ckpt` provides a more general ckpt for random mask shape.
|
| 47 |
+
|
| 48 |
+
`segmentation_mask_brushnet_ckpt` and `random_mask_brushnet_ckpt` contains BrushNet for SD 1.5 models while
|
| 49 |
+
`segmentation_mask_brushnet_ckpt_sdxl_v0` and `random_mask_brushnet_ckpt_sdxl_v0` for SDXL.
|
| 50 |
+
|
| 51 |
+
You should place `diffusion_pytorch_model.safetensors` files to your `models/inpaint` folder. You can also specify `inpaint` folder in your `extra_model_paths.yaml`.
|
| 52 |
+
|
| 53 |
+
For PowerPaint you should download three files. Both `diffusion_pytorch_model.safetensors` and `pytorch_model.bin` from [here](https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/tree/main/PowerPaint_Brushnet) should be placed in your `models/inpaint` folder.
|
| 54 |
+
|
| 55 |
+
Also you need SD1.5 text encoder model `model.fp16.safetensors` from [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder). It should be placed in your `models/clip` folder.
|
| 56 |
+
|
| 57 |
+
This is a structure of my `models/inpaint` folder:
|
| 58 |
+
|
| 59 |
+

|
| 60 |
+
|
| 61 |
+
Yours can be different.
|
| 62 |
+
|
| 63 |
+
## Usage
|
| 64 |
+
|
| 65 |
+
Below is an example for the intended workflow. The [workflow](example/BrushNet_basic.json) for the example can be found inside the 'example' directory.
|
| 66 |
+
|
| 67 |
+

|
| 68 |
+
|
| 69 |
+
<details>
|
| 70 |
+
<summary>SDXL</summary>
|
| 71 |
+
|
| 72 |
+

|
| 73 |
+
|
| 74 |
+
[workflow](example/BrushNet_SDXL_basic.json)
|
| 75 |
+
|
| 76 |
+
</details>
|
| 77 |
+
|
| 78 |
+
<details>
|
| 79 |
+
<summary>IPAdapter plus</summary>
|
| 80 |
+
|
| 81 |
+

|
| 82 |
+
|
| 83 |
+
[workflow](example/BrushNet_with_IPA.json)
|
| 84 |
+
|
| 85 |
+
</details>
|
| 86 |
+
|
| 87 |
+
<details>
|
| 88 |
+
<summary>LoRA</summary>
|
| 89 |
+
|
| 90 |
+

|
| 91 |
+
|
| 92 |
+
[workflow](example/BrushNet_with_LoRA.json)
|
| 93 |
+
|
| 94 |
+
</details>
|
| 95 |
+
|
| 96 |
+
<details>
|
| 97 |
+
<summary>Blending inpaint</summary>
|
| 98 |
+
|
| 99 |
+

|
| 100 |
+
|
| 101 |
+
Sometimes inference and VAE broke image, so you need to blend inpaint image with the original: [workflow](example/BrushNet_inpaint.json). You can see blurred and broken text after inpainting in the first image and how I suppose to repair it.
|
| 102 |
+
|
| 103 |
+
</details>
|
| 104 |
+
|
| 105 |
+
<details>
|
| 106 |
+
<summary>ControlNet</summary>
|
| 107 |
+
|
| 108 |
+

|
| 109 |
+
|
| 110 |
+
[workflow](example/BrushNet_with_CN.json)
|
| 111 |
+
|
| 112 |
+
[ControlNet canny edge](CN.md)
|
| 113 |
+
|
| 114 |
+
</details>
|
| 115 |
+
|
| 116 |
+
<details>
|
| 117 |
+
<summary>ELLA outpaint</summary>
|
| 118 |
+
|
| 119 |
+

|
| 120 |
+
|
| 121 |
+
[workflow](example/BrushNet_with_ELLA.json)
|
| 122 |
+
|
| 123 |
+
</details>
|
| 124 |
+
|
| 125 |
+
<details>
|
| 126 |
+
<summary>Upscale</summary>
|
| 127 |
+
|
| 128 |
+

|
| 129 |
+
|
| 130 |
+
[workflow](example/BrushNet_SDXL_upscale.json)
|
| 131 |
+
|
| 132 |
+
To upscale you should use base model, not BrushNet. The same is true for conditioning. Latent upscaling between BrushNet and KSampler will not work or will give you wierd results. These limitations are due to structure of BrushNet and its influence on UNet calculations.
|
| 133 |
+
|
| 134 |
+
</details>
|
| 135 |
+
|
| 136 |
+
<details>
|
| 137 |
+
<summary>Image batch</summary>
|
| 138 |
+
|
| 139 |
+

|
| 140 |
+
|
| 141 |
+
[workflow](example/BrushNet_image_batch.json)
|
| 142 |
+
|
| 143 |
+
If you have OOM problems, you can use Evolved Sampling from [AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved):
|
| 144 |
+
|
| 145 |
+

|
| 146 |
+
|
| 147 |
+
[workflow](example/BrushNet_image_big_batch.json)
|
| 148 |
+
|
| 149 |
+
In Context Options set context_length to number of images which can be loaded into VRAM. Images will be processed in chunks of this size.
|
| 150 |
+
|
| 151 |
+
</details>
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
<details>
|
| 155 |
+
<summary>Big image inpaint</summary>
|
| 156 |
+
|
| 157 |
+

|
| 158 |
+
|
| 159 |
+
[workflow](example/BrushNet_cut_for_inpaint.json)
|
| 160 |
+
|
| 161 |
+
When you work with big image and your inpaint mask is small it is better to cut part of the image, work with it and then blend it back.
|
| 162 |
+
I created a node for such workflow, see example.
|
| 163 |
+
|
| 164 |
+
</details>
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
<details>
|
| 168 |
+
<summary>PowerPaint outpaint</summary>
|
| 169 |
+
|
| 170 |
+

|
| 171 |
+
|
| 172 |
+
[workflow](example/PowerPaint_outpaint.json)
|
| 173 |
+
|
| 174 |
+
</details>
|
| 175 |
+
|
| 176 |
+
<details>
|
| 177 |
+
<summary>PowerPaint object removal</summary>
|
| 178 |
+
|
| 179 |
+

|
| 180 |
+
|
| 181 |
+
[workflow](example/PowerPaint_object_removal.json)
|
| 182 |
+
|
| 183 |
+
It is often hard to completely remove the object, especially if it is at the front:
|
| 184 |
+
|
| 185 |
+

|
| 186 |
+
|
| 187 |
+
You should try to add object description to negative prompt and describe empty scene, like here:
|
| 188 |
+
|
| 189 |
+

|
| 190 |
+
|
| 191 |
+
</details>
|
| 192 |
+
|
| 193 |
+
### Parameters
|
| 194 |
+
|
| 195 |
+
#### Brushnet Loader
|
| 196 |
+
|
| 197 |
+
- `dtype`, defaults to `torch.float16`. The torch.dtype of BrushNet. If you have old GPU or NVIDIA 16 series card try to switch to `torch.float32`.
|
| 198 |
+
|
| 199 |
+
#### Brushnet
|
| 200 |
+
|
| 201 |
+
- `scale`, defaults to 1.0: The "strength" of BrushNet. The outputs of the BrushNet are multiplied by `scale` before they are added to the residual in the original unet.
|
| 202 |
+
- `start_at`, defaults to 0: step at which the BrushNet starts applying.
|
| 203 |
+
- `end_at`, defaults to 10000: step at which the BrushNet stops applying.
|
| 204 |
+
|
| 205 |
+
[Here](PARAMS.md) are examples of use these two last parameters.
|
| 206 |
+
|
| 207 |
+
#### PowerPaint
|
| 208 |
+
|
| 209 |
+
- `CLIP`: PowerPaint CLIP that should be passed from PowerPaintCLIPLoader node.
|
| 210 |
+
- `fitting`: PowerPaint fitting degree.
|
| 211 |
+
- `function`: PowerPaint function, see its [page](https://github.com/open-mmlab/PowerPaint) for details.
|
| 212 |
+
|
| 213 |
+
When using certain network functions, the authors of PowerPaint recommend adding phrases to the prompt:
|
| 214 |
+
|
| 215 |
+
- object removal: `empty scene blur`
|
| 216 |
+
- context aware: `empty scene`
|
| 217 |
+
- outpainting: `empty scene`
|
| 218 |
+
|
| 219 |
+
Many of ComfyUI users use custom text generation nodes, CLIP nodes and a lot of other conditioning. I don't want to break all of these nodes, so I didn't add prompt updating and instead rely on users. Also my own experiments show that these additions to prompt are not strictly necessary.
|
| 220 |
+
|
| 221 |
+
The latent image can be from BrushNet node or not, but it should be the same size as original image (divided by 8 in latent space).
|
| 222 |
+
|
| 223 |
+
The both conditioning `positive` and `negative` in BrushNet and PowerPaint nodes are used for calculation inside, but then simply copied to output.
|
| 224 |
+
|
| 225 |
+
Be advised, not all workflows and nodes will work with BrushNet due to its structure. Also put model changes before BrushNet nodes, not after. If you need model to work with image after BrushNet inference use base one (see Upscale example below).
|
| 226 |
+
|
| 227 |
+
#### RAUNet
|
| 228 |
+
|
| 229 |
+
- `du_start`, defaults to 0: step at which the Downsample/Upsample resize starts applying.
|
| 230 |
+
- `du_end`, defaults to 4: step at which the Downsample/Upsample resize stops applying.
|
| 231 |
+
- `xa_start`, defaults to 4: step at which the CrossAttention resize starts applying.
|
| 232 |
+
- `xa_end`, defaults to 10: step at which the CrossAttention resize stops applying.
|
| 233 |
+
|
| 234 |
+
For an examples and explanation, please look [here](RAUNET.md).
|
| 235 |
+
|
| 236 |
+
## Limitations
|
| 237 |
+
|
| 238 |
+
BrushNet has some limitations (from the [paper](https://arxiv.org/abs/2403.06976)):
|
| 239 |
+
|
| 240 |
+
- The quality and content generated by the model are heavily dependent on the chosen base model.
|
| 241 |
+
The results can exhibit incoherence if, for example, the given image is a natural image while the base model primarily focuses on anime.
|
| 242 |
+
- Even with BrushNet, we still observe poor generation results in cases where the given mask has an unusually shaped
|
| 243 |
+
or irregular form, or when the given text does not align well with the masked image.
|
| 244 |
+
|
| 245 |
+
## Notes
|
| 246 |
+
|
| 247 |
+
Unfortunately, due to the nature of BrushNet code some nodes are not compatible with these, since we are trying to patch the same ComfyUI's functions.
|
| 248 |
+
|
| 249 |
+
List of known uncompartible nodes.
|
| 250 |
+
|
| 251 |
+
- [WASasquatch's FreeU_Advanced](https://github.com/WASasquatch/FreeU_Advanced/tree/main)
|
| 252 |
+
- [blepping's jank HiDiffusion](https://github.com/blepping/comfyui_jankhidiffusion)
|
| 253 |
+
|
| 254 |
+
## Credits
|
| 255 |
+
|
| 256 |
+
The code is based on
|
| 257 |
+
|
| 258 |
+
- [BrushNet](https://github.com/TencentARC/BrushNet)
|
| 259 |
+
- [PowerPaint](https://github.com/zhuang2002/PowerPaint)
|
| 260 |
+
- [HiDiffusion](https://github.com/megvii-research/HiDiffusion)
|
| 261 |
+
- [diffusers](https://github.com/huggingface/diffusers)
|
ComfyUI-BrushNet/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, PowerPaintCLIPLoader, PowerPaint, CutForInpaint
|
| 2 |
+
from .raunet_nodes import RAUNet
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
@author: nullquant
|
| 6 |
+
@title: BrushNet
|
| 7 |
+
@nickname: BrushName nodes
|
| 8 |
+
@description: These are custom nodes for ComfyUI native implementation of BrushNet, PowerPaint and RAUNet models
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
# A dictionary that contains all nodes you want to export with their names
|
| 12 |
+
# NOTE: names should be globally unique
|
| 13 |
+
NODE_CLASS_MAPPINGS = {
|
| 14 |
+
"BrushNetLoader": BrushNetLoader,
|
| 15 |
+
"BrushNet": BrushNet,
|
| 16 |
+
"BlendInpaint": BlendInpaint,
|
| 17 |
+
"PowerPaintCLIPLoader": PowerPaintCLIPLoader,
|
| 18 |
+
"PowerPaint": PowerPaint,
|
| 19 |
+
"CutForInpaint": CutForInpaint,
|
| 20 |
+
"RAUNet": RAUNet,
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
# A dictionary that contains the friendly/humanly readable titles for the nodes
|
| 24 |
+
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 25 |
+
"BrushNetLoader": "BrushNet Loader",
|
| 26 |
+
"BrushNet": "BrushNet",
|
| 27 |
+
"BlendInpaint": "Blend Inpaint",
|
| 28 |
+
"PowerPaintCLIPLoader": "PowerPaint CLIP Loader",
|
| 29 |
+
"PowerPaint": "PowerPaint",
|
| 30 |
+
"CutForInpaint": "Cut For Inpaint",
|
| 31 |
+
"RAUNet": "RAUNet",
|
| 32 |
+
}
|
ComfyUI-BrushNet/brushnet/brushnet.json
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "BrushNetModel",
|
| 3 |
+
"_diffusers_version": "0.27.0.dev0",
|
| 4 |
+
"_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"addition_embed_type": null,
|
| 7 |
+
"addition_embed_type_num_heads": 64,
|
| 8 |
+
"addition_time_embed_dim": null,
|
| 9 |
+
"attention_head_dim": 8,
|
| 10 |
+
"block_out_channels": [
|
| 11 |
+
320,
|
| 12 |
+
640,
|
| 13 |
+
1280,
|
| 14 |
+
1280
|
| 15 |
+
],
|
| 16 |
+
"brushnet_conditioning_channel_order": "rgb",
|
| 17 |
+
"class_embed_type": null,
|
| 18 |
+
"conditioning_channels": 5,
|
| 19 |
+
"conditioning_embedding_out_channels": [
|
| 20 |
+
16,
|
| 21 |
+
32,
|
| 22 |
+
96,
|
| 23 |
+
256
|
| 24 |
+
],
|
| 25 |
+
"cross_attention_dim": 768,
|
| 26 |
+
"down_block_types": [
|
| 27 |
+
"DownBlock2D",
|
| 28 |
+
"DownBlock2D",
|
| 29 |
+
"DownBlock2D",
|
| 30 |
+
"DownBlock2D"
|
| 31 |
+
],
|
| 32 |
+
"downsample_padding": 1,
|
| 33 |
+
"encoder_hid_dim": null,
|
| 34 |
+
"encoder_hid_dim_type": null,
|
| 35 |
+
"flip_sin_to_cos": true,
|
| 36 |
+
"freq_shift": 0,
|
| 37 |
+
"global_pool_conditions": false,
|
| 38 |
+
"in_channels": 4,
|
| 39 |
+
"layers_per_block": 2,
|
| 40 |
+
"mid_block_scale_factor": 1,
|
| 41 |
+
"mid_block_type": "MidBlock2D",
|
| 42 |
+
"norm_eps": 1e-05,
|
| 43 |
+
"norm_num_groups": 32,
|
| 44 |
+
"num_attention_heads": null,
|
| 45 |
+
"num_class_embeds": null,
|
| 46 |
+
"only_cross_attention": false,
|
| 47 |
+
"projection_class_embeddings_input_dim": null,
|
| 48 |
+
"resnet_time_scale_shift": "default",
|
| 49 |
+
"transformer_layers_per_block": 1,
|
| 50 |
+
"up_block_types": [
|
| 51 |
+
"UpBlock2D",
|
| 52 |
+
"UpBlock2D",
|
| 53 |
+
"UpBlock2D",
|
| 54 |
+
"UpBlock2D"
|
| 55 |
+
],
|
| 56 |
+
"upcast_attention": false,
|
| 57 |
+
"use_linear_projection": false
|
| 58 |
+
}
|
ComfyUI-BrushNet/brushnet/brushnet.py
ADDED
|
@@ -0,0 +1,948 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
from diffusers.utils import BaseOutput, logging
|
| 10 |
+
from diffusers.models.attention_processor import (
|
| 11 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 12 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 13 |
+
AttentionProcessor,
|
| 14 |
+
AttnAddedKVProcessor,
|
| 15 |
+
AttnProcessor,
|
| 16 |
+
)
|
| 17 |
+
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
| 18 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 19 |
+
|
| 20 |
+
from .unet_2d_blocks import (
|
| 21 |
+
CrossAttnDownBlock2D,
|
| 22 |
+
DownBlock2D,
|
| 23 |
+
UNetMidBlock2D,
|
| 24 |
+
UNetMidBlock2DCrossAttn,
|
| 25 |
+
get_down_block,
|
| 26 |
+
get_mid_block,
|
| 27 |
+
get_up_block,
|
| 28 |
+
MidBlock2D
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from .unet_2d_condition import UNet2DConditionModel
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class BrushNetOutput(BaseOutput):
|
| 39 |
+
"""
|
| 40 |
+
The output of [`BrushNetModel`].
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
up_block_res_samples (`tuple[torch.Tensor]`):
|
| 44 |
+
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
| 45 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
| 46 |
+
used to condition the original UNet's upsampling activations.
|
| 47 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
| 48 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
| 49 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
| 50 |
+
used to condition the original UNet's downsampling activations.
|
| 51 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
| 52 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
| 53 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
| 54 |
+
Output can be used to condition the original UNet's middle block activation.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
up_block_res_samples: Tuple[torch.Tensor]
|
| 58 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
| 59 |
+
mid_block_res_sample: torch.Tensor
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class BrushNetModel(ModelMixin, ConfigMixin):
|
| 63 |
+
"""
|
| 64 |
+
A BrushNet model.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
in_channels (`int`, defaults to 4):
|
| 68 |
+
The number of channels in the input sample.
|
| 69 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 70 |
+
Whether to flip the sin to cos in the time embedding.
|
| 71 |
+
freq_shift (`int`, defaults to 0):
|
| 72 |
+
The frequency shift to apply to the time embedding.
|
| 73 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
| 74 |
+
The tuple of downsample blocks to use.
|
| 75 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
| 76 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
| 77 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
| 78 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
| 79 |
+
The tuple of upsample blocks to use.
|
| 80 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
| 81 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
| 82 |
+
The tuple of output channels for each block.
|
| 83 |
+
layers_per_block (`int`, defaults to 2):
|
| 84 |
+
The number of layers per block.
|
| 85 |
+
downsample_padding (`int`, defaults to 1):
|
| 86 |
+
The padding to use for the downsampling convolution.
|
| 87 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
| 88 |
+
The scale factor to use for the mid block.
|
| 89 |
+
act_fn (`str`, defaults to "silu"):
|
| 90 |
+
The activation function to use.
|
| 91 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 92 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
| 93 |
+
in post-processing.
|
| 94 |
+
norm_eps (`float`, defaults to 1e-5):
|
| 95 |
+
The epsilon to use for the normalization.
|
| 96 |
+
cross_attention_dim (`int`, defaults to 1280):
|
| 97 |
+
The dimension of the cross attention features.
|
| 98 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
| 99 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 100 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 101 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 102 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
| 103 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
| 104 |
+
dimension to `cross_attention_dim`.
|
| 105 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
| 106 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
| 107 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
| 108 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
| 109 |
+
The dimension of the attention heads.
|
| 110 |
+
use_linear_projection (`bool`, defaults to `False`):
|
| 111 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 112 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
| 113 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
| 114 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
| 115 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
| 116 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
| 117 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
| 118 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
| 119 |
+
class conditioning with `class_embed_type` equal to `None`.
|
| 120 |
+
upcast_attention (`bool`, defaults to `False`):
|
| 121 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
| 122 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
| 123 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
| 124 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
| 125 |
+
`class_embed_type="projection"`.
|
| 126 |
+
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
| 127 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
| 128 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
| 129 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
| 130 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
| 131 |
+
TODO(Patrick) - unused parameter.
|
| 132 |
+
addition_embed_type_num_heads (`int`, defaults to 64):
|
| 133 |
+
The number of heads to use for the `TextTimeEmbedding` layer.
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
_supports_gradient_checkpointing = True
|
| 137 |
+
|
| 138 |
+
@register_to_config
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
in_channels: int = 4,
|
| 142 |
+
conditioning_channels: int = 5,
|
| 143 |
+
flip_sin_to_cos: bool = True,
|
| 144 |
+
freq_shift: int = 0,
|
| 145 |
+
down_block_types: Tuple[str, ...] = (
|
| 146 |
+
"DownBlock2D",
|
| 147 |
+
"DownBlock2D",
|
| 148 |
+
"DownBlock2D",
|
| 149 |
+
"DownBlock2D",
|
| 150 |
+
),
|
| 151 |
+
mid_block_type: Optional[str] = "UNetMidBlock2D",
|
| 152 |
+
up_block_types: Tuple[str, ...] = (
|
| 153 |
+
"UpBlock2D",
|
| 154 |
+
"UpBlock2D",
|
| 155 |
+
"UpBlock2D",
|
| 156 |
+
"UpBlock2D",
|
| 157 |
+
),
|
| 158 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 159 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
| 160 |
+
layers_per_block: int = 2,
|
| 161 |
+
downsample_padding: int = 1,
|
| 162 |
+
mid_block_scale_factor: float = 1,
|
| 163 |
+
act_fn: str = "silu",
|
| 164 |
+
norm_num_groups: Optional[int] = 32,
|
| 165 |
+
norm_eps: float = 1e-5,
|
| 166 |
+
cross_attention_dim: int = 1280,
|
| 167 |
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
| 168 |
+
encoder_hid_dim: Optional[int] = None,
|
| 169 |
+
encoder_hid_dim_type: Optional[str] = None,
|
| 170 |
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
| 171 |
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 172 |
+
use_linear_projection: bool = False,
|
| 173 |
+
class_embed_type: Optional[str] = None,
|
| 174 |
+
addition_embed_type: Optional[str] = None,
|
| 175 |
+
addition_time_embed_dim: Optional[int] = None,
|
| 176 |
+
num_class_embeds: Optional[int] = None,
|
| 177 |
+
upcast_attention: bool = False,
|
| 178 |
+
resnet_time_scale_shift: str = "default",
|
| 179 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
| 180 |
+
brushnet_conditioning_channel_order: str = "rgb",
|
| 181 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
| 182 |
+
global_pool_conditions: bool = False,
|
| 183 |
+
addition_embed_type_num_heads: int = 64,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
| 188 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
| 189 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
| 190 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
| 191 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
| 192 |
+
# which is why we correct for the naming here.
|
| 193 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
| 194 |
+
|
| 195 |
+
# Check inputs
|
| 196 |
+
if len(down_block_types) != len(up_block_types):
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if len(block_out_channels) != len(down_block_types):
|
| 202 |
+
raise ValueError(
|
| 203 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 212 |
+
raise ValueError(
|
| 213 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
if isinstance(transformer_layers_per_block, int):
|
| 217 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 218 |
+
|
| 219 |
+
# input
|
| 220 |
+
conv_in_kernel = 3
|
| 221 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
| 222 |
+
self.conv_in_condition = nn.Conv2d(
|
| 223 |
+
in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# time
|
| 227 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 228 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 229 |
+
timestep_input_dim = block_out_channels[0]
|
| 230 |
+
self.time_embedding = TimestepEmbedding(
|
| 231 |
+
timestep_input_dim,
|
| 232 |
+
time_embed_dim,
|
| 233 |
+
act_fn=act_fn,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
| 237 |
+
encoder_hid_dim_type = "text_proj"
|
| 238 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
| 239 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
| 240 |
+
|
| 241 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
if encoder_hid_dim_type == "text_proj":
|
| 247 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
| 248 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
| 249 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 250 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 251 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
| 252 |
+
self.encoder_hid_proj = TextImageProjection(
|
| 253 |
+
text_embed_dim=encoder_hid_dim,
|
| 254 |
+
image_embed_dim=cross_attention_dim,
|
| 255 |
+
cross_attention_dim=cross_attention_dim,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
elif encoder_hid_dim_type is not None:
|
| 259 |
+
raise ValueError(
|
| 260 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
self.encoder_hid_proj = None
|
| 264 |
+
|
| 265 |
+
# class embedding
|
| 266 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 267 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 268 |
+
elif class_embed_type == "timestep":
|
| 269 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 270 |
+
elif class_embed_type == "identity":
|
| 271 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 272 |
+
elif class_embed_type == "projection":
|
| 273 |
+
if projection_class_embeddings_input_dim is None:
|
| 274 |
+
raise ValueError(
|
| 275 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
| 276 |
+
)
|
| 277 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
| 278 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
| 279 |
+
# 2. it projects from an arbitrary input dimension.
|
| 280 |
+
#
|
| 281 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
| 282 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
| 283 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
| 284 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 285 |
+
else:
|
| 286 |
+
self.class_embedding = None
|
| 287 |
+
|
| 288 |
+
if addition_embed_type == "text":
|
| 289 |
+
if encoder_hid_dim is not None:
|
| 290 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
| 291 |
+
else:
|
| 292 |
+
text_time_embedding_from_dim = cross_attention_dim
|
| 293 |
+
|
| 294 |
+
self.add_embedding = TextTimeEmbedding(
|
| 295 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
| 296 |
+
)
|
| 297 |
+
elif addition_embed_type == "text_image":
|
| 298 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 299 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 300 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
| 301 |
+
self.add_embedding = TextImageTimeEmbedding(
|
| 302 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
| 303 |
+
)
|
| 304 |
+
elif addition_embed_type == "text_time":
|
| 305 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
| 306 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 307 |
+
|
| 308 |
+
elif addition_embed_type is not None:
|
| 309 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
| 310 |
+
|
| 311 |
+
self.down_blocks = nn.ModuleList([])
|
| 312 |
+
self.brushnet_down_blocks = nn.ModuleList([])
|
| 313 |
+
|
| 314 |
+
if isinstance(only_cross_attention, bool):
|
| 315 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 316 |
+
|
| 317 |
+
if isinstance(attention_head_dim, int):
|
| 318 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
| 319 |
+
|
| 320 |
+
if isinstance(num_attention_heads, int):
|
| 321 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 322 |
+
|
| 323 |
+
# down
|
| 324 |
+
output_channel = block_out_channels[0]
|
| 325 |
+
|
| 326 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 327 |
+
brushnet_block = zero_module(brushnet_block)
|
| 328 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
| 329 |
+
|
| 330 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 331 |
+
input_channel = output_channel
|
| 332 |
+
output_channel = block_out_channels[i]
|
| 333 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 334 |
+
|
| 335 |
+
down_block = get_down_block(
|
| 336 |
+
down_block_type,
|
| 337 |
+
num_layers=layers_per_block,
|
| 338 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 339 |
+
in_channels=input_channel,
|
| 340 |
+
out_channels=output_channel,
|
| 341 |
+
temb_channels=time_embed_dim,
|
| 342 |
+
add_downsample=not is_final_block,
|
| 343 |
+
resnet_eps=norm_eps,
|
| 344 |
+
resnet_act_fn=act_fn,
|
| 345 |
+
resnet_groups=norm_num_groups,
|
| 346 |
+
cross_attention_dim=cross_attention_dim,
|
| 347 |
+
num_attention_heads=num_attention_heads[i],
|
| 348 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 349 |
+
downsample_padding=downsample_padding,
|
| 350 |
+
use_linear_projection=use_linear_projection,
|
| 351 |
+
only_cross_attention=only_cross_attention[i],
|
| 352 |
+
upcast_attention=upcast_attention,
|
| 353 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 354 |
+
)
|
| 355 |
+
self.down_blocks.append(down_block)
|
| 356 |
+
|
| 357 |
+
for _ in range(layers_per_block):
|
| 358 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 359 |
+
brushnet_block = zero_module(brushnet_block)
|
| 360 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
| 361 |
+
|
| 362 |
+
if not is_final_block:
|
| 363 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 364 |
+
brushnet_block = zero_module(brushnet_block)
|
| 365 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
| 366 |
+
|
| 367 |
+
# mid
|
| 368 |
+
mid_block_channel = block_out_channels[-1]
|
| 369 |
+
|
| 370 |
+
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
| 371 |
+
brushnet_block = zero_module(brushnet_block)
|
| 372 |
+
self.brushnet_mid_block = brushnet_block
|
| 373 |
+
|
| 374 |
+
self.mid_block = get_mid_block(
|
| 375 |
+
mid_block_type,
|
| 376 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 377 |
+
in_channels=mid_block_channel,
|
| 378 |
+
temb_channels=time_embed_dim,
|
| 379 |
+
resnet_eps=norm_eps,
|
| 380 |
+
resnet_act_fn=act_fn,
|
| 381 |
+
output_scale_factor=mid_block_scale_factor,
|
| 382 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 383 |
+
cross_attention_dim=cross_attention_dim,
|
| 384 |
+
num_attention_heads=num_attention_heads[-1],
|
| 385 |
+
resnet_groups=norm_num_groups,
|
| 386 |
+
use_linear_projection=use_linear_projection,
|
| 387 |
+
upcast_attention=upcast_attention,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# count how many layers upsample the images
|
| 391 |
+
self.num_upsamplers = 0
|
| 392 |
+
|
| 393 |
+
# up
|
| 394 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 395 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| 396 |
+
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
| 397 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
| 398 |
+
|
| 399 |
+
output_channel = reversed_block_out_channels[0]
|
| 400 |
+
|
| 401 |
+
self.up_blocks = nn.ModuleList([])
|
| 402 |
+
self.brushnet_up_blocks = nn.ModuleList([])
|
| 403 |
+
|
| 404 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 405 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 406 |
+
|
| 407 |
+
prev_output_channel = output_channel
|
| 408 |
+
output_channel = reversed_block_out_channels[i]
|
| 409 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 410 |
+
|
| 411 |
+
# add upsample block for all BUT final layer
|
| 412 |
+
if not is_final_block:
|
| 413 |
+
add_upsample = True
|
| 414 |
+
self.num_upsamplers += 1
|
| 415 |
+
else:
|
| 416 |
+
add_upsample = False
|
| 417 |
+
|
| 418 |
+
up_block = get_up_block(
|
| 419 |
+
up_block_type,
|
| 420 |
+
num_layers=layers_per_block+1,
|
| 421 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| 422 |
+
in_channels=input_channel,
|
| 423 |
+
out_channels=output_channel,
|
| 424 |
+
prev_output_channel=prev_output_channel,
|
| 425 |
+
temb_channels=time_embed_dim,
|
| 426 |
+
add_upsample=add_upsample,
|
| 427 |
+
resnet_eps=norm_eps,
|
| 428 |
+
resnet_act_fn=act_fn,
|
| 429 |
+
resolution_idx=i,
|
| 430 |
+
resnet_groups=norm_num_groups,
|
| 431 |
+
cross_attention_dim=cross_attention_dim,
|
| 432 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
| 433 |
+
use_linear_projection=use_linear_projection,
|
| 434 |
+
only_cross_attention=only_cross_attention[i],
|
| 435 |
+
upcast_attention=upcast_attention,
|
| 436 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 437 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 438 |
+
)
|
| 439 |
+
self.up_blocks.append(up_block)
|
| 440 |
+
prev_output_channel = output_channel
|
| 441 |
+
|
| 442 |
+
for _ in range(layers_per_block+1):
|
| 443 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 444 |
+
brushnet_block = zero_module(brushnet_block)
|
| 445 |
+
self.brushnet_up_blocks.append(brushnet_block)
|
| 446 |
+
|
| 447 |
+
if not is_final_block:
|
| 448 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 449 |
+
brushnet_block = zero_module(brushnet_block)
|
| 450 |
+
self.brushnet_up_blocks.append(brushnet_block)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
@classmethod
|
| 454 |
+
def from_unet(
|
| 455 |
+
cls,
|
| 456 |
+
unet: UNet2DConditionModel,
|
| 457 |
+
brushnet_conditioning_channel_order: str = "rgb",
|
| 458 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
| 459 |
+
load_weights_from_unet: bool = True,
|
| 460 |
+
conditioning_channels: int = 5,
|
| 461 |
+
):
|
| 462 |
+
r"""
|
| 463 |
+
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
| 464 |
+
|
| 465 |
+
Parameters:
|
| 466 |
+
unet (`UNet2DConditionModel`):
|
| 467 |
+
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
| 468 |
+
where applicable.
|
| 469 |
+
"""
|
| 470 |
+
transformer_layers_per_block = (
|
| 471 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
| 472 |
+
)
|
| 473 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
| 474 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
| 475 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
| 476 |
+
addition_time_embed_dim = (
|
| 477 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
brushnet = cls(
|
| 481 |
+
in_channels=unet.config.in_channels,
|
| 482 |
+
conditioning_channels=conditioning_channels,
|
| 483 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
| 484 |
+
freq_shift=unet.config.freq_shift,
|
| 485 |
+
down_block_types=["DownBlock2D" for block_name in unet.config.down_block_types],
|
| 486 |
+
mid_block_type='MidBlock2D',
|
| 487 |
+
up_block_types=["UpBlock2D" for block_name in unet.config.down_block_types],
|
| 488 |
+
only_cross_attention=unet.config.only_cross_attention,
|
| 489 |
+
block_out_channels=unet.config.block_out_channels,
|
| 490 |
+
layers_per_block=unet.config.layers_per_block,
|
| 491 |
+
downsample_padding=unet.config.downsample_padding,
|
| 492 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
| 493 |
+
act_fn=unet.config.act_fn,
|
| 494 |
+
norm_num_groups=unet.config.norm_num_groups,
|
| 495 |
+
norm_eps=unet.config.norm_eps,
|
| 496 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
| 497 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 498 |
+
encoder_hid_dim=encoder_hid_dim,
|
| 499 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
| 500 |
+
attention_head_dim=unet.config.attention_head_dim,
|
| 501 |
+
num_attention_heads=unet.config.num_attention_heads,
|
| 502 |
+
use_linear_projection=unet.config.use_linear_projection,
|
| 503 |
+
class_embed_type=unet.config.class_embed_type,
|
| 504 |
+
addition_embed_type=addition_embed_type,
|
| 505 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
| 506 |
+
num_class_embeds=unet.config.num_class_embeds,
|
| 507 |
+
upcast_attention=unet.config.upcast_attention,
|
| 508 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
| 509 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
| 510 |
+
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
| 511 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
if load_weights_from_unet:
|
| 515 |
+
conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
|
| 516 |
+
conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
|
| 517 |
+
conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
|
| 518 |
+
brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
|
| 519 |
+
brushnet.conv_in_condition.bias=unet.conv_in.bias
|
| 520 |
+
|
| 521 |
+
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
| 522 |
+
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
| 523 |
+
|
| 524 |
+
if brushnet.class_embedding:
|
| 525 |
+
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
| 526 |
+
|
| 527 |
+
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
|
| 528 |
+
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
|
| 529 |
+
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)
|
| 530 |
+
|
| 531 |
+
return brushnet
|
| 532 |
+
|
| 533 |
+
@property
|
| 534 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 535 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 536 |
+
r"""
|
| 537 |
+
Returns:
|
| 538 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 539 |
+
indexed by its weight name.
|
| 540 |
+
"""
|
| 541 |
+
# set recursively
|
| 542 |
+
processors = {}
|
| 543 |
+
|
| 544 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 545 |
+
if hasattr(module, "get_processor"):
|
| 546 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
| 547 |
+
|
| 548 |
+
for sub_name, child in module.named_children():
|
| 549 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 550 |
+
|
| 551 |
+
return processors
|
| 552 |
+
|
| 553 |
+
for name, module in self.named_children():
|
| 554 |
+
fn_recursive_add_processors(name, module, processors)
|
| 555 |
+
|
| 556 |
+
return processors
|
| 557 |
+
|
| 558 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 559 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 560 |
+
r"""
|
| 561 |
+
Sets the attention processor to use to compute attention.
|
| 562 |
+
|
| 563 |
+
Parameters:
|
| 564 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 565 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 566 |
+
for **all** `Attention` layers.
|
| 567 |
+
|
| 568 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 569 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 570 |
+
|
| 571 |
+
"""
|
| 572 |
+
count = len(self.attn_processors.keys())
|
| 573 |
+
|
| 574 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 575 |
+
raise ValueError(
|
| 576 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 577 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 581 |
+
if hasattr(module, "set_processor"):
|
| 582 |
+
if not isinstance(processor, dict):
|
| 583 |
+
module.set_processor(processor)
|
| 584 |
+
else:
|
| 585 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 586 |
+
|
| 587 |
+
for sub_name, child in module.named_children():
|
| 588 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 589 |
+
|
| 590 |
+
for name, module in self.named_children():
|
| 591 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 592 |
+
|
| 593 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
| 594 |
+
def set_default_attn_processor(self):
|
| 595 |
+
"""
|
| 596 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 597 |
+
"""
|
| 598 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 599 |
+
processor = AttnAddedKVProcessor()
|
| 600 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 601 |
+
processor = AttnProcessor()
|
| 602 |
+
else:
|
| 603 |
+
raise ValueError(
|
| 604 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
self.set_attn_processor(processor)
|
| 608 |
+
|
| 609 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
| 610 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
| 611 |
+
r"""
|
| 612 |
+
Enable sliced attention computation.
|
| 613 |
+
|
| 614 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
| 615 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 619 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
| 620 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
| 621 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 622 |
+
must be a multiple of `slice_size`.
|
| 623 |
+
"""
|
| 624 |
+
sliceable_head_dims = []
|
| 625 |
+
|
| 626 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
| 627 |
+
if hasattr(module, "set_attention_slice"):
|
| 628 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 629 |
+
|
| 630 |
+
for child in module.children():
|
| 631 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
| 632 |
+
|
| 633 |
+
# retrieve number of attention layers
|
| 634 |
+
for module in self.children():
|
| 635 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
| 636 |
+
|
| 637 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
| 638 |
+
|
| 639 |
+
if slice_size == "auto":
|
| 640 |
+
# half the attention head size is usually a good trade-off between
|
| 641 |
+
# speed and memory
|
| 642 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 643 |
+
elif slice_size == "max":
|
| 644 |
+
# make smallest slice possible
|
| 645 |
+
slice_size = num_sliceable_layers * [1]
|
| 646 |
+
|
| 647 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 648 |
+
|
| 649 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 650 |
+
raise ValueError(
|
| 651 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 652 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
for i in range(len(slice_size)):
|
| 656 |
+
size = slice_size[i]
|
| 657 |
+
dim = sliceable_head_dims[i]
|
| 658 |
+
if size is not None and size > dim:
|
| 659 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 660 |
+
|
| 661 |
+
# Recursively walk through all the children.
|
| 662 |
+
# Any children which exposes the set_attention_slice method
|
| 663 |
+
# gets the message
|
| 664 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
| 665 |
+
if hasattr(module, "set_attention_slice"):
|
| 666 |
+
module.set_attention_slice(slice_size.pop())
|
| 667 |
+
|
| 668 |
+
for child in module.children():
|
| 669 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 670 |
+
|
| 671 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 672 |
+
for module in self.children():
|
| 673 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 674 |
+
|
| 675 |
+
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
| 676 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
| 677 |
+
module.gradient_checkpointing = value
|
| 678 |
+
|
| 679 |
+
def forward(
|
| 680 |
+
self,
|
| 681 |
+
sample: torch.FloatTensor,
|
| 682 |
+
encoder_hidden_states: torch.Tensor,
|
| 683 |
+
brushnet_cond: torch.FloatTensor,
|
| 684 |
+
timestep = None,
|
| 685 |
+
time_emb = None,
|
| 686 |
+
conditioning_scale: float = 1.0,
|
| 687 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 688 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 689 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 690 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 691 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 692 |
+
guess_mode: bool = False,
|
| 693 |
+
return_dict: bool = True,
|
| 694 |
+
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
| 695 |
+
"""
|
| 696 |
+
The [`BrushNetModel`] forward method.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
sample (`torch.FloatTensor`):
|
| 700 |
+
The noisy input tensor.
|
| 701 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
| 702 |
+
The number of timesteps to denoise an input.
|
| 703 |
+
encoder_hidden_states (`torch.Tensor`):
|
| 704 |
+
The encoder hidden states.
|
| 705 |
+
brushnet_cond (`torch.FloatTensor`):
|
| 706 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
| 707 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
| 708 |
+
The scale factor for BrushNet outputs.
|
| 709 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
| 710 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
| 711 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
| 712 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
| 713 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
| 714 |
+
embeddings.
|
| 715 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
| 716 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 717 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 718 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 719 |
+
added_cond_kwargs (`dict`):
|
| 720 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
| 721 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
| 722 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
| 723 |
+
guess_mode (`bool`, defaults to `False`):
|
| 724 |
+
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
| 725 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
| 726 |
+
return_dict (`bool`, defaults to `True`):
|
| 727 |
+
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
| 731 |
+
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
| 732 |
+
returned where the first element is the sample tensor.
|
| 733 |
+
"""
|
| 734 |
+
|
| 735 |
+
# check channel order
|
| 736 |
+
channel_order = self.config.brushnet_conditioning_channel_order
|
| 737 |
+
|
| 738 |
+
if channel_order == "rgb":
|
| 739 |
+
# in rgb order by default
|
| 740 |
+
...
|
| 741 |
+
elif channel_order == "bgr":
|
| 742 |
+
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
| 743 |
+
else:
|
| 744 |
+
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
| 745 |
+
|
| 746 |
+
# prepare attention_mask
|
| 747 |
+
if attention_mask is not None:
|
| 748 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 749 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 750 |
+
|
| 751 |
+
if timestep is None and time_emb is None:
|
| 752 |
+
raise ValueError(f"`timestep` and `emb` are both None")
|
| 753 |
+
|
| 754 |
+
#print("BN: sample.device", sample.device)
|
| 755 |
+
#print("BN: TE.device", self.time_embedding.linear_1.weight.device)
|
| 756 |
+
|
| 757 |
+
if timestep is not None:
|
| 758 |
+
# 1. time
|
| 759 |
+
timesteps = timestep
|
| 760 |
+
if not torch.is_tensor(timesteps):
|
| 761 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 762 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 763 |
+
is_mps = sample.device.type == "mps"
|
| 764 |
+
if isinstance(timestep, float):
|
| 765 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 766 |
+
else:
|
| 767 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 768 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 769 |
+
elif len(timesteps.shape) == 0:
|
| 770 |
+
timesteps = timesteps[None].to(sample.device)
|
| 771 |
+
|
| 772 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 773 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 774 |
+
|
| 775 |
+
t_emb = self.time_proj(timesteps)
|
| 776 |
+
|
| 777 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 778 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 779 |
+
# there might be better ways to encapsulate this.
|
| 780 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 781 |
+
|
| 782 |
+
#print("t_emb.device =",t_emb.device)
|
| 783 |
+
|
| 784 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 785 |
+
aug_emb = None
|
| 786 |
+
|
| 787 |
+
#print('emb.shape', emb.shape)
|
| 788 |
+
|
| 789 |
+
if self.class_embedding is not None:
|
| 790 |
+
if class_labels is None:
|
| 791 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 792 |
+
|
| 793 |
+
if self.config.class_embed_type == "timestep":
|
| 794 |
+
class_labels = self.time_proj(class_labels)
|
| 795 |
+
|
| 796 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
| 797 |
+
emb = emb + class_emb
|
| 798 |
+
|
| 799 |
+
if self.config.addition_embed_type is not None:
|
| 800 |
+
if self.config.addition_embed_type == "text":
|
| 801 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 802 |
+
|
| 803 |
+
elif self.config.addition_embed_type == "text_time":
|
| 804 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 805 |
+
raise ValueError(
|
| 806 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 807 |
+
)
|
| 808 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 809 |
+
if "time_ids" not in added_cond_kwargs:
|
| 810 |
+
raise ValueError(
|
| 811 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 812 |
+
)
|
| 813 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 814 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 815 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 816 |
+
|
| 817 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 818 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 819 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 820 |
+
|
| 821 |
+
#print('text_embeds', text_embeds.shape, 'time_ids', time_ids.shape, 'time_embeds', time_embeds.shape, 'add__embeds', add_embeds.shape, 'aug_emb', aug_emb.shape)
|
| 822 |
+
|
| 823 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 824 |
+
else:
|
| 825 |
+
emb = time_emb
|
| 826 |
+
|
| 827 |
+
# 2. pre-process
|
| 828 |
+
|
| 829 |
+
brushnet_cond=torch.concat([sample,brushnet_cond],1)
|
| 830 |
+
sample = self.conv_in_condition(brushnet_cond)
|
| 831 |
+
|
| 832 |
+
# 3. down
|
| 833 |
+
down_block_res_samples = (sample,)
|
| 834 |
+
for downsample_block in self.down_blocks:
|
| 835 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 836 |
+
sample, res_samples = downsample_block(
|
| 837 |
+
hidden_states=sample,
|
| 838 |
+
temb=emb,
|
| 839 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 840 |
+
attention_mask=attention_mask,
|
| 841 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 842 |
+
)
|
| 843 |
+
else:
|
| 844 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 845 |
+
|
| 846 |
+
down_block_res_samples += res_samples
|
| 847 |
+
|
| 848 |
+
# 4. PaintingNet down blocks
|
| 849 |
+
brushnet_down_block_res_samples = ()
|
| 850 |
+
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
| 851 |
+
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
| 852 |
+
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
# 5. mid
|
| 856 |
+
if self.mid_block is not None:
|
| 857 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| 858 |
+
sample = self.mid_block(
|
| 859 |
+
sample,
|
| 860 |
+
emb,
|
| 861 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 862 |
+
attention_mask=attention_mask,
|
| 863 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 864 |
+
)
|
| 865 |
+
else:
|
| 866 |
+
sample = self.mid_block(sample, emb)
|
| 867 |
+
|
| 868 |
+
# 6. BrushNet mid blocks
|
| 869 |
+
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
| 870 |
+
|
| 871 |
+
# 7. up
|
| 872 |
+
up_block_res_samples = ()
|
| 873 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 874 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 875 |
+
|
| 876 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 877 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 878 |
+
|
| 879 |
+
# if we have not reached the final block and need to forward the
|
| 880 |
+
# upsample size, we do it here
|
| 881 |
+
if not is_final_block:
|
| 882 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 883 |
+
|
| 884 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 885 |
+
sample, up_res_samples = upsample_block(
|
| 886 |
+
hidden_states=sample,
|
| 887 |
+
temb=emb,
|
| 888 |
+
res_hidden_states_tuple=res_samples,
|
| 889 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 890 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 891 |
+
upsample_size=upsample_size,
|
| 892 |
+
attention_mask=attention_mask,
|
| 893 |
+
return_res_samples=True
|
| 894 |
+
)
|
| 895 |
+
else:
|
| 896 |
+
sample, up_res_samples = upsample_block(
|
| 897 |
+
hidden_states=sample,
|
| 898 |
+
temb=emb,
|
| 899 |
+
res_hidden_states_tuple=res_samples,
|
| 900 |
+
upsample_size=upsample_size,
|
| 901 |
+
return_res_samples=True
|
| 902 |
+
)
|
| 903 |
+
|
| 904 |
+
up_block_res_samples += up_res_samples
|
| 905 |
+
|
| 906 |
+
# 8. BrushNet up blocks
|
| 907 |
+
brushnet_up_block_res_samples = ()
|
| 908 |
+
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
| 909 |
+
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
| 910 |
+
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
| 911 |
+
|
| 912 |
+
# 6. scaling
|
| 913 |
+
if guess_mode and not self.config.global_pool_conditions:
|
| 914 |
+
scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0
|
| 915 |
+
scales = scales * conditioning_scale
|
| 916 |
+
|
| 917 |
+
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
|
| 918 |
+
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
| 919 |
+
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
|
| 920 |
+
else:
|
| 921 |
+
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
|
| 922 |
+
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
| 923 |
+
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
if self.config.global_pool_conditions:
|
| 927 |
+
brushnet_down_block_res_samples = [
|
| 928 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
| 929 |
+
]
|
| 930 |
+
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
| 931 |
+
brushnet_up_block_res_samples = [
|
| 932 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
| 933 |
+
]
|
| 934 |
+
|
| 935 |
+
if not return_dict:
|
| 936 |
+
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
| 937 |
+
|
| 938 |
+
return BrushNetOutput(
|
| 939 |
+
down_block_res_samples=brushnet_down_block_res_samples,
|
| 940 |
+
mid_block_res_sample=brushnet_mid_block_res_sample,
|
| 941 |
+
up_block_res_samples=brushnet_up_block_res_samples
|
| 942 |
+
)
|
| 943 |
+
|
| 944 |
+
|
| 945 |
+
def zero_module(module):
|
| 946 |
+
for p in module.parameters():
|
| 947 |
+
nn.init.zeros_(p)
|
| 948 |
+
return module
|
ComfyUI-BrushNet/brushnet/brushnet_ca.py
ADDED
|
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 8 |
+
from diffusers.utils import BaseOutput, logging
|
| 9 |
+
from diffusers.models.attention_processor import (
|
| 10 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 11 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 12 |
+
AttentionProcessor,
|
| 13 |
+
AttnAddedKVProcessor,
|
| 14 |
+
AttnProcessor,
|
| 15 |
+
)
|
| 16 |
+
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
|
| 17 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 18 |
+
|
| 19 |
+
from .unet_2d_blocks import (
|
| 20 |
+
CrossAttnDownBlock2D,
|
| 21 |
+
DownBlock2D,
|
| 22 |
+
UNetMidBlock2D,
|
| 23 |
+
UNetMidBlock2DCrossAttn,
|
| 24 |
+
get_down_block,
|
| 25 |
+
get_mid_block,
|
| 26 |
+
get_up_block,
|
| 27 |
+
MidBlock2D
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
from .unet_2d_condition import UNet2DConditionModel
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclass
|
| 37 |
+
class BrushNetOutput(BaseOutput):
|
| 38 |
+
"""
|
| 39 |
+
The output of [`BrushNetModel`].
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
up_block_res_samples (`tuple[torch.Tensor]`):
|
| 43 |
+
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
| 44 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
| 45 |
+
used to condition the original UNet's upsampling activations.
|
| 46 |
+
down_block_res_samples (`tuple[torch.Tensor]`):
|
| 47 |
+
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
| 48 |
+
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
| 49 |
+
used to condition the original UNet's downsampling activations.
|
| 50 |
+
mid_down_block_re_sample (`torch.Tensor`):
|
| 51 |
+
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
| 52 |
+
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
| 53 |
+
Output can be used to condition the original UNet's middle block activation.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
up_block_res_samples: Tuple[torch.Tensor]
|
| 57 |
+
down_block_res_samples: Tuple[torch.Tensor]
|
| 58 |
+
mid_block_res_sample: torch.Tensor
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class BrushNetModel(ModelMixin, ConfigMixin):
|
| 62 |
+
"""
|
| 63 |
+
A BrushNet model.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
in_channels (`int`, defaults to 4):
|
| 67 |
+
The number of channels in the input sample.
|
| 68 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 69 |
+
Whether to flip the sin to cos in the time embedding.
|
| 70 |
+
freq_shift (`int`, defaults to 0):
|
| 71 |
+
The frequency shift to apply to the time embedding.
|
| 72 |
+
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
| 73 |
+
The tuple of downsample blocks to use.
|
| 74 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
| 75 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
| 76 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
| 77 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
| 78 |
+
The tuple of upsample blocks to use.
|
| 79 |
+
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
| 80 |
+
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
| 81 |
+
The tuple of output channels for each block.
|
| 82 |
+
layers_per_block (`int`, defaults to 2):
|
| 83 |
+
The number of layers per block.
|
| 84 |
+
downsample_padding (`int`, defaults to 1):
|
| 85 |
+
The padding to use for the downsampling convolution.
|
| 86 |
+
mid_block_scale_factor (`float`, defaults to 1):
|
| 87 |
+
The scale factor to use for the mid block.
|
| 88 |
+
act_fn (`str`, defaults to "silu"):
|
| 89 |
+
The activation function to use.
|
| 90 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 91 |
+
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
| 92 |
+
in post-processing.
|
| 93 |
+
norm_eps (`float`, defaults to 1e-5):
|
| 94 |
+
The epsilon to use for the normalization.
|
| 95 |
+
cross_attention_dim (`int`, defaults to 1280):
|
| 96 |
+
The dimension of the cross attention features.
|
| 97 |
+
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
| 98 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 99 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 100 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 101 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
| 102 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
| 103 |
+
dimension to `cross_attention_dim`.
|
| 104 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
| 105 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
| 106 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
| 107 |
+
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
| 108 |
+
The dimension of the attention heads.
|
| 109 |
+
use_linear_projection (`bool`, defaults to `False`):
|
| 110 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 111 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
| 112 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
| 113 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
| 114 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
| 115 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
| 116 |
+
num_class_embeds (`int`, *optional*, defaults to 0):
|
| 117 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
| 118 |
+
class conditioning with `class_embed_type` equal to `None`.
|
| 119 |
+
upcast_attention (`bool`, defaults to `False`):
|
| 120 |
+
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
| 121 |
+
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
| 122 |
+
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
| 123 |
+
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
| 124 |
+
`class_embed_type="projection"`.
|
| 125 |
+
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
| 126 |
+
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
| 127 |
+
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
| 128 |
+
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
| 129 |
+
global_pool_conditions (`bool`, defaults to `False`):
|
| 130 |
+
TODO(Patrick) - unused parameter.
|
| 131 |
+
addition_embed_type_num_heads (`int`, defaults to 64):
|
| 132 |
+
The number of heads to use for the `TextTimeEmbedding` layer.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
_supports_gradient_checkpointing = True
|
| 136 |
+
|
| 137 |
+
@register_to_config
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
in_channels: int = 4,
|
| 141 |
+
conditioning_channels: int = 5,
|
| 142 |
+
flip_sin_to_cos: bool = True,
|
| 143 |
+
freq_shift: int = 0,
|
| 144 |
+
down_block_types: Tuple[str, ...] = (
|
| 145 |
+
"CrossAttnDownBlock2D",
|
| 146 |
+
"CrossAttnDownBlock2D",
|
| 147 |
+
"CrossAttnDownBlock2D",
|
| 148 |
+
"DownBlock2D",
|
| 149 |
+
),
|
| 150 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
| 151 |
+
up_block_types: Tuple[str, ...] = (
|
| 152 |
+
"UpBlock2D",
|
| 153 |
+
"CrossAttnUpBlock2D",
|
| 154 |
+
"CrossAttnUpBlock2D",
|
| 155 |
+
"CrossAttnUpBlock2D",
|
| 156 |
+
),
|
| 157 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 158 |
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
| 159 |
+
layers_per_block: int = 2,
|
| 160 |
+
downsample_padding: int = 1,
|
| 161 |
+
mid_block_scale_factor: float = 1,
|
| 162 |
+
act_fn: str = "silu",
|
| 163 |
+
norm_num_groups: Optional[int] = 32,
|
| 164 |
+
norm_eps: float = 1e-5,
|
| 165 |
+
cross_attention_dim: int = 1280,
|
| 166 |
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
| 167 |
+
encoder_hid_dim: Optional[int] = None,
|
| 168 |
+
encoder_hid_dim_type: Optional[str] = None,
|
| 169 |
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
| 170 |
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
| 171 |
+
use_linear_projection: bool = False,
|
| 172 |
+
class_embed_type: Optional[str] = None,
|
| 173 |
+
addition_embed_type: Optional[str] = None,
|
| 174 |
+
addition_time_embed_dim: Optional[int] = None,
|
| 175 |
+
num_class_embeds: Optional[int] = None,
|
| 176 |
+
upcast_attention: bool = False,
|
| 177 |
+
resnet_time_scale_shift: str = "default",
|
| 178 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
| 179 |
+
brushnet_conditioning_channel_order: str = "rgb",
|
| 180 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
| 181 |
+
global_pool_conditions: bool = False,
|
| 182 |
+
addition_embed_type_num_heads: int = 64,
|
| 183 |
+
):
|
| 184 |
+
super().__init__()
|
| 185 |
+
|
| 186 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
| 187 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
| 188 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
| 189 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
| 190 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
| 191 |
+
# which is why we correct for the naming here.
|
| 192 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
| 193 |
+
|
| 194 |
+
# Check inputs
|
| 195 |
+
if len(down_block_types) != len(up_block_types):
|
| 196 |
+
raise ValueError(
|
| 197 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
if len(block_out_channels) != len(down_block_types):
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 211 |
+
raise ValueError(
|
| 212 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
if isinstance(transformer_layers_per_block, int):
|
| 216 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 217 |
+
|
| 218 |
+
# input
|
| 219 |
+
conv_in_kernel = 3
|
| 220 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
| 221 |
+
self.conv_in_condition = nn.Conv2d(
|
| 222 |
+
in_channels + conditioning_channels,
|
| 223 |
+
block_out_channels[0],
|
| 224 |
+
kernel_size=conv_in_kernel,
|
| 225 |
+
padding=conv_in_padding,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# time
|
| 229 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 230 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 231 |
+
timestep_input_dim = block_out_channels[0]
|
| 232 |
+
self.time_embedding = TimestepEmbedding(
|
| 233 |
+
timestep_input_dim,
|
| 234 |
+
time_embed_dim,
|
| 235 |
+
act_fn=act_fn,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
| 239 |
+
encoder_hid_dim_type = "text_proj"
|
| 240 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
| 241 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
| 242 |
+
|
| 243 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
| 244 |
+
raise ValueError(
|
| 245 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
if encoder_hid_dim_type == "text_proj":
|
| 249 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
| 250 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
| 251 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 252 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 253 |
+
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
| 254 |
+
self.encoder_hid_proj = TextImageProjection(
|
| 255 |
+
text_embed_dim=encoder_hid_dim,
|
| 256 |
+
image_embed_dim=cross_attention_dim,
|
| 257 |
+
cross_attention_dim=cross_attention_dim,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
elif encoder_hid_dim_type is not None:
|
| 261 |
+
raise ValueError(
|
| 262 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
self.encoder_hid_proj = None
|
| 266 |
+
|
| 267 |
+
# class embedding
|
| 268 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 269 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 270 |
+
elif class_embed_type == "timestep":
|
| 271 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 272 |
+
elif class_embed_type == "identity":
|
| 273 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 274 |
+
elif class_embed_type == "projection":
|
| 275 |
+
if projection_class_embeddings_input_dim is None:
|
| 276 |
+
raise ValueError(
|
| 277 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
| 278 |
+
)
|
| 279 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
| 280 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
| 281 |
+
# 2. it projects from an arbitrary input dimension.
|
| 282 |
+
#
|
| 283 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
| 284 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
| 285 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
| 286 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 287 |
+
else:
|
| 288 |
+
self.class_embedding = None
|
| 289 |
+
|
| 290 |
+
if addition_embed_type == "text":
|
| 291 |
+
if encoder_hid_dim is not None:
|
| 292 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
| 293 |
+
else:
|
| 294 |
+
text_time_embedding_from_dim = cross_attention_dim
|
| 295 |
+
|
| 296 |
+
self.add_embedding = TextTimeEmbedding(
|
| 297 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
| 298 |
+
)
|
| 299 |
+
elif addition_embed_type == "text_image":
|
| 300 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 301 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 302 |
+
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
| 303 |
+
self.add_embedding = TextImageTimeEmbedding(
|
| 304 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
| 305 |
+
)
|
| 306 |
+
elif addition_embed_type == "text_time":
|
| 307 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
| 308 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 309 |
+
|
| 310 |
+
elif addition_embed_type is not None:
|
| 311 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
| 312 |
+
|
| 313 |
+
self.down_blocks = nn.ModuleList([])
|
| 314 |
+
self.brushnet_down_blocks = nn.ModuleList([])
|
| 315 |
+
|
| 316 |
+
if isinstance(only_cross_attention, bool):
|
| 317 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 318 |
+
|
| 319 |
+
if isinstance(attention_head_dim, int):
|
| 320 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
| 321 |
+
|
| 322 |
+
if isinstance(num_attention_heads, int):
|
| 323 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 324 |
+
|
| 325 |
+
# down
|
| 326 |
+
output_channel = block_out_channels[0]
|
| 327 |
+
|
| 328 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 329 |
+
brushnet_block = zero_module(brushnet_block)
|
| 330 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
| 331 |
+
|
| 332 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 333 |
+
input_channel = output_channel
|
| 334 |
+
output_channel = block_out_channels[i]
|
| 335 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 336 |
+
|
| 337 |
+
down_block = get_down_block(
|
| 338 |
+
down_block_type,
|
| 339 |
+
num_layers=layers_per_block,
|
| 340 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 341 |
+
in_channels=input_channel,
|
| 342 |
+
out_channels=output_channel,
|
| 343 |
+
temb_channels=time_embed_dim,
|
| 344 |
+
add_downsample=not is_final_block,
|
| 345 |
+
resnet_eps=norm_eps,
|
| 346 |
+
resnet_act_fn=act_fn,
|
| 347 |
+
resnet_groups=norm_num_groups,
|
| 348 |
+
cross_attention_dim=cross_attention_dim,
|
| 349 |
+
num_attention_heads=num_attention_heads[i],
|
| 350 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 351 |
+
downsample_padding=downsample_padding,
|
| 352 |
+
use_linear_projection=use_linear_projection,
|
| 353 |
+
only_cross_attention=only_cross_attention[i],
|
| 354 |
+
upcast_attention=upcast_attention,
|
| 355 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 356 |
+
)
|
| 357 |
+
self.down_blocks.append(down_block)
|
| 358 |
+
|
| 359 |
+
for _ in range(layers_per_block):
|
| 360 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 361 |
+
brushnet_block = zero_module(brushnet_block)
|
| 362 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
| 363 |
+
|
| 364 |
+
if not is_final_block:
|
| 365 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 366 |
+
brushnet_block = zero_module(brushnet_block)
|
| 367 |
+
self.brushnet_down_blocks.append(brushnet_block)
|
| 368 |
+
|
| 369 |
+
# mid
|
| 370 |
+
mid_block_channel = block_out_channels[-1]
|
| 371 |
+
|
| 372 |
+
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
| 373 |
+
brushnet_block = zero_module(brushnet_block)
|
| 374 |
+
self.brushnet_mid_block = brushnet_block
|
| 375 |
+
|
| 376 |
+
self.mid_block = get_mid_block(
|
| 377 |
+
mid_block_type,
|
| 378 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 379 |
+
in_channels=mid_block_channel,
|
| 380 |
+
temb_channels=time_embed_dim,
|
| 381 |
+
resnet_eps=norm_eps,
|
| 382 |
+
resnet_act_fn=act_fn,
|
| 383 |
+
output_scale_factor=mid_block_scale_factor,
|
| 384 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 385 |
+
cross_attention_dim=cross_attention_dim,
|
| 386 |
+
num_attention_heads=num_attention_heads[-1],
|
| 387 |
+
resnet_groups=norm_num_groups,
|
| 388 |
+
use_linear_projection=use_linear_projection,
|
| 389 |
+
upcast_attention=upcast_attention,
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# count how many layers upsample the images
|
| 393 |
+
self.num_upsamplers = 0
|
| 394 |
+
|
| 395 |
+
# up
|
| 396 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 397 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| 398 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
| 399 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
| 400 |
+
|
| 401 |
+
output_channel = reversed_block_out_channels[0]
|
| 402 |
+
|
| 403 |
+
self.up_blocks = nn.ModuleList([])
|
| 404 |
+
self.brushnet_up_blocks = nn.ModuleList([])
|
| 405 |
+
|
| 406 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 407 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 408 |
+
|
| 409 |
+
prev_output_channel = output_channel
|
| 410 |
+
output_channel = reversed_block_out_channels[i]
|
| 411 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 412 |
+
|
| 413 |
+
# add upsample block for all BUT final layer
|
| 414 |
+
if not is_final_block:
|
| 415 |
+
add_upsample = True
|
| 416 |
+
self.num_upsamplers += 1
|
| 417 |
+
else:
|
| 418 |
+
add_upsample = False
|
| 419 |
+
|
| 420 |
+
up_block = get_up_block(
|
| 421 |
+
up_block_type,
|
| 422 |
+
num_layers=layers_per_block + 1,
|
| 423 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| 424 |
+
in_channels=input_channel,
|
| 425 |
+
out_channels=output_channel,
|
| 426 |
+
prev_output_channel=prev_output_channel,
|
| 427 |
+
temb_channels=time_embed_dim,
|
| 428 |
+
add_upsample=add_upsample,
|
| 429 |
+
resnet_eps=norm_eps,
|
| 430 |
+
resnet_act_fn=act_fn,
|
| 431 |
+
resolution_idx=i,
|
| 432 |
+
resnet_groups=norm_num_groups,
|
| 433 |
+
cross_attention_dim=cross_attention_dim,
|
| 434 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
| 435 |
+
use_linear_projection=use_linear_projection,
|
| 436 |
+
only_cross_attention=only_cross_attention[i],
|
| 437 |
+
upcast_attention=upcast_attention,
|
| 438 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 439 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 440 |
+
)
|
| 441 |
+
self.up_blocks.append(up_block)
|
| 442 |
+
prev_output_channel = output_channel
|
| 443 |
+
|
| 444 |
+
for _ in range(layers_per_block + 1):
|
| 445 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 446 |
+
brushnet_block = zero_module(brushnet_block)
|
| 447 |
+
self.brushnet_up_blocks.append(brushnet_block)
|
| 448 |
+
|
| 449 |
+
if not is_final_block:
|
| 450 |
+
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
| 451 |
+
brushnet_block = zero_module(brushnet_block)
|
| 452 |
+
self.brushnet_up_blocks.append(brushnet_block)
|
| 453 |
+
|
| 454 |
+
@classmethod
|
| 455 |
+
def from_unet(
|
| 456 |
+
cls,
|
| 457 |
+
unet: UNet2DConditionModel,
|
| 458 |
+
brushnet_conditioning_channel_order: str = "rgb",
|
| 459 |
+
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
| 460 |
+
load_weights_from_unet: bool = True,
|
| 461 |
+
conditioning_channels: int = 5,
|
| 462 |
+
):
|
| 463 |
+
r"""
|
| 464 |
+
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
| 465 |
+
|
| 466 |
+
Parameters:
|
| 467 |
+
unet (`UNet2DConditionModel`):
|
| 468 |
+
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
| 469 |
+
where applicable.
|
| 470 |
+
"""
|
| 471 |
+
transformer_layers_per_block = (
|
| 472 |
+
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
| 473 |
+
)
|
| 474 |
+
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
| 475 |
+
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
| 476 |
+
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
| 477 |
+
addition_time_embed_dim = (
|
| 478 |
+
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
brushnet = cls(
|
| 482 |
+
in_channels=unet.config.in_channels,
|
| 483 |
+
conditioning_channels=conditioning_channels,
|
| 484 |
+
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
| 485 |
+
freq_shift=unet.config.freq_shift,
|
| 486 |
+
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
| 487 |
+
down_block_types=[
|
| 488 |
+
"CrossAttnDownBlock2D",
|
| 489 |
+
"CrossAttnDownBlock2D",
|
| 490 |
+
"CrossAttnDownBlock2D",
|
| 491 |
+
"DownBlock2D",
|
| 492 |
+
],
|
| 493 |
+
# mid_block_type='MidBlock2D',
|
| 494 |
+
mid_block_type="UNetMidBlock2DCrossAttn",
|
| 495 |
+
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
| 496 |
+
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
| 497 |
+
only_cross_attention=unet.config.only_cross_attention,
|
| 498 |
+
block_out_channels=unet.config.block_out_channels,
|
| 499 |
+
layers_per_block=unet.config.layers_per_block,
|
| 500 |
+
downsample_padding=unet.config.downsample_padding,
|
| 501 |
+
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
| 502 |
+
act_fn=unet.config.act_fn,
|
| 503 |
+
norm_num_groups=unet.config.norm_num_groups,
|
| 504 |
+
norm_eps=unet.config.norm_eps,
|
| 505 |
+
cross_attention_dim=unet.config.cross_attention_dim,
|
| 506 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 507 |
+
encoder_hid_dim=encoder_hid_dim,
|
| 508 |
+
encoder_hid_dim_type=encoder_hid_dim_type,
|
| 509 |
+
attention_head_dim=unet.config.attention_head_dim,
|
| 510 |
+
num_attention_heads=unet.config.num_attention_heads,
|
| 511 |
+
use_linear_projection=unet.config.use_linear_projection,
|
| 512 |
+
class_embed_type=unet.config.class_embed_type,
|
| 513 |
+
addition_embed_type=addition_embed_type,
|
| 514 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
| 515 |
+
num_class_embeds=unet.config.num_class_embeds,
|
| 516 |
+
upcast_attention=unet.config.upcast_attention,
|
| 517 |
+
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
| 518 |
+
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
| 519 |
+
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
| 520 |
+
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if load_weights_from_unet:
|
| 524 |
+
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
| 525 |
+
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
| 526 |
+
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
| 527 |
+
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
| 528 |
+
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
| 529 |
+
|
| 530 |
+
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
| 531 |
+
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
| 532 |
+
|
| 533 |
+
if brushnet.class_embedding:
|
| 534 |
+
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
| 535 |
+
|
| 536 |
+
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
| 537 |
+
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
| 538 |
+
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
| 539 |
+
|
| 540 |
+
return brushnet.to(unet.dtype)
|
| 541 |
+
|
| 542 |
+
@property
|
| 543 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 544 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 545 |
+
r"""
|
| 546 |
+
Returns:
|
| 547 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 548 |
+
indexed by its weight name.
|
| 549 |
+
"""
|
| 550 |
+
# set recursively
|
| 551 |
+
processors = {}
|
| 552 |
+
|
| 553 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 554 |
+
if hasattr(module, "get_processor"):
|
| 555 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
| 556 |
+
|
| 557 |
+
for sub_name, child in module.named_children():
|
| 558 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 559 |
+
|
| 560 |
+
return processors
|
| 561 |
+
|
| 562 |
+
for name, module in self.named_children():
|
| 563 |
+
fn_recursive_add_processors(name, module, processors)
|
| 564 |
+
|
| 565 |
+
return processors
|
| 566 |
+
|
| 567 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 568 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 569 |
+
r"""
|
| 570 |
+
Sets the attention processor to use to compute attention.
|
| 571 |
+
|
| 572 |
+
Parameters:
|
| 573 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 574 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 575 |
+
for **all** `Attention` layers.
|
| 576 |
+
|
| 577 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 578 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 579 |
+
|
| 580 |
+
"""
|
| 581 |
+
count = len(self.attn_processors.keys())
|
| 582 |
+
|
| 583 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 584 |
+
raise ValueError(
|
| 585 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 586 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 590 |
+
if hasattr(module, "set_processor"):
|
| 591 |
+
if not isinstance(processor, dict):
|
| 592 |
+
module.set_processor(processor)
|
| 593 |
+
else:
|
| 594 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 595 |
+
|
| 596 |
+
for sub_name, child in module.named_children():
|
| 597 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 598 |
+
|
| 599 |
+
for name, module in self.named_children():
|
| 600 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 601 |
+
|
| 602 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
| 603 |
+
def set_default_attn_processor(self):
|
| 604 |
+
"""
|
| 605 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 606 |
+
"""
|
| 607 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 608 |
+
processor = AttnAddedKVProcessor()
|
| 609 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 610 |
+
processor = AttnProcessor()
|
| 611 |
+
else:
|
| 612 |
+
raise ValueError(
|
| 613 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 614 |
+
)
|
| 615 |
+
|
| 616 |
+
self.set_attn_processor(processor)
|
| 617 |
+
|
| 618 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
| 619 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
| 620 |
+
r"""
|
| 621 |
+
Enable sliced attention computation.
|
| 622 |
+
|
| 623 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
| 624 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 628 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
| 629 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
| 630 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 631 |
+
must be a multiple of `slice_size`.
|
| 632 |
+
"""
|
| 633 |
+
sliceable_head_dims = []
|
| 634 |
+
|
| 635 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
| 636 |
+
if hasattr(module, "set_attention_slice"):
|
| 637 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 638 |
+
|
| 639 |
+
for child in module.children():
|
| 640 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
| 641 |
+
|
| 642 |
+
# retrieve number of attention layers
|
| 643 |
+
for module in self.children():
|
| 644 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
| 645 |
+
|
| 646 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
| 647 |
+
|
| 648 |
+
if slice_size == "auto":
|
| 649 |
+
# half the attention head size is usually a good trade-off between
|
| 650 |
+
# speed and memory
|
| 651 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 652 |
+
elif slice_size == "max":
|
| 653 |
+
# make smallest slice possible
|
| 654 |
+
slice_size = num_sliceable_layers * [1]
|
| 655 |
+
|
| 656 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 657 |
+
|
| 658 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 659 |
+
raise ValueError(
|
| 660 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 661 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
for i in range(len(slice_size)):
|
| 665 |
+
size = slice_size[i]
|
| 666 |
+
dim = sliceable_head_dims[i]
|
| 667 |
+
if size is not None and size > dim:
|
| 668 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 669 |
+
|
| 670 |
+
# Recursively walk through all the children.
|
| 671 |
+
# Any children which exposes the set_attention_slice method
|
| 672 |
+
# gets the message
|
| 673 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
| 674 |
+
if hasattr(module, "set_attention_slice"):
|
| 675 |
+
module.set_attention_slice(slice_size.pop())
|
| 676 |
+
|
| 677 |
+
for child in module.children():
|
| 678 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 679 |
+
|
| 680 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 681 |
+
for module in self.children():
|
| 682 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 683 |
+
|
| 684 |
+
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
| 685 |
+
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
| 686 |
+
module.gradient_checkpointing = value
|
| 687 |
+
|
| 688 |
+
def forward(
|
| 689 |
+
self,
|
| 690 |
+
sample: torch.FloatTensor,
|
| 691 |
+
timestep: Union[torch.Tensor, float, int],
|
| 692 |
+
encoder_hidden_states: torch.Tensor,
|
| 693 |
+
brushnet_cond: torch.FloatTensor,
|
| 694 |
+
conditioning_scale: float = 1.0,
|
| 695 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 696 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 697 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 698 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 699 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 700 |
+
guess_mode: bool = False,
|
| 701 |
+
return_dict: bool = True,
|
| 702 |
+
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
| 703 |
+
"""
|
| 704 |
+
The [`BrushNetModel`] forward method.
|
| 705 |
+
|
| 706 |
+
Args:
|
| 707 |
+
sample (`torch.FloatTensor`):
|
| 708 |
+
The noisy input tensor.
|
| 709 |
+
timestep (`Union[torch.Tensor, float, int]`):
|
| 710 |
+
The number of timesteps to denoise an input.
|
| 711 |
+
encoder_hidden_states (`torch.Tensor`):
|
| 712 |
+
The encoder hidden states.
|
| 713 |
+
brushnet_cond (`torch.FloatTensor`):
|
| 714 |
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
| 715 |
+
conditioning_scale (`float`, defaults to `1.0`):
|
| 716 |
+
The scale factor for BrushNet outputs.
|
| 717 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
| 718 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
| 719 |
+
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
| 720 |
+
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
| 721 |
+
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
| 722 |
+
embeddings.
|
| 723 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
| 724 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 725 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 726 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 727 |
+
added_cond_kwargs (`dict`):
|
| 728 |
+
Additional conditions for the Stable Diffusion XL UNet.
|
| 729 |
+
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
| 730 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
| 731 |
+
guess_mode (`bool`, defaults to `False`):
|
| 732 |
+
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
| 733 |
+
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
| 734 |
+
return_dict (`bool`, defaults to `True`):
|
| 735 |
+
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
| 736 |
+
|
| 737 |
+
Returns:
|
| 738 |
+
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
| 739 |
+
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
| 740 |
+
returned where the first element is the sample tensor.
|
| 741 |
+
"""
|
| 742 |
+
# check channel order
|
| 743 |
+
channel_order = self.config.brushnet_conditioning_channel_order
|
| 744 |
+
|
| 745 |
+
if channel_order == "rgb":
|
| 746 |
+
# in rgb order by default
|
| 747 |
+
...
|
| 748 |
+
elif channel_order == "bgr":
|
| 749 |
+
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
| 750 |
+
else:
|
| 751 |
+
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
| 752 |
+
|
| 753 |
+
# prepare attention_mask
|
| 754 |
+
if attention_mask is not None:
|
| 755 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 756 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 757 |
+
|
| 758 |
+
# 1. time
|
| 759 |
+
timesteps = timestep
|
| 760 |
+
if not torch.is_tensor(timesteps):
|
| 761 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 762 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 763 |
+
is_mps = sample.device.type == "mps"
|
| 764 |
+
if isinstance(timestep, float):
|
| 765 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 766 |
+
else:
|
| 767 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 768 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 769 |
+
elif len(timesteps.shape) == 0:
|
| 770 |
+
timesteps = timesteps[None].to(sample.device)
|
| 771 |
+
|
| 772 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 773 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 774 |
+
|
| 775 |
+
t_emb = self.time_proj(timesteps)
|
| 776 |
+
|
| 777 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 778 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 779 |
+
# there might be better ways to encapsulate this.
|
| 780 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 781 |
+
|
| 782 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 783 |
+
aug_emb = None
|
| 784 |
+
|
| 785 |
+
if self.class_embedding is not None:
|
| 786 |
+
if class_labels is None:
|
| 787 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 788 |
+
|
| 789 |
+
if self.config.class_embed_type == "timestep":
|
| 790 |
+
class_labels = self.time_proj(class_labels)
|
| 791 |
+
|
| 792 |
+
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
| 793 |
+
emb = emb + class_emb
|
| 794 |
+
|
| 795 |
+
if self.config.addition_embed_type is not None:
|
| 796 |
+
if self.config.addition_embed_type == "text":
|
| 797 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 798 |
+
|
| 799 |
+
elif self.config.addition_embed_type == "text_time":
|
| 800 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 801 |
+
raise ValueError(
|
| 802 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 803 |
+
)
|
| 804 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 805 |
+
if "time_ids" not in added_cond_kwargs:
|
| 806 |
+
raise ValueError(
|
| 807 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 808 |
+
)
|
| 809 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 810 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 811 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 812 |
+
|
| 813 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 814 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 815 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 816 |
+
|
| 817 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 818 |
+
|
| 819 |
+
# 2. pre-process
|
| 820 |
+
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
|
| 821 |
+
sample = self.conv_in_condition(brushnet_cond)
|
| 822 |
+
|
| 823 |
+
# 3. down
|
| 824 |
+
down_block_res_samples = (sample,)
|
| 825 |
+
for downsample_block in self.down_blocks:
|
| 826 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 827 |
+
sample, res_samples = downsample_block(
|
| 828 |
+
hidden_states=sample,
|
| 829 |
+
temb=emb,
|
| 830 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 831 |
+
attention_mask=attention_mask,
|
| 832 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 833 |
+
)
|
| 834 |
+
else:
|
| 835 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 836 |
+
|
| 837 |
+
down_block_res_samples += res_samples
|
| 838 |
+
|
| 839 |
+
# 4. PaintingNet down blocks
|
| 840 |
+
brushnet_down_block_res_samples = ()
|
| 841 |
+
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
| 842 |
+
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
| 843 |
+
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
| 844 |
+
|
| 845 |
+
# 5. mid
|
| 846 |
+
if self.mid_block is not None:
|
| 847 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| 848 |
+
sample = self.mid_block(
|
| 849 |
+
sample,
|
| 850 |
+
emb,
|
| 851 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 852 |
+
attention_mask=attention_mask,
|
| 853 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 854 |
+
)
|
| 855 |
+
else:
|
| 856 |
+
sample = self.mid_block(sample, emb)
|
| 857 |
+
|
| 858 |
+
# 6. BrushNet mid blocks
|
| 859 |
+
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
| 860 |
+
|
| 861 |
+
# 7. up
|
| 862 |
+
up_block_res_samples = ()
|
| 863 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 864 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 865 |
+
|
| 866 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 867 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 868 |
+
|
| 869 |
+
# if we have not reached the final block and need to forward the
|
| 870 |
+
# upsample size, we do it here
|
| 871 |
+
if not is_final_block:
|
| 872 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 873 |
+
|
| 874 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 875 |
+
sample, up_res_samples = upsample_block(
|
| 876 |
+
hidden_states=sample,
|
| 877 |
+
temb=emb,
|
| 878 |
+
res_hidden_states_tuple=res_samples,
|
| 879 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 880 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 881 |
+
upsample_size=upsample_size,
|
| 882 |
+
attention_mask=attention_mask,
|
| 883 |
+
return_res_samples=True,
|
| 884 |
+
)
|
| 885 |
+
else:
|
| 886 |
+
sample, up_res_samples = upsample_block(
|
| 887 |
+
hidden_states=sample,
|
| 888 |
+
temb=emb,
|
| 889 |
+
res_hidden_states_tuple=res_samples,
|
| 890 |
+
upsample_size=upsample_size,
|
| 891 |
+
return_res_samples=True,
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
up_block_res_samples += up_res_samples
|
| 895 |
+
|
| 896 |
+
# 8. BrushNet up blocks
|
| 897 |
+
brushnet_up_block_res_samples = ()
|
| 898 |
+
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
| 899 |
+
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
| 900 |
+
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
| 901 |
+
|
| 902 |
+
# 6. scaling
|
| 903 |
+
if guess_mode and not self.config.global_pool_conditions:
|
| 904 |
+
scales = torch.logspace(
|
| 905 |
+
-1,
|
| 906 |
+
0,
|
| 907 |
+
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
| 908 |
+
device=sample.device,
|
| 909 |
+
) # 0.1 to 1.0
|
| 910 |
+
scales = scales * conditioning_scale
|
| 911 |
+
|
| 912 |
+
brushnet_down_block_res_samples = [
|
| 913 |
+
sample * scale
|
| 914 |
+
for sample, scale in zip(
|
| 915 |
+
brushnet_down_block_res_samples, scales[: len(brushnet_down_block_res_samples)]
|
| 916 |
+
)
|
| 917 |
+
]
|
| 918 |
+
brushnet_mid_block_res_sample = (
|
| 919 |
+
brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
| 920 |
+
)
|
| 921 |
+
brushnet_up_block_res_samples = [
|
| 922 |
+
sample * scale
|
| 923 |
+
for sample, scale in zip(
|
| 924 |
+
brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples) + 1 :]
|
| 925 |
+
)
|
| 926 |
+
]
|
| 927 |
+
else:
|
| 928 |
+
brushnet_down_block_res_samples = [
|
| 929 |
+
sample * conditioning_scale for sample in brushnet_down_block_res_samples
|
| 930 |
+
]
|
| 931 |
+
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
| 932 |
+
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
| 933 |
+
|
| 934 |
+
if self.config.global_pool_conditions:
|
| 935 |
+
brushnet_down_block_res_samples = [
|
| 936 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
| 937 |
+
]
|
| 938 |
+
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
| 939 |
+
brushnet_up_block_res_samples = [
|
| 940 |
+
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
| 941 |
+
]
|
| 942 |
+
|
| 943 |
+
if not return_dict:
|
| 944 |
+
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
| 945 |
+
|
| 946 |
+
return BrushNetOutput(
|
| 947 |
+
down_block_res_samples=brushnet_down_block_res_samples,
|
| 948 |
+
mid_block_res_sample=brushnet_mid_block_res_sample,
|
| 949 |
+
up_block_res_samples=brushnet_up_block_res_samples,
|
| 950 |
+
)
|
| 951 |
+
|
| 952 |
+
|
| 953 |
+
def zero_module(module):
|
| 954 |
+
for p in module.parameters():
|
| 955 |
+
nn.init.zeros_(p)
|
| 956 |
+
return module
|
ComfyUI-BrushNet/brushnet/brushnet_xl.json
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "BrushNetModel",
|
| 3 |
+
"_diffusers_version": "0.27.0.dev0",
|
| 4 |
+
"_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"addition_embed_type": "text_time",
|
| 7 |
+
"addition_embed_type_num_heads": 64,
|
| 8 |
+
"addition_time_embed_dim": 256,
|
| 9 |
+
"attention_head_dim": [
|
| 10 |
+
5,
|
| 11 |
+
10,
|
| 12 |
+
20
|
| 13 |
+
],
|
| 14 |
+
"block_out_channels": [
|
| 15 |
+
320,
|
| 16 |
+
640,
|
| 17 |
+
1280
|
| 18 |
+
],
|
| 19 |
+
"brushnet_conditioning_channel_order": "rgb",
|
| 20 |
+
"class_embed_type": null,
|
| 21 |
+
"conditioning_channels": 5,
|
| 22 |
+
"conditioning_embedding_out_channels": [
|
| 23 |
+
16,
|
| 24 |
+
32,
|
| 25 |
+
96,
|
| 26 |
+
256
|
| 27 |
+
],
|
| 28 |
+
"cross_attention_dim": 2048,
|
| 29 |
+
"down_block_types": [
|
| 30 |
+
"DownBlock2D",
|
| 31 |
+
"DownBlock2D",
|
| 32 |
+
"DownBlock2D"
|
| 33 |
+
],
|
| 34 |
+
"downsample_padding": 1,
|
| 35 |
+
"encoder_hid_dim": null,
|
| 36 |
+
"encoder_hid_dim_type": null,
|
| 37 |
+
"flip_sin_to_cos": true,
|
| 38 |
+
"freq_shift": 0,
|
| 39 |
+
"global_pool_conditions": false,
|
| 40 |
+
"in_channels": 4,
|
| 41 |
+
"layers_per_block": 2,
|
| 42 |
+
"mid_block_scale_factor": 1,
|
| 43 |
+
"mid_block_type": "MidBlock2D",
|
| 44 |
+
"norm_eps": 1e-05,
|
| 45 |
+
"norm_num_groups": 32,
|
| 46 |
+
"num_attention_heads": null,
|
| 47 |
+
"num_class_embeds": null,
|
| 48 |
+
"only_cross_attention": false,
|
| 49 |
+
"projection_class_embeddings_input_dim": 2816,
|
| 50 |
+
"resnet_time_scale_shift": "default",
|
| 51 |
+
"transformer_layers_per_block": [
|
| 52 |
+
1,
|
| 53 |
+
2,
|
| 54 |
+
10
|
| 55 |
+
],
|
| 56 |
+
"up_block_types": [
|
| 57 |
+
"UpBlock2D",
|
| 58 |
+
"UpBlock2D",
|
| 59 |
+
"UpBlock2D"
|
| 60 |
+
],
|
| 61 |
+
"upcast_attention": null,
|
| 62 |
+
"use_linear_projection": true
|
| 63 |
+
}
|
ComfyUI-BrushNet/brushnet/powerpaint.json
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "BrushNetModel",
|
| 3 |
+
"_diffusers_version": "0.27.2",
|
| 4 |
+
"act_fn": "silu",
|
| 5 |
+
"addition_embed_type": null,
|
| 6 |
+
"addition_embed_type_num_heads": 64,
|
| 7 |
+
"addition_time_embed_dim": null,
|
| 8 |
+
"attention_head_dim": 8,
|
| 9 |
+
"block_out_channels": [
|
| 10 |
+
320,
|
| 11 |
+
640,
|
| 12 |
+
1280,
|
| 13 |
+
1280
|
| 14 |
+
],
|
| 15 |
+
"brushnet_conditioning_channel_order": "rgb",
|
| 16 |
+
"class_embed_type": null,
|
| 17 |
+
"conditioning_channels": 5,
|
| 18 |
+
"conditioning_embedding_out_channels": [
|
| 19 |
+
16,
|
| 20 |
+
32,
|
| 21 |
+
96,
|
| 22 |
+
256
|
| 23 |
+
],
|
| 24 |
+
"cross_attention_dim": 768,
|
| 25 |
+
"down_block_types": [
|
| 26 |
+
"CrossAttnDownBlock2D",
|
| 27 |
+
"CrossAttnDownBlock2D",
|
| 28 |
+
"CrossAttnDownBlock2D",
|
| 29 |
+
"DownBlock2D"
|
| 30 |
+
],
|
| 31 |
+
"downsample_padding": 1,
|
| 32 |
+
"encoder_hid_dim": null,
|
| 33 |
+
"encoder_hid_dim_type": null,
|
| 34 |
+
"flip_sin_to_cos": true,
|
| 35 |
+
"freq_shift": 0,
|
| 36 |
+
"global_pool_conditions": false,
|
| 37 |
+
"in_channels": 4,
|
| 38 |
+
"layers_per_block": 2,
|
| 39 |
+
"mid_block_scale_factor": 1,
|
| 40 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 41 |
+
"norm_eps": 1e-05,
|
| 42 |
+
"norm_num_groups": 32,
|
| 43 |
+
"num_attention_heads": null,
|
| 44 |
+
"num_class_embeds": null,
|
| 45 |
+
"only_cross_attention": false,
|
| 46 |
+
"projection_class_embeddings_input_dim": null,
|
| 47 |
+
"resnet_time_scale_shift": "default",
|
| 48 |
+
"transformer_layers_per_block": 1,
|
| 49 |
+
"up_block_types": [
|
| 50 |
+
"UpBlock2D",
|
| 51 |
+
"CrossAttnUpBlock2D",
|
| 52 |
+
"CrossAttnUpBlock2D",
|
| 53 |
+
"CrossAttnUpBlock2D"
|
| 54 |
+
],
|
| 55 |
+
"upcast_attention": false,
|
| 56 |
+
"use_linear_projection": false
|
| 57 |
+
}
|
ComfyUI-BrushNet/brushnet/powerpaint_utils.py
ADDED
|
@@ -0,0 +1,496 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from transformers import CLIPTokenizer
|
| 7 |
+
from typing import Any, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
class TokenizerWrapper:
|
| 10 |
+
"""Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
|
| 11 |
+
currently. This wrapper is modified from https://github.com/huggingface/dif
|
| 12 |
+
fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
|
| 13 |
+
py#L358 # noqa.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
from_pretrained (Union[str, os.PathLike], optional): The *model id*
|
| 17 |
+
of a pretrained model or a path to a *directory* containing
|
| 18 |
+
model weights and config. Defaults to None.
|
| 19 |
+
from_config (Union[str, os.PathLike], optional): The *model id*
|
| 20 |
+
of a pretrained model or a path to a *directory* containing
|
| 21 |
+
model weights and config. Defaults to None.
|
| 22 |
+
|
| 23 |
+
*args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
|
| 24 |
+
will be passed to `from_pretrained` function. Otherwise, *args
|
| 25 |
+
and **kwargs will be used to initialize the model by
|
| 26 |
+
`self._module_cls(*args, **kwargs)`.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, tokenizer: CLIPTokenizer):
|
| 30 |
+
self.wrapped = tokenizer
|
| 31 |
+
self.token_map = {}
|
| 32 |
+
|
| 33 |
+
def __getattr__(self, name: str) -> Any:
|
| 34 |
+
if name in self.__dict__:
|
| 35 |
+
return getattr(self, name)
|
| 36 |
+
#if name == "wrapped":
|
| 37 |
+
# return getattr(self, 'wrapped')#super().__getattr__("wrapped")
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
return getattr(self.wrapped, name)
|
| 41 |
+
except AttributeError:
|
| 42 |
+
raise AttributeError(
|
| 43 |
+
"'name' cannot be found in both "
|
| 44 |
+
f"'{self.__class__.__name__}' and "
|
| 45 |
+
f"'{self.__class__.__name__}.tokenizer'."
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
|
| 49 |
+
"""Attempt to add tokens to the tokenizer.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
tokens (Union[str, List[str]]): The tokens to be added.
|
| 53 |
+
"""
|
| 54 |
+
num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
|
| 55 |
+
assert num_added_tokens != 0, (
|
| 56 |
+
f"The tokenizer already contains the token {tokens}. Please pass "
|
| 57 |
+
"a different `placeholder_token` that is not already in the "
|
| 58 |
+
"tokenizer."
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def get_token_info(self, token: str) -> dict:
|
| 62 |
+
"""Get the information of a token, including its start and end index in
|
| 63 |
+
the current tokenizer.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
token (str): The token to be queried.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
dict: The information of the token, including its start and end
|
| 70 |
+
index in current tokenizer.
|
| 71 |
+
"""
|
| 72 |
+
token_ids = self.__call__(token).input_ids
|
| 73 |
+
start, end = token_ids[1], token_ids[-2] + 1
|
| 74 |
+
return {"name": token, "start": start, "end": end}
|
| 75 |
+
|
| 76 |
+
def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
|
| 77 |
+
"""Add placeholder tokens to the tokenizer.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
placeholder_token (str): The placeholder token to be added.
|
| 81 |
+
num_vec_per_token (int, optional): The number of vectors of
|
| 82 |
+
the added placeholder token.
|
| 83 |
+
*args, **kwargs: The arguments for `self.wrapped.add_tokens`.
|
| 84 |
+
"""
|
| 85 |
+
output = []
|
| 86 |
+
if num_vec_per_token == 1:
|
| 87 |
+
self.try_adding_tokens(placeholder_token, *args, **kwargs)
|
| 88 |
+
output.append(placeholder_token)
|
| 89 |
+
else:
|
| 90 |
+
output = []
|
| 91 |
+
for i in range(num_vec_per_token):
|
| 92 |
+
ith_token = placeholder_token + f"_{i}"
|
| 93 |
+
self.try_adding_tokens(ith_token, *args, **kwargs)
|
| 94 |
+
output.append(ith_token)
|
| 95 |
+
|
| 96 |
+
for token in self.token_map:
|
| 97 |
+
if token in placeholder_token:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
f"The tokenizer already has placeholder token {token} "
|
| 100 |
+
f"that can get confused with {placeholder_token} "
|
| 101 |
+
"keep placeholder tokens independent"
|
| 102 |
+
)
|
| 103 |
+
self.token_map[placeholder_token] = output
|
| 104 |
+
|
| 105 |
+
def replace_placeholder_tokens_in_text(
|
| 106 |
+
self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
|
| 107 |
+
) -> Union[str, List[str]]:
|
| 108 |
+
"""Replace the keywords in text with placeholder tokens. This function
|
| 109 |
+
will be called in `self.__call__` and `self.encode`.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
text (Union[str, List[str]]): The text to be processed.
|
| 113 |
+
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
| 114 |
+
Defaults to False.
|
| 115 |
+
prop_tokens_to_load (float, optional): The proportion of tokens to
|
| 116 |
+
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
Union[str, List[str]]: The processed text.
|
| 120 |
+
"""
|
| 121 |
+
if isinstance(text, list):
|
| 122 |
+
output = []
|
| 123 |
+
for i in range(len(text)):
|
| 124 |
+
output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
|
| 125 |
+
return output
|
| 126 |
+
|
| 127 |
+
for placeholder_token in self.token_map:
|
| 128 |
+
if placeholder_token in text:
|
| 129 |
+
tokens = self.token_map[placeholder_token]
|
| 130 |
+
tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
|
| 131 |
+
if vector_shuffle:
|
| 132 |
+
tokens = copy.copy(tokens)
|
| 133 |
+
random.shuffle(tokens)
|
| 134 |
+
text = text.replace(placeholder_token, " ".join(tokens))
|
| 135 |
+
return text
|
| 136 |
+
|
| 137 |
+
def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
|
| 138 |
+
"""Replace the placeholder tokens in text with the original keywords.
|
| 139 |
+
This function will be called in `self.decode`.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
text (Union[str, List[str]]): The text to be processed.
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
Union[str, List[str]]: The processed text.
|
| 146 |
+
"""
|
| 147 |
+
if isinstance(text, list):
|
| 148 |
+
output = []
|
| 149 |
+
for i in range(len(text)):
|
| 150 |
+
output.append(self.replace_text_with_placeholder_tokens(text[i]))
|
| 151 |
+
return output
|
| 152 |
+
|
| 153 |
+
for placeholder_token, tokens in self.token_map.items():
|
| 154 |
+
merged_tokens = " ".join(tokens)
|
| 155 |
+
if merged_tokens in text:
|
| 156 |
+
text = text.replace(merged_tokens, placeholder_token)
|
| 157 |
+
return text
|
| 158 |
+
|
| 159 |
+
def __call__(
|
| 160 |
+
self,
|
| 161 |
+
text: Union[str, List[str]],
|
| 162 |
+
*args,
|
| 163 |
+
vector_shuffle: bool = False,
|
| 164 |
+
prop_tokens_to_load: float = 1.0,
|
| 165 |
+
**kwargs,
|
| 166 |
+
):
|
| 167 |
+
"""The call function of the wrapper.
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
text (Union[str, List[str]]): The text to be tokenized.
|
| 171 |
+
vector_shuffle (bool, optional): Whether to shuffle the vectors.
|
| 172 |
+
Defaults to False.
|
| 173 |
+
prop_tokens_to_load (float, optional): The proportion of tokens to
|
| 174 |
+
be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
|
| 175 |
+
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
| 176 |
+
"""
|
| 177 |
+
replaced_text = self.replace_placeholder_tokens_in_text(
|
| 178 |
+
text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return self.wrapped.__call__(replaced_text, *args, **kwargs)
|
| 182 |
+
|
| 183 |
+
def encode(self, text: Union[str, List[str]], *args, **kwargs):
|
| 184 |
+
"""Encode the passed text to token index.
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
text (Union[str, List[str]]): The text to be encode.
|
| 188 |
+
*args, **kwargs: The arguments for `self.wrapped.__call__`.
|
| 189 |
+
"""
|
| 190 |
+
replaced_text = self.replace_placeholder_tokens_in_text(text)
|
| 191 |
+
return self.wrapped(replaced_text, *args, **kwargs)
|
| 192 |
+
|
| 193 |
+
def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
|
| 194 |
+
"""Decode the token index to text.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
token_ids: The token index to be decoded.
|
| 198 |
+
return_raw: Whether keep the placeholder token in the text.
|
| 199 |
+
Defaults to False.
|
| 200 |
+
*args, **kwargs: The arguments for `self.wrapped.decode`.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Union[str, List[str]]: The decoded text.
|
| 204 |
+
"""
|
| 205 |
+
text = self.wrapped.decode(token_ids, *args, **kwargs)
|
| 206 |
+
if return_raw:
|
| 207 |
+
return text
|
| 208 |
+
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
| 209 |
+
return replaced_text
|
| 210 |
+
|
| 211 |
+
def __repr__(self):
|
| 212 |
+
"""The representation of the wrapper."""
|
| 213 |
+
s = super().__repr__()
|
| 214 |
+
prefix = f"Wrapped Module Class: {self._module_cls}\n"
|
| 215 |
+
prefix += f"Wrapped Module Name: {self._module_name}\n"
|
| 216 |
+
if self._from_pretrained:
|
| 217 |
+
prefix += f"From Pretrained: {self._from_pretrained}\n"
|
| 218 |
+
s = prefix + s
|
| 219 |
+
return s
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
class EmbeddingLayerWithFixes(nn.Module):
|
| 223 |
+
"""The revised embedding layer to support external embeddings. This design
|
| 224 |
+
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
|
| 225 |
+
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
|
| 226 |
+
jack.py#L224 # noqa.
|
| 227 |
+
|
| 228 |
+
Args:
|
| 229 |
+
wrapped (nn.Emebdding): The embedding layer to be wrapped.
|
| 230 |
+
external_embeddings (Union[dict, List[dict]], optional): The external
|
| 231 |
+
embeddings added to this layer. Defaults to None.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.wrapped = wrapped
|
| 237 |
+
self.num_embeddings = wrapped.weight.shape[0]
|
| 238 |
+
|
| 239 |
+
self.external_embeddings = []
|
| 240 |
+
if external_embeddings:
|
| 241 |
+
self.add_embeddings(external_embeddings)
|
| 242 |
+
|
| 243 |
+
self.trainable_embeddings = nn.ParameterDict()
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def weight(self):
|
| 247 |
+
"""Get the weight of wrapped embedding layer."""
|
| 248 |
+
return self.wrapped.weight
|
| 249 |
+
|
| 250 |
+
def check_duplicate_names(self, embeddings: List[dict]):
|
| 251 |
+
"""Check whether duplicate names exist in list of 'external
|
| 252 |
+
embeddings'.
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
embeddings (List[dict]): A list of embedding to be check.
|
| 256 |
+
"""
|
| 257 |
+
names = [emb["name"] for emb in embeddings]
|
| 258 |
+
assert len(names) == len(set(names)), (
|
| 259 |
+
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
def check_ids_overlap(self, embeddings):
|
| 263 |
+
"""Check whether overlap exist in token ids of 'external_embeddings'.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
embeddings (List[dict]): A list of embedding to be check.
|
| 267 |
+
"""
|
| 268 |
+
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
|
| 269 |
+
ids_range.sort() # sort by 'start'
|
| 270 |
+
# check if 'end' has overlapping
|
| 271 |
+
for idx in range(len(ids_range) - 1):
|
| 272 |
+
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
|
| 273 |
+
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
|
| 274 |
+
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
|
| 278 |
+
"""Add external embeddings to this layer.
|
| 279 |
+
|
| 280 |
+
Use case:
|
| 281 |
+
|
| 282 |
+
>>> 1. Add token to tokenizer and get the token id.
|
| 283 |
+
>>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
|
| 284 |
+
>>> # 'how much' in kiswahili
|
| 285 |
+
>>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
|
| 286 |
+
>>>
|
| 287 |
+
>>> 2. Add external embeddings to the model.
|
| 288 |
+
>>> new_embedding = {
|
| 289 |
+
>>> 'name': 'ngapi', # 'how much' in kiswahili
|
| 290 |
+
>>> 'embedding': torch.ones(1, 15) * 4,
|
| 291 |
+
>>> 'start': tokenizer.get_token_info('kwaheri')['start'],
|
| 292 |
+
>>> 'end': tokenizer.get_token_info('kwaheri')['end'],
|
| 293 |
+
>>> 'trainable': False # if True, will registry as a parameter
|
| 294 |
+
>>> }
|
| 295 |
+
>>> embedding_layer = nn.Embedding(10, 15)
|
| 296 |
+
>>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
|
| 297 |
+
>>> embedding_layer_wrapper.add_embeddings(new_embedding)
|
| 298 |
+
>>>
|
| 299 |
+
>>> 3. Forward tokenizer and embedding layer!
|
| 300 |
+
>>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
|
| 301 |
+
>>> input_ids = tokenizer(
|
| 302 |
+
>>> input_text, padding='max_length', truncation=True,
|
| 303 |
+
>>> return_tensors='pt')['input_ids']
|
| 304 |
+
>>> out_feat = embedding_layer_wrapper(input_ids)
|
| 305 |
+
>>>
|
| 306 |
+
>>> 4. Let's validate the result!
|
| 307 |
+
>>> assert (out_feat[0, 3: 7] == 2.3).all()
|
| 308 |
+
>>> assert (out_feat[2, 5: 9] == 2.3).all()
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
embeddings (Union[dict, list[dict]]): The external embeddings to
|
| 312 |
+
be added. Each dict must contain the following 4 fields: 'name'
|
| 313 |
+
(the name of this embedding), 'embedding' (the embedding
|
| 314 |
+
tensor), 'start' (the start token id of this embedding), 'end'
|
| 315 |
+
(the end token id of this embedding). For example:
|
| 316 |
+
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
|
| 317 |
+
"""
|
| 318 |
+
if isinstance(embeddings, dict):
|
| 319 |
+
embeddings = [embeddings]
|
| 320 |
+
|
| 321 |
+
self.external_embeddings += embeddings
|
| 322 |
+
self.check_duplicate_names(self.external_embeddings)
|
| 323 |
+
self.check_ids_overlap(self.external_embeddings)
|
| 324 |
+
|
| 325 |
+
# set for trainable
|
| 326 |
+
added_trainable_emb_info = []
|
| 327 |
+
for embedding in embeddings:
|
| 328 |
+
trainable = embedding.get("trainable", False)
|
| 329 |
+
if trainable:
|
| 330 |
+
name = embedding["name"]
|
| 331 |
+
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
|
| 332 |
+
self.trainable_embeddings[name] = embedding["embedding"]
|
| 333 |
+
added_trainable_emb_info.append(name)
|
| 334 |
+
|
| 335 |
+
added_emb_info = [emb["name"] for emb in embeddings]
|
| 336 |
+
added_emb_info = ", ".join(added_emb_info)
|
| 337 |
+
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
|
| 338 |
+
|
| 339 |
+
if added_trainable_emb_info:
|
| 340 |
+
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
|
| 341 |
+
print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
|
| 342 |
+
|
| 343 |
+
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 344 |
+
"""Replace external input ids to 0.
|
| 345 |
+
|
| 346 |
+
Args:
|
| 347 |
+
input_ids (torch.Tensor): The input ids to be replaced.
|
| 348 |
+
|
| 349 |
+
Returns:
|
| 350 |
+
torch.Tensor: The replaced input ids.
|
| 351 |
+
"""
|
| 352 |
+
input_ids_fwd = input_ids.clone()
|
| 353 |
+
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
|
| 354 |
+
return input_ids_fwd
|
| 355 |
+
|
| 356 |
+
def replace_embeddings(
|
| 357 |
+
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
|
| 358 |
+
) -> torch.Tensor:
|
| 359 |
+
"""Replace external embedding to the embedding layer. Noted that, in
|
| 360 |
+
this function we use `torch.cat` to avoid inplace modification.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
input_ids (torch.Tensor): The original token ids. Shape like
|
| 364 |
+
[LENGTH, ].
|
| 365 |
+
embedding (torch.Tensor): The embedding of token ids after
|
| 366 |
+
`replace_input_ids` function.
|
| 367 |
+
external_embedding (dict): The external embedding to be replaced.
|
| 368 |
+
|
| 369 |
+
Returns:
|
| 370 |
+
torch.Tensor: The replaced embedding.
|
| 371 |
+
"""
|
| 372 |
+
new_embedding = []
|
| 373 |
+
|
| 374 |
+
name = external_embedding["name"]
|
| 375 |
+
start = external_embedding["start"]
|
| 376 |
+
end = external_embedding["end"]
|
| 377 |
+
target_ids_to_replace = [i for i in range(start, end)]
|
| 378 |
+
ext_emb = external_embedding["embedding"]
|
| 379 |
+
|
| 380 |
+
# do not need to replace
|
| 381 |
+
if not (input_ids == start).any():
|
| 382 |
+
return embedding
|
| 383 |
+
|
| 384 |
+
# start replace
|
| 385 |
+
s_idx, e_idx = 0, 0
|
| 386 |
+
while e_idx < len(input_ids):
|
| 387 |
+
if input_ids[e_idx] == start:
|
| 388 |
+
if e_idx != 0:
|
| 389 |
+
# add embedding do not need to replace
|
| 390 |
+
new_embedding.append(embedding[s_idx:e_idx])
|
| 391 |
+
|
| 392 |
+
# check if the next embedding need to replace is valid
|
| 393 |
+
actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]]
|
| 394 |
+
assert actually_ids_to_replace == target_ids_to_replace, (
|
| 395 |
+
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
|
| 396 |
+
f"Expect '{target_ids_to_replace}' for embedding "
|
| 397 |
+
f"'{name}' but found '{actually_ids_to_replace}'."
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
new_embedding.append(ext_emb)
|
| 401 |
+
|
| 402 |
+
s_idx = e_idx + end - start
|
| 403 |
+
e_idx = s_idx + 1
|
| 404 |
+
else:
|
| 405 |
+
e_idx += 1
|
| 406 |
+
|
| 407 |
+
if e_idx == len(input_ids):
|
| 408 |
+
new_embedding.append(embedding[s_idx:e_idx])
|
| 409 |
+
|
| 410 |
+
return torch.cat(new_embedding, dim=0)
|
| 411 |
+
|
| 412 |
+
def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None):
|
| 413 |
+
"""The forward function.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
|
| 417 |
+
[LENGTH, ].
|
| 418 |
+
external_embeddings (Optional[List[dict]]): The external
|
| 419 |
+
embeddings. If not passed, only `self.external_embeddings`
|
| 420 |
+
will be used. Defaults to None.
|
| 421 |
+
|
| 422 |
+
input_ids: shape like [bz, LENGTH] or [LENGTH].
|
| 423 |
+
"""
|
| 424 |
+
assert input_ids.ndim in [1, 2]
|
| 425 |
+
if input_ids.ndim == 1:
|
| 426 |
+
input_ids = input_ids.unsqueeze(0)
|
| 427 |
+
|
| 428 |
+
if external_embeddings is None and not self.external_embeddings:
|
| 429 |
+
return self.wrapped(input_ids)
|
| 430 |
+
|
| 431 |
+
input_ids_fwd = self.replace_input_ids(input_ids)
|
| 432 |
+
inputs_embeds = self.wrapped(input_ids_fwd)
|
| 433 |
+
|
| 434 |
+
vecs = []
|
| 435 |
+
|
| 436 |
+
if external_embeddings is None:
|
| 437 |
+
external_embeddings = []
|
| 438 |
+
elif isinstance(external_embeddings, dict):
|
| 439 |
+
external_embeddings = [external_embeddings]
|
| 440 |
+
embeddings = self.external_embeddings + external_embeddings
|
| 441 |
+
|
| 442 |
+
for input_id, embedding in zip(input_ids, inputs_embeds):
|
| 443 |
+
new_embedding = embedding
|
| 444 |
+
for external_embedding in embeddings:
|
| 445 |
+
new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
|
| 446 |
+
vecs.append(new_embedding)
|
| 447 |
+
|
| 448 |
+
return torch.stack(vecs)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def add_tokens(
|
| 453 |
+
tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1
|
| 454 |
+
):
|
| 455 |
+
"""Add token for training.
|
| 456 |
+
|
| 457 |
+
# TODO: support add tokens as dict, then we can load pretrained tokens.
|
| 458 |
+
"""
|
| 459 |
+
if initialize_tokens is not None:
|
| 460 |
+
assert len(initialize_tokens) == len(
|
| 461 |
+
placeholder_tokens
|
| 462 |
+
), "placeholder_token should be the same length as initialize_token"
|
| 463 |
+
for ii in range(len(placeholder_tokens)):
|
| 464 |
+
tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
|
| 465 |
+
|
| 466 |
+
# text_encoder.set_embedding_layer()
|
| 467 |
+
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
| 468 |
+
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
|
| 469 |
+
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
| 470 |
+
|
| 471 |
+
assert embedding_layer is not None, (
|
| 472 |
+
"Do not support get embedding layer for current text encoder. " "Please check your configuration."
|
| 473 |
+
)
|
| 474 |
+
initialize_embedding = []
|
| 475 |
+
if initialize_tokens is not None:
|
| 476 |
+
for ii in range(len(placeholder_tokens)):
|
| 477 |
+
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
|
| 478 |
+
temp_embedding = embedding_layer.weight[init_id]
|
| 479 |
+
initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
|
| 480 |
+
else:
|
| 481 |
+
for ii in range(len(placeholder_tokens)):
|
| 482 |
+
init_id = tokenizer("a").input_ids[1]
|
| 483 |
+
temp_embedding = embedding_layer.weight[init_id]
|
| 484 |
+
len_emb = temp_embedding.shape[0]
|
| 485 |
+
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
|
| 486 |
+
initialize_embedding.append(init_weight)
|
| 487 |
+
|
| 488 |
+
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
|
| 489 |
+
|
| 490 |
+
token_info_all = []
|
| 491 |
+
for ii in range(len(placeholder_tokens)):
|
| 492 |
+
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
|
| 493 |
+
token_info["embedding"] = initialize_embedding[ii]
|
| 494 |
+
token_info["trainable"] = True
|
| 495 |
+
token_info_all.append(token_info)
|
| 496 |
+
embedding_layer.add_embeddings(token_info_all)
|
ComfyUI-BrushNet/brushnet/unet_2d_blocks.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ComfyUI-BrushNet/brushnet/unet_2d_condition.py
ADDED
|
@@ -0,0 +1,1355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from dataclasses import dataclass
|
| 15 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.utils.checkpoint
|
| 20 |
+
|
| 21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
| 23 |
+
from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
| 24 |
+
from diffusers.models.activations import get_activation
|
| 25 |
+
from diffusers.models.attention_processor import (
|
| 26 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 27 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 28 |
+
Attention,
|
| 29 |
+
AttentionProcessor,
|
| 30 |
+
AttnAddedKVProcessor,
|
| 31 |
+
AttnProcessor,
|
| 32 |
+
)
|
| 33 |
+
from diffusers.models.embeddings import (
|
| 34 |
+
GaussianFourierProjection,
|
| 35 |
+
GLIGENTextBoundingboxProjection,
|
| 36 |
+
ImageHintTimeEmbedding,
|
| 37 |
+
ImageProjection,
|
| 38 |
+
ImageTimeEmbedding,
|
| 39 |
+
TextImageProjection,
|
| 40 |
+
TextImageTimeEmbedding,
|
| 41 |
+
TextTimeEmbedding,
|
| 42 |
+
TimestepEmbedding,
|
| 43 |
+
Timesteps,
|
| 44 |
+
)
|
| 45 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 46 |
+
from .unet_2d_blocks import (
|
| 47 |
+
get_down_block,
|
| 48 |
+
get_mid_block,
|
| 49 |
+
get_up_block,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass
|
| 57 |
+
class UNet2DConditionOutput(BaseOutput):
|
| 58 |
+
"""
|
| 59 |
+
The output of [`UNet2DConditionModel`].
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 63 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
sample: torch.FloatTensor = None
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
|
| 70 |
+
r"""
|
| 71 |
+
A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
|
| 72 |
+
shaped output.
|
| 73 |
+
|
| 74 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 75 |
+
for all models (such as downloading or saving).
|
| 76 |
+
|
| 77 |
+
Parameters:
|
| 78 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
| 79 |
+
Height and width of input/output sample.
|
| 80 |
+
in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
|
| 81 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
| 82 |
+
center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
|
| 83 |
+
flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
|
| 84 |
+
Whether to flip the sin to cos in the time embedding.
|
| 85 |
+
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
| 86 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
| 87 |
+
The tuple of downsample blocks to use.
|
| 88 |
+
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
| 89 |
+
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
| 90 |
+
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
| 91 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
| 92 |
+
The tuple of upsample blocks to use.
|
| 93 |
+
only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
|
| 94 |
+
Whether to include self-attention in the basic transformer blocks, see
|
| 95 |
+
[`~models.attention.BasicTransformerBlock`].
|
| 96 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
| 97 |
+
The tuple of output channels for each block.
|
| 98 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
| 99 |
+
downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
|
| 100 |
+
mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
|
| 101 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 102 |
+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
|
| 103 |
+
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
|
| 104 |
+
If `None`, normalization and activation layers is skipped in post-processing.
|
| 105 |
+
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
| 106 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
| 107 |
+
The dimension of the cross attention features.
|
| 108 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
| 109 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 110 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 111 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 112 |
+
reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
|
| 113 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
|
| 114 |
+
blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
|
| 115 |
+
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
| 116 |
+
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
| 117 |
+
encoder_hid_dim (`int`, *optional*, defaults to None):
|
| 118 |
+
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
| 119 |
+
dimension to `cross_attention_dim`.
|
| 120 |
+
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
| 121 |
+
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
| 122 |
+
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
| 123 |
+
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
| 124 |
+
num_attention_heads (`int`, *optional*):
|
| 125 |
+
The number of attention heads. If not defined, defaults to `attention_head_dim`
|
| 126 |
+
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
| 127 |
+
for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
|
| 128 |
+
class_embed_type (`str`, *optional*, defaults to `None`):
|
| 129 |
+
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
|
| 130 |
+
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
| 131 |
+
addition_embed_type (`str`, *optional*, defaults to `None`):
|
| 132 |
+
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
| 133 |
+
"text". "text" will use the `TextTimeEmbedding` layer.
|
| 134 |
+
addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
|
| 135 |
+
Dimension for the timestep embeddings.
|
| 136 |
+
num_class_embeds (`int`, *optional*, defaults to `None`):
|
| 137 |
+
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
| 138 |
+
class conditioning with `class_embed_type` equal to `None`.
|
| 139 |
+
time_embedding_type (`str`, *optional*, defaults to `positional`):
|
| 140 |
+
The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
|
| 141 |
+
time_embedding_dim (`int`, *optional*, defaults to `None`):
|
| 142 |
+
An optional override for the dimension of the projected time embedding.
|
| 143 |
+
time_embedding_act_fn (`str`, *optional*, defaults to `None`):
|
| 144 |
+
Optional activation function to use only once on the time embeddings before they are passed to the rest of
|
| 145 |
+
the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
|
| 146 |
+
timestep_post_act (`str`, *optional*, defaults to `None`):
|
| 147 |
+
The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
|
| 148 |
+
time_cond_proj_dim (`int`, *optional*, defaults to `None`):
|
| 149 |
+
The dimension of `cond_proj` layer in the timestep embedding.
|
| 150 |
+
conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
|
| 151 |
+
conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
|
| 152 |
+
projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
|
| 153 |
+
`class_embed_type="projection"`. Required when `class_embed_type="projection"`.
|
| 154 |
+
class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
|
| 155 |
+
embeddings with the class embeddings.
|
| 156 |
+
mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
|
| 157 |
+
Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
|
| 158 |
+
`only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
|
| 159 |
+
`only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
|
| 160 |
+
otherwise.
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
_supports_gradient_checkpointing = True
|
| 164 |
+
|
| 165 |
+
@register_to_config
|
| 166 |
+
def __init__(
|
| 167 |
+
self,
|
| 168 |
+
sample_size: Optional[int] = None,
|
| 169 |
+
in_channels: int = 4,
|
| 170 |
+
out_channels: int = 4,
|
| 171 |
+
center_input_sample: bool = False,
|
| 172 |
+
flip_sin_to_cos: bool = True,
|
| 173 |
+
freq_shift: int = 0,
|
| 174 |
+
down_block_types: Tuple[str] = (
|
| 175 |
+
"CrossAttnDownBlock2D",
|
| 176 |
+
"CrossAttnDownBlock2D",
|
| 177 |
+
"CrossAttnDownBlock2D",
|
| 178 |
+
"DownBlock2D",
|
| 179 |
+
),
|
| 180 |
+
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
| 181 |
+
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
| 182 |
+
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
| 183 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 184 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
| 185 |
+
downsample_padding: int = 1,
|
| 186 |
+
mid_block_scale_factor: float = 1,
|
| 187 |
+
dropout: float = 0.0,
|
| 188 |
+
act_fn: str = "silu",
|
| 189 |
+
norm_num_groups: Optional[int] = 32,
|
| 190 |
+
norm_eps: float = 1e-5,
|
| 191 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1280,
|
| 192 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
| 193 |
+
reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
|
| 194 |
+
encoder_hid_dim: Optional[int] = None,
|
| 195 |
+
encoder_hid_dim_type: Optional[str] = None,
|
| 196 |
+
attention_head_dim: Union[int, Tuple[int]] = 8,
|
| 197 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
|
| 198 |
+
dual_cross_attention: bool = False,
|
| 199 |
+
use_linear_projection: bool = False,
|
| 200 |
+
class_embed_type: Optional[str] = None,
|
| 201 |
+
addition_embed_type: Optional[str] = None,
|
| 202 |
+
addition_time_embed_dim: Optional[int] = None,
|
| 203 |
+
num_class_embeds: Optional[int] = None,
|
| 204 |
+
upcast_attention: bool = False,
|
| 205 |
+
resnet_time_scale_shift: str = "default",
|
| 206 |
+
resnet_skip_time_act: bool = False,
|
| 207 |
+
resnet_out_scale_factor: float = 1.0,
|
| 208 |
+
time_embedding_type: str = "positional",
|
| 209 |
+
time_embedding_dim: Optional[int] = None,
|
| 210 |
+
time_embedding_act_fn: Optional[str] = None,
|
| 211 |
+
timestep_post_act: Optional[str] = None,
|
| 212 |
+
time_cond_proj_dim: Optional[int] = None,
|
| 213 |
+
conv_in_kernel: int = 3,
|
| 214 |
+
conv_out_kernel: int = 3,
|
| 215 |
+
projection_class_embeddings_input_dim: Optional[int] = None,
|
| 216 |
+
attention_type: str = "default",
|
| 217 |
+
class_embeddings_concat: bool = False,
|
| 218 |
+
mid_block_only_cross_attention: Optional[bool] = None,
|
| 219 |
+
cross_attention_norm: Optional[str] = None,
|
| 220 |
+
addition_embed_type_num_heads: int = 64,
|
| 221 |
+
):
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
self.sample_size = sample_size
|
| 225 |
+
|
| 226 |
+
if num_attention_heads is not None:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# If `num_attention_heads` is not defined (which is the case for most models)
|
| 232 |
+
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
| 233 |
+
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
| 234 |
+
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
| 235 |
+
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
| 236 |
+
# which is why we correct for the naming here.
|
| 237 |
+
num_attention_heads = num_attention_heads or attention_head_dim
|
| 238 |
+
|
| 239 |
+
# Check inputs
|
| 240 |
+
self._check_config(
|
| 241 |
+
down_block_types=down_block_types,
|
| 242 |
+
up_block_types=up_block_types,
|
| 243 |
+
only_cross_attention=only_cross_attention,
|
| 244 |
+
block_out_channels=block_out_channels,
|
| 245 |
+
layers_per_block=layers_per_block,
|
| 246 |
+
cross_attention_dim=cross_attention_dim,
|
| 247 |
+
transformer_layers_per_block=transformer_layers_per_block,
|
| 248 |
+
reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
|
| 249 |
+
attention_head_dim=attention_head_dim,
|
| 250 |
+
num_attention_heads=num_attention_heads,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# input
|
| 254 |
+
conv_in_padding = (conv_in_kernel - 1) // 2
|
| 255 |
+
self.conv_in = nn.Conv2d(
|
| 256 |
+
in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
# time
|
| 260 |
+
time_embed_dim, timestep_input_dim = self._set_time_proj(
|
| 261 |
+
time_embedding_type,
|
| 262 |
+
block_out_channels=block_out_channels,
|
| 263 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
| 264 |
+
freq_shift=freq_shift,
|
| 265 |
+
time_embedding_dim=time_embedding_dim,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.time_embedding = TimestepEmbedding(
|
| 269 |
+
timestep_input_dim,
|
| 270 |
+
time_embed_dim,
|
| 271 |
+
act_fn=act_fn,
|
| 272 |
+
post_act_fn=timestep_post_act,
|
| 273 |
+
cond_proj_dim=time_cond_proj_dim,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
self._set_encoder_hid_proj(
|
| 277 |
+
encoder_hid_dim_type,
|
| 278 |
+
cross_attention_dim=cross_attention_dim,
|
| 279 |
+
encoder_hid_dim=encoder_hid_dim,
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# class embedding
|
| 283 |
+
self._set_class_embedding(
|
| 284 |
+
class_embed_type,
|
| 285 |
+
act_fn=act_fn,
|
| 286 |
+
num_class_embeds=num_class_embeds,
|
| 287 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
| 288 |
+
time_embed_dim=time_embed_dim,
|
| 289 |
+
timestep_input_dim=timestep_input_dim,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
self._set_add_embedding(
|
| 293 |
+
addition_embed_type,
|
| 294 |
+
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
| 295 |
+
addition_time_embed_dim=addition_time_embed_dim,
|
| 296 |
+
cross_attention_dim=cross_attention_dim,
|
| 297 |
+
encoder_hid_dim=encoder_hid_dim,
|
| 298 |
+
flip_sin_to_cos=flip_sin_to_cos,
|
| 299 |
+
freq_shift=freq_shift,
|
| 300 |
+
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
| 301 |
+
time_embed_dim=time_embed_dim,
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
if time_embedding_act_fn is None:
|
| 305 |
+
self.time_embed_act = None
|
| 306 |
+
else:
|
| 307 |
+
self.time_embed_act = get_activation(time_embedding_act_fn)
|
| 308 |
+
|
| 309 |
+
self.down_blocks = nn.ModuleList([])
|
| 310 |
+
self.up_blocks = nn.ModuleList([])
|
| 311 |
+
|
| 312 |
+
if isinstance(only_cross_attention, bool):
|
| 313 |
+
if mid_block_only_cross_attention is None:
|
| 314 |
+
mid_block_only_cross_attention = only_cross_attention
|
| 315 |
+
|
| 316 |
+
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
| 317 |
+
|
| 318 |
+
if mid_block_only_cross_attention is None:
|
| 319 |
+
mid_block_only_cross_attention = False
|
| 320 |
+
|
| 321 |
+
if isinstance(num_attention_heads, int):
|
| 322 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 323 |
+
|
| 324 |
+
if isinstance(attention_head_dim, int):
|
| 325 |
+
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
| 326 |
+
|
| 327 |
+
if isinstance(cross_attention_dim, int):
|
| 328 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
| 329 |
+
|
| 330 |
+
if isinstance(layers_per_block, int):
|
| 331 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
| 332 |
+
|
| 333 |
+
if isinstance(transformer_layers_per_block, int):
|
| 334 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 335 |
+
|
| 336 |
+
if class_embeddings_concat:
|
| 337 |
+
# The time embeddings are concatenated with the class embeddings. The dimension of the
|
| 338 |
+
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
|
| 339 |
+
# regular time embeddings
|
| 340 |
+
blocks_time_embed_dim = time_embed_dim * 2
|
| 341 |
+
else:
|
| 342 |
+
blocks_time_embed_dim = time_embed_dim
|
| 343 |
+
|
| 344 |
+
# down
|
| 345 |
+
output_channel = block_out_channels[0]
|
| 346 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 347 |
+
input_channel = output_channel
|
| 348 |
+
output_channel = block_out_channels[i]
|
| 349 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 350 |
+
|
| 351 |
+
down_block = get_down_block(
|
| 352 |
+
down_block_type,
|
| 353 |
+
num_layers=layers_per_block[i],
|
| 354 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 355 |
+
in_channels=input_channel,
|
| 356 |
+
out_channels=output_channel,
|
| 357 |
+
temb_channels=blocks_time_embed_dim,
|
| 358 |
+
add_downsample=not is_final_block,
|
| 359 |
+
resnet_eps=norm_eps,
|
| 360 |
+
resnet_act_fn=act_fn,
|
| 361 |
+
resnet_groups=norm_num_groups,
|
| 362 |
+
cross_attention_dim=cross_attention_dim[i],
|
| 363 |
+
num_attention_heads=num_attention_heads[i],
|
| 364 |
+
downsample_padding=downsample_padding,
|
| 365 |
+
dual_cross_attention=dual_cross_attention,
|
| 366 |
+
use_linear_projection=use_linear_projection,
|
| 367 |
+
only_cross_attention=only_cross_attention[i],
|
| 368 |
+
upcast_attention=upcast_attention,
|
| 369 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 370 |
+
attention_type=attention_type,
|
| 371 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
| 372 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
| 373 |
+
cross_attention_norm=cross_attention_norm,
|
| 374 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 375 |
+
dropout=dropout,
|
| 376 |
+
)
|
| 377 |
+
self.down_blocks.append(down_block)
|
| 378 |
+
|
| 379 |
+
# mid
|
| 380 |
+
self.mid_block = get_mid_block(
|
| 381 |
+
mid_block_type,
|
| 382 |
+
temb_channels=blocks_time_embed_dim,
|
| 383 |
+
in_channels=block_out_channels[-1],
|
| 384 |
+
resnet_eps=norm_eps,
|
| 385 |
+
resnet_act_fn=act_fn,
|
| 386 |
+
resnet_groups=norm_num_groups,
|
| 387 |
+
output_scale_factor=mid_block_scale_factor,
|
| 388 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 389 |
+
num_attention_heads=num_attention_heads[-1],
|
| 390 |
+
cross_attention_dim=cross_attention_dim[-1],
|
| 391 |
+
dual_cross_attention=dual_cross_attention,
|
| 392 |
+
use_linear_projection=use_linear_projection,
|
| 393 |
+
mid_block_only_cross_attention=mid_block_only_cross_attention,
|
| 394 |
+
upcast_attention=upcast_attention,
|
| 395 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 396 |
+
attention_type=attention_type,
|
| 397 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
| 398 |
+
cross_attention_norm=cross_attention_norm,
|
| 399 |
+
attention_head_dim=attention_head_dim[-1],
|
| 400 |
+
dropout=dropout,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# count how many layers upsample the images
|
| 404 |
+
self.num_upsamplers = 0
|
| 405 |
+
|
| 406 |
+
# up
|
| 407 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 408 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| 409 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
| 410 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
| 411 |
+
reversed_transformer_layers_per_block = (
|
| 412 |
+
list(reversed(transformer_layers_per_block))
|
| 413 |
+
if reverse_transformer_layers_per_block is None
|
| 414 |
+
else reverse_transformer_layers_per_block
|
| 415 |
+
)
|
| 416 |
+
only_cross_attention = list(reversed(only_cross_attention))
|
| 417 |
+
|
| 418 |
+
output_channel = reversed_block_out_channels[0]
|
| 419 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 420 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 421 |
+
|
| 422 |
+
prev_output_channel = output_channel
|
| 423 |
+
output_channel = reversed_block_out_channels[i]
|
| 424 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 425 |
+
|
| 426 |
+
# add upsample block for all BUT final layer
|
| 427 |
+
if not is_final_block:
|
| 428 |
+
add_upsample = True
|
| 429 |
+
self.num_upsamplers += 1
|
| 430 |
+
else:
|
| 431 |
+
add_upsample = False
|
| 432 |
+
|
| 433 |
+
up_block = get_up_block(
|
| 434 |
+
up_block_type,
|
| 435 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
| 436 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| 437 |
+
in_channels=input_channel,
|
| 438 |
+
out_channels=output_channel,
|
| 439 |
+
prev_output_channel=prev_output_channel,
|
| 440 |
+
temb_channels=blocks_time_embed_dim,
|
| 441 |
+
add_upsample=add_upsample,
|
| 442 |
+
resnet_eps=norm_eps,
|
| 443 |
+
resnet_act_fn=act_fn,
|
| 444 |
+
resolution_idx=i,
|
| 445 |
+
resnet_groups=norm_num_groups,
|
| 446 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
| 447 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
| 448 |
+
dual_cross_attention=dual_cross_attention,
|
| 449 |
+
use_linear_projection=use_linear_projection,
|
| 450 |
+
only_cross_attention=only_cross_attention[i],
|
| 451 |
+
upcast_attention=upcast_attention,
|
| 452 |
+
resnet_time_scale_shift=resnet_time_scale_shift,
|
| 453 |
+
attention_type=attention_type,
|
| 454 |
+
resnet_skip_time_act=resnet_skip_time_act,
|
| 455 |
+
resnet_out_scale_factor=resnet_out_scale_factor,
|
| 456 |
+
cross_attention_norm=cross_attention_norm,
|
| 457 |
+
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
| 458 |
+
dropout=dropout,
|
| 459 |
+
)
|
| 460 |
+
self.up_blocks.append(up_block)
|
| 461 |
+
prev_output_channel = output_channel
|
| 462 |
+
|
| 463 |
+
# out
|
| 464 |
+
if norm_num_groups is not None:
|
| 465 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 466 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
self.conv_act = get_activation(act_fn)
|
| 470 |
+
|
| 471 |
+
else:
|
| 472 |
+
self.conv_norm_out = None
|
| 473 |
+
self.conv_act = None
|
| 474 |
+
|
| 475 |
+
conv_out_padding = (conv_out_kernel - 1) // 2
|
| 476 |
+
self.conv_out = nn.Conv2d(
|
| 477 |
+
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
|
| 481 |
+
|
| 482 |
+
def _check_config(
|
| 483 |
+
self,
|
| 484 |
+
down_block_types: Tuple[str],
|
| 485 |
+
up_block_types: Tuple[str],
|
| 486 |
+
only_cross_attention: Union[bool, Tuple[bool]],
|
| 487 |
+
block_out_channels: Tuple[int],
|
| 488 |
+
layers_per_block: Union[int, Tuple[int]],
|
| 489 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
| 490 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
| 491 |
+
reverse_transformer_layers_per_block: bool,
|
| 492 |
+
attention_head_dim: int,
|
| 493 |
+
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
| 494 |
+
):
|
| 495 |
+
if len(down_block_types) != len(up_block_types):
|
| 496 |
+
raise ValueError(
|
| 497 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if len(block_out_channels) != len(down_block_types):
|
| 501 |
+
raise ValueError(
|
| 502 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
| 506 |
+
raise ValueError(
|
| 507 |
+
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 511 |
+
raise ValueError(
|
| 512 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
|
| 516 |
+
raise ValueError(
|
| 517 |
+
f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
| 521 |
+
raise ValueError(
|
| 522 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
| 526 |
+
raise ValueError(
|
| 527 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
| 528 |
+
)
|
| 529 |
+
if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
|
| 530 |
+
for layer_number_per_block in transformer_layers_per_block:
|
| 531 |
+
if isinstance(layer_number_per_block, list):
|
| 532 |
+
raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
|
| 533 |
+
|
| 534 |
+
def _set_time_proj(
|
| 535 |
+
self,
|
| 536 |
+
time_embedding_type: str,
|
| 537 |
+
block_out_channels: int,
|
| 538 |
+
flip_sin_to_cos: bool,
|
| 539 |
+
freq_shift: float,
|
| 540 |
+
time_embedding_dim: int,
|
| 541 |
+
) -> Tuple[int, int]:
|
| 542 |
+
if time_embedding_type == "fourier":
|
| 543 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
|
| 544 |
+
if time_embed_dim % 2 != 0:
|
| 545 |
+
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
|
| 546 |
+
self.time_proj = GaussianFourierProjection(
|
| 547 |
+
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
|
| 548 |
+
)
|
| 549 |
+
timestep_input_dim = time_embed_dim
|
| 550 |
+
elif time_embedding_type == "positional":
|
| 551 |
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
| 552 |
+
|
| 553 |
+
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
| 554 |
+
timestep_input_dim = block_out_channels[0]
|
| 555 |
+
else:
|
| 556 |
+
raise ValueError(
|
| 557 |
+
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
return time_embed_dim, timestep_input_dim
|
| 561 |
+
|
| 562 |
+
def _set_encoder_hid_proj(
|
| 563 |
+
self,
|
| 564 |
+
encoder_hid_dim_type: Optional[str],
|
| 565 |
+
cross_attention_dim: Union[int, Tuple[int]],
|
| 566 |
+
encoder_hid_dim: Optional[int],
|
| 567 |
+
):
|
| 568 |
+
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
| 569 |
+
encoder_hid_dim_type = "text_proj"
|
| 570 |
+
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
| 571 |
+
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
| 572 |
+
|
| 573 |
+
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
| 574 |
+
raise ValueError(
|
| 575 |
+
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
if encoder_hid_dim_type == "text_proj":
|
| 579 |
+
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
| 580 |
+
elif encoder_hid_dim_type == "text_image_proj":
|
| 581 |
+
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 582 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 583 |
+
# case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
|
| 584 |
+
self.encoder_hid_proj = TextImageProjection(
|
| 585 |
+
text_embed_dim=encoder_hid_dim,
|
| 586 |
+
image_embed_dim=cross_attention_dim,
|
| 587 |
+
cross_attention_dim=cross_attention_dim,
|
| 588 |
+
)
|
| 589 |
+
elif encoder_hid_dim_type == "image_proj":
|
| 590 |
+
# Kandinsky 2.2
|
| 591 |
+
self.encoder_hid_proj = ImageProjection(
|
| 592 |
+
image_embed_dim=encoder_hid_dim,
|
| 593 |
+
cross_attention_dim=cross_attention_dim,
|
| 594 |
+
)
|
| 595 |
+
elif encoder_hid_dim_type is not None:
|
| 596 |
+
raise ValueError(
|
| 597 |
+
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
| 598 |
+
)
|
| 599 |
+
else:
|
| 600 |
+
self.encoder_hid_proj = None
|
| 601 |
+
|
| 602 |
+
def _set_class_embedding(
|
| 603 |
+
self,
|
| 604 |
+
class_embed_type: Optional[str],
|
| 605 |
+
act_fn: str,
|
| 606 |
+
num_class_embeds: Optional[int],
|
| 607 |
+
projection_class_embeddings_input_dim: Optional[int],
|
| 608 |
+
time_embed_dim: int,
|
| 609 |
+
timestep_input_dim: int,
|
| 610 |
+
):
|
| 611 |
+
if class_embed_type is None and num_class_embeds is not None:
|
| 612 |
+
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
| 613 |
+
elif class_embed_type == "timestep":
|
| 614 |
+
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
|
| 615 |
+
elif class_embed_type == "identity":
|
| 616 |
+
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
| 617 |
+
elif class_embed_type == "projection":
|
| 618 |
+
if projection_class_embeddings_input_dim is None:
|
| 619 |
+
raise ValueError(
|
| 620 |
+
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
| 621 |
+
)
|
| 622 |
+
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
| 623 |
+
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
| 624 |
+
# 2. it projects from an arbitrary input dimension.
|
| 625 |
+
#
|
| 626 |
+
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
| 627 |
+
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
| 628 |
+
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
| 629 |
+
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 630 |
+
elif class_embed_type == "simple_projection":
|
| 631 |
+
if projection_class_embeddings_input_dim is None:
|
| 632 |
+
raise ValueError(
|
| 633 |
+
"`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
|
| 634 |
+
)
|
| 635 |
+
self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
|
| 636 |
+
else:
|
| 637 |
+
self.class_embedding = None
|
| 638 |
+
|
| 639 |
+
def _set_add_embedding(
|
| 640 |
+
self,
|
| 641 |
+
addition_embed_type: str,
|
| 642 |
+
addition_embed_type_num_heads: int,
|
| 643 |
+
addition_time_embed_dim: Optional[int],
|
| 644 |
+
flip_sin_to_cos: bool,
|
| 645 |
+
freq_shift: float,
|
| 646 |
+
cross_attention_dim: Optional[int],
|
| 647 |
+
encoder_hid_dim: Optional[int],
|
| 648 |
+
projection_class_embeddings_input_dim: Optional[int],
|
| 649 |
+
time_embed_dim: int,
|
| 650 |
+
):
|
| 651 |
+
if addition_embed_type == "text":
|
| 652 |
+
if encoder_hid_dim is not None:
|
| 653 |
+
text_time_embedding_from_dim = encoder_hid_dim
|
| 654 |
+
else:
|
| 655 |
+
text_time_embedding_from_dim = cross_attention_dim
|
| 656 |
+
|
| 657 |
+
self.add_embedding = TextTimeEmbedding(
|
| 658 |
+
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
| 659 |
+
)
|
| 660 |
+
elif addition_embed_type == "text_image":
|
| 661 |
+
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
| 662 |
+
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
| 663 |
+
# case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
|
| 664 |
+
self.add_embedding = TextImageTimeEmbedding(
|
| 665 |
+
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
| 666 |
+
)
|
| 667 |
+
elif addition_embed_type == "text_time":
|
| 668 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
| 669 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 670 |
+
elif addition_embed_type == "image":
|
| 671 |
+
# Kandinsky 2.2
|
| 672 |
+
self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
| 673 |
+
elif addition_embed_type == "image_hint":
|
| 674 |
+
# Kandinsky 2.2 ControlNet
|
| 675 |
+
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
| 676 |
+
elif addition_embed_type is not None:
|
| 677 |
+
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
| 678 |
+
|
| 679 |
+
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
|
| 680 |
+
if attention_type in ["gated", "gated-text-image"]:
|
| 681 |
+
positive_len = 768
|
| 682 |
+
if isinstance(cross_attention_dim, int):
|
| 683 |
+
positive_len = cross_attention_dim
|
| 684 |
+
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
| 685 |
+
positive_len = cross_attention_dim[0]
|
| 686 |
+
|
| 687 |
+
feature_type = "text-only" if attention_type == "gated" else "text-image"
|
| 688 |
+
self.position_net = GLIGENTextBoundingboxProjection(
|
| 689 |
+
positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
@property
|
| 693 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 694 |
+
r"""
|
| 695 |
+
Returns:
|
| 696 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 697 |
+
indexed by its weight name.
|
| 698 |
+
"""
|
| 699 |
+
# set recursively
|
| 700 |
+
processors = {}
|
| 701 |
+
|
| 702 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 703 |
+
if hasattr(module, "get_processor"):
|
| 704 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
| 705 |
+
|
| 706 |
+
for sub_name, child in module.named_children():
|
| 707 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 708 |
+
|
| 709 |
+
return processors
|
| 710 |
+
|
| 711 |
+
for name, module in self.named_children():
|
| 712 |
+
fn_recursive_add_processors(name, module, processors)
|
| 713 |
+
|
| 714 |
+
return processors
|
| 715 |
+
|
| 716 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 717 |
+
r"""
|
| 718 |
+
Sets the attention processor to use to compute attention.
|
| 719 |
+
|
| 720 |
+
Parameters:
|
| 721 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 722 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 723 |
+
for **all** `Attention` layers.
|
| 724 |
+
|
| 725 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 726 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 727 |
+
|
| 728 |
+
"""
|
| 729 |
+
count = len(self.attn_processors.keys())
|
| 730 |
+
|
| 731 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 732 |
+
raise ValueError(
|
| 733 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 734 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 735 |
+
)
|
| 736 |
+
|
| 737 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 738 |
+
if hasattr(module, "set_processor"):
|
| 739 |
+
if not isinstance(processor, dict):
|
| 740 |
+
module.set_processor(processor)
|
| 741 |
+
else:
|
| 742 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 743 |
+
|
| 744 |
+
for sub_name, child in module.named_children():
|
| 745 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 746 |
+
|
| 747 |
+
for name, module in self.named_children():
|
| 748 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 749 |
+
|
| 750 |
+
def set_default_attn_processor(self):
|
| 751 |
+
"""
|
| 752 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 753 |
+
"""
|
| 754 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 755 |
+
processor = AttnAddedKVProcessor()
|
| 756 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 757 |
+
processor = AttnProcessor()
|
| 758 |
+
else:
|
| 759 |
+
raise ValueError(
|
| 760 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 761 |
+
)
|
| 762 |
+
|
| 763 |
+
self.set_attn_processor(processor)
|
| 764 |
+
|
| 765 |
+
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
|
| 766 |
+
r"""
|
| 767 |
+
Enable sliced attention computation.
|
| 768 |
+
|
| 769 |
+
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
| 770 |
+
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
| 771 |
+
|
| 772 |
+
Args:
|
| 773 |
+
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
| 774 |
+
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
| 775 |
+
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
| 776 |
+
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
| 777 |
+
must be a multiple of `slice_size`.
|
| 778 |
+
"""
|
| 779 |
+
sliceable_head_dims = []
|
| 780 |
+
|
| 781 |
+
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
| 782 |
+
if hasattr(module, "set_attention_slice"):
|
| 783 |
+
sliceable_head_dims.append(module.sliceable_head_dim)
|
| 784 |
+
|
| 785 |
+
for child in module.children():
|
| 786 |
+
fn_recursive_retrieve_sliceable_dims(child)
|
| 787 |
+
|
| 788 |
+
# retrieve number of attention layers
|
| 789 |
+
for module in self.children():
|
| 790 |
+
fn_recursive_retrieve_sliceable_dims(module)
|
| 791 |
+
|
| 792 |
+
num_sliceable_layers = len(sliceable_head_dims)
|
| 793 |
+
|
| 794 |
+
if slice_size == "auto":
|
| 795 |
+
# half the attention head size is usually a good trade-off between
|
| 796 |
+
# speed and memory
|
| 797 |
+
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
| 798 |
+
elif slice_size == "max":
|
| 799 |
+
# make smallest slice possible
|
| 800 |
+
slice_size = num_sliceable_layers * [1]
|
| 801 |
+
|
| 802 |
+
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
| 803 |
+
|
| 804 |
+
if len(slice_size) != len(sliceable_head_dims):
|
| 805 |
+
raise ValueError(
|
| 806 |
+
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
| 807 |
+
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
for i in range(len(slice_size)):
|
| 811 |
+
size = slice_size[i]
|
| 812 |
+
dim = sliceable_head_dims[i]
|
| 813 |
+
if size is not None and size > dim:
|
| 814 |
+
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
| 815 |
+
|
| 816 |
+
# Recursively walk through all the children.
|
| 817 |
+
# Any children which exposes the set_attention_slice method
|
| 818 |
+
# gets the message
|
| 819 |
+
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
| 820 |
+
if hasattr(module, "set_attention_slice"):
|
| 821 |
+
module.set_attention_slice(slice_size.pop())
|
| 822 |
+
|
| 823 |
+
for child in module.children():
|
| 824 |
+
fn_recursive_set_attention_slice(child, slice_size)
|
| 825 |
+
|
| 826 |
+
reversed_slice_size = list(reversed(slice_size))
|
| 827 |
+
for module in self.children():
|
| 828 |
+
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
| 829 |
+
|
| 830 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 831 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 832 |
+
module.gradient_checkpointing = value
|
| 833 |
+
|
| 834 |
+
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
| 835 |
+
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
| 836 |
+
|
| 837 |
+
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
| 838 |
+
|
| 839 |
+
Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
|
| 840 |
+
are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
|
| 841 |
+
|
| 842 |
+
Args:
|
| 843 |
+
s1 (`float`):
|
| 844 |
+
Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
|
| 845 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
| 846 |
+
s2 (`float`):
|
| 847 |
+
Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
|
| 848 |
+
mitigate the "oversmoothing effect" in the enhanced denoising process.
|
| 849 |
+
b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
|
| 850 |
+
b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
|
| 851 |
+
"""
|
| 852 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 853 |
+
setattr(upsample_block, "s1", s1)
|
| 854 |
+
setattr(upsample_block, "s2", s2)
|
| 855 |
+
setattr(upsample_block, "b1", b1)
|
| 856 |
+
setattr(upsample_block, "b2", b2)
|
| 857 |
+
|
| 858 |
+
def disable_freeu(self):
|
| 859 |
+
"""Disables the FreeU mechanism."""
|
| 860 |
+
freeu_keys = {"s1", "s2", "b1", "b2"}
|
| 861 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 862 |
+
for k in freeu_keys:
|
| 863 |
+
if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
|
| 864 |
+
setattr(upsample_block, k, None)
|
| 865 |
+
|
| 866 |
+
def fuse_qkv_projections(self):
|
| 867 |
+
"""
|
| 868 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 869 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 870 |
+
|
| 871 |
+
<Tip warning={true}>
|
| 872 |
+
|
| 873 |
+
This API is 🧪 experimental.
|
| 874 |
+
|
| 875 |
+
</Tip>
|
| 876 |
+
"""
|
| 877 |
+
self.original_attn_processors = None
|
| 878 |
+
|
| 879 |
+
for _, attn_processor in self.attn_processors.items():
|
| 880 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 881 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 882 |
+
|
| 883 |
+
self.original_attn_processors = self.attn_processors
|
| 884 |
+
|
| 885 |
+
for module in self.modules():
|
| 886 |
+
if isinstance(module, Attention):
|
| 887 |
+
module.fuse_projections(fuse=True)
|
| 888 |
+
|
| 889 |
+
def unfuse_qkv_projections(self):
|
| 890 |
+
"""Disables the fused QKV projection if enabled.
|
| 891 |
+
|
| 892 |
+
<Tip warning={true}>
|
| 893 |
+
|
| 894 |
+
This API is 🧪 experimental.
|
| 895 |
+
|
| 896 |
+
</Tip>
|
| 897 |
+
|
| 898 |
+
"""
|
| 899 |
+
if self.original_attn_processors is not None:
|
| 900 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 901 |
+
|
| 902 |
+
def unload_lora(self):
|
| 903 |
+
"""Unloads LoRA weights."""
|
| 904 |
+
deprecate(
|
| 905 |
+
"unload_lora",
|
| 906 |
+
"0.28.0",
|
| 907 |
+
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
|
| 908 |
+
)
|
| 909 |
+
for module in self.modules():
|
| 910 |
+
if hasattr(module, "set_lora_layer"):
|
| 911 |
+
module.set_lora_layer(None)
|
| 912 |
+
|
| 913 |
+
def get_time_embed(
|
| 914 |
+
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
|
| 915 |
+
) -> Optional[torch.Tensor]:
|
| 916 |
+
timesteps = timestep
|
| 917 |
+
if not torch.is_tensor(timesteps):
|
| 918 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 919 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 920 |
+
is_mps = sample.device.type == "mps"
|
| 921 |
+
if isinstance(timestep, float):
|
| 922 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 923 |
+
else:
|
| 924 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 925 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 926 |
+
elif len(timesteps.shape) == 0:
|
| 927 |
+
timesteps = timesteps[None].to(sample.device)
|
| 928 |
+
|
| 929 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 930 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 931 |
+
|
| 932 |
+
t_emb = self.time_proj(timesteps)
|
| 933 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 934 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 935 |
+
# there might be better ways to encapsulate this.
|
| 936 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 937 |
+
return t_emb
|
| 938 |
+
|
| 939 |
+
def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
|
| 940 |
+
class_emb = None
|
| 941 |
+
if self.class_embedding is not None:
|
| 942 |
+
if class_labels is None:
|
| 943 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 944 |
+
|
| 945 |
+
if self.config.class_embed_type == "timestep":
|
| 946 |
+
class_labels = self.time_proj(class_labels)
|
| 947 |
+
|
| 948 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 949 |
+
# there might be better ways to encapsulate this.
|
| 950 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
| 951 |
+
|
| 952 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 953 |
+
return class_emb
|
| 954 |
+
|
| 955 |
+
def get_aug_embed(
|
| 956 |
+
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
| 957 |
+
) -> Optional[torch.Tensor]:
|
| 958 |
+
aug_emb = None
|
| 959 |
+
if self.config.addition_embed_type == "text":
|
| 960 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 961 |
+
elif self.config.addition_embed_type == "text_image":
|
| 962 |
+
# Kandinsky 2.1 - style
|
| 963 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 964 |
+
raise ValueError(
|
| 965 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 966 |
+
)
|
| 967 |
+
|
| 968 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 969 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
| 970 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
| 971 |
+
elif self.config.addition_embed_type == "text_time":
|
| 972 |
+
# SDXL - style
|
| 973 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 974 |
+
raise ValueError(
|
| 975 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 976 |
+
)
|
| 977 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 978 |
+
if "time_ids" not in added_cond_kwargs:
|
| 979 |
+
raise ValueError(
|
| 980 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 981 |
+
)
|
| 982 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 983 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 984 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 985 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 986 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 987 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 988 |
+
elif self.config.addition_embed_type == "image":
|
| 989 |
+
# Kandinsky 2.2 - style
|
| 990 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 991 |
+
raise ValueError(
|
| 992 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 993 |
+
)
|
| 994 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 995 |
+
aug_emb = self.add_embedding(image_embs)
|
| 996 |
+
elif self.config.addition_embed_type == "image_hint":
|
| 997 |
+
# Kandinsky 2.2 - style
|
| 998 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
| 999 |
+
raise ValueError(
|
| 1000 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
| 1001 |
+
)
|
| 1002 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 1003 |
+
hint = added_cond_kwargs.get("hint")
|
| 1004 |
+
aug_emb = self.add_embedding(image_embs, hint)
|
| 1005 |
+
return aug_emb
|
| 1006 |
+
|
| 1007 |
+
def process_encoder_hidden_states(
|
| 1008 |
+
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
| 1009 |
+
) -> torch.Tensor:
|
| 1010 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
| 1011 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
| 1012 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
| 1013 |
+
# Kandinsky 2.1 - style
|
| 1014 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 1015 |
+
raise ValueError(
|
| 1016 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 1020 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
| 1021 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
| 1022 |
+
# Kandinsky 2.2 - style
|
| 1023 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 1024 |
+
raise ValueError(
|
| 1025 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 1026 |
+
)
|
| 1027 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 1028 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
| 1029 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
| 1030 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 1031 |
+
raise ValueError(
|
| 1032 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 1033 |
+
)
|
| 1034 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 1035 |
+
image_embeds = self.encoder_hid_proj(image_embeds)
|
| 1036 |
+
encoder_hidden_states = (encoder_hidden_states, image_embeds)
|
| 1037 |
+
return encoder_hidden_states
|
| 1038 |
+
|
| 1039 |
+
def forward(
|
| 1040 |
+
self,
|
| 1041 |
+
sample: torch.FloatTensor,
|
| 1042 |
+
timestep: Union[torch.Tensor, float, int],
|
| 1043 |
+
encoder_hidden_states: torch.Tensor,
|
| 1044 |
+
class_labels: Optional[torch.Tensor] = None,
|
| 1045 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 1046 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1047 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 1048 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 1049 |
+
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 1050 |
+
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 1051 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 1052 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1053 |
+
return_dict: bool = True,
|
| 1054 |
+
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
| 1055 |
+
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
| 1056 |
+
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
| 1057 |
+
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 1058 |
+
r"""
|
| 1059 |
+
The [`UNet2DConditionModel`] forward method.
|
| 1060 |
+
|
| 1061 |
+
Args:
|
| 1062 |
+
sample (`torch.FloatTensor`):
|
| 1063 |
+
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
| 1064 |
+
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 1065 |
+
encoder_hidden_states (`torch.FloatTensor`):
|
| 1066 |
+
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
| 1067 |
+
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
| 1068 |
+
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
| 1069 |
+
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
| 1070 |
+
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
| 1071 |
+
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
| 1072 |
+
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
| 1073 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 1074 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 1075 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 1076 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 1077 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 1078 |
+
`self.processor` in
|
| 1079 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 1080 |
+
added_cond_kwargs: (`dict`, *optional*):
|
| 1081 |
+
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
| 1082 |
+
are passed along to the UNet blocks.
|
| 1083 |
+
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
| 1084 |
+
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
| 1085 |
+
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
| 1086 |
+
A tensor that if specified is added to the residual of the middle unet block.
|
| 1087 |
+
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
| 1088 |
+
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
| 1089 |
+
encoder_attention_mask (`torch.Tensor`):
|
| 1090 |
+
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
| 1091 |
+
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
| 1092 |
+
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
| 1093 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1094 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 1095 |
+
tuple.
|
| 1096 |
+
|
| 1097 |
+
Returns:
|
| 1098 |
+
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
| 1099 |
+
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
|
| 1100 |
+
otherwise a `tuple` is returned where the first element is the sample tensor.
|
| 1101 |
+
"""
|
| 1102 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 1103 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
| 1104 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 1105 |
+
# on the fly if necessary.
|
| 1106 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 1107 |
+
|
| 1108 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 1109 |
+
forward_upsample_size = False
|
| 1110 |
+
upsample_size = None
|
| 1111 |
+
|
| 1112 |
+
for dim in sample.shape[-2:]:
|
| 1113 |
+
if dim % default_overall_up_factor != 0:
|
| 1114 |
+
# Forward upsample size to force interpolation output size.
|
| 1115 |
+
forward_upsample_size = True
|
| 1116 |
+
break
|
| 1117 |
+
|
| 1118 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
| 1119 |
+
# expects mask of shape:
|
| 1120 |
+
# [batch, key_tokens]
|
| 1121 |
+
# adds singleton query_tokens dimension:
|
| 1122 |
+
# [batch, 1, key_tokens]
|
| 1123 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 1124 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 1125 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 1126 |
+
if attention_mask is not None:
|
| 1127 |
+
# assume that mask is expressed as:
|
| 1128 |
+
# (1 = keep, 0 = discard)
|
| 1129 |
+
# convert mask into a bias that can be added to attention scores:
|
| 1130 |
+
# (keep = +0, discard = -10000.0)
|
| 1131 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 1132 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 1133 |
+
|
| 1134 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 1135 |
+
if encoder_attention_mask is not None:
|
| 1136 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 1137 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 1138 |
+
|
| 1139 |
+
# 0. center input if necessary
|
| 1140 |
+
if self.config.center_input_sample:
|
| 1141 |
+
sample = 2 * sample - 1.0
|
| 1142 |
+
|
| 1143 |
+
# 1. time
|
| 1144 |
+
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
| 1145 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 1146 |
+
aug_emb = None
|
| 1147 |
+
|
| 1148 |
+
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
| 1149 |
+
if class_emb is not None:
|
| 1150 |
+
if self.config.class_embeddings_concat:
|
| 1151 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 1152 |
+
else:
|
| 1153 |
+
emb = emb + class_emb
|
| 1154 |
+
|
| 1155 |
+
aug_emb = self.get_aug_embed(
|
| 1156 |
+
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
| 1157 |
+
)
|
| 1158 |
+
if self.config.addition_embed_type == "image_hint":
|
| 1159 |
+
aug_emb, hint = aug_emb
|
| 1160 |
+
sample = torch.cat([sample, hint], dim=1)
|
| 1161 |
+
|
| 1162 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 1163 |
+
|
| 1164 |
+
if self.time_embed_act is not None:
|
| 1165 |
+
emb = self.time_embed_act(emb)
|
| 1166 |
+
|
| 1167 |
+
encoder_hidden_states = self.process_encoder_hidden_states(
|
| 1168 |
+
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
| 1169 |
+
)
|
| 1170 |
+
|
| 1171 |
+
# 2. pre-process
|
| 1172 |
+
sample = self.conv_in(sample)
|
| 1173 |
+
|
| 1174 |
+
# 2.5 GLIGEN position net
|
| 1175 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
| 1176 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 1177 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
| 1178 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
| 1179 |
+
|
| 1180 |
+
# 3. down
|
| 1181 |
+
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
|
| 1182 |
+
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
|
| 1183 |
+
if cross_attention_kwargs is not None:
|
| 1184 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 1185 |
+
lora_scale = cross_attention_kwargs.pop("scale", 1.0)
|
| 1186 |
+
else:
|
| 1187 |
+
lora_scale = 1.0
|
| 1188 |
+
|
| 1189 |
+
if USE_PEFT_BACKEND:
|
| 1190 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 1191 |
+
scale_lora_layers(self, lora_scale)
|
| 1192 |
+
|
| 1193 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
| 1194 |
+
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
| 1195 |
+
is_adapter = down_intrablock_additional_residuals is not None
|
| 1196 |
+
# maintain backward compatibility for legacy usage, where
|
| 1197 |
+
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
| 1198 |
+
# but can only use one or the other
|
| 1199 |
+
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
|
| 1200 |
+
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
| 1201 |
+
deprecate(
|
| 1202 |
+
"T2I should not use down_block_additional_residuals",
|
| 1203 |
+
"1.3.0",
|
| 1204 |
+
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
| 1205 |
+
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
| 1206 |
+
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
| 1207 |
+
standard_warn=False,
|
| 1208 |
+
)
|
| 1209 |
+
down_intrablock_additional_residuals = down_block_additional_residuals
|
| 1210 |
+
is_adapter = True
|
| 1211 |
+
|
| 1212 |
+
down_block_res_samples = (sample,)
|
| 1213 |
+
|
| 1214 |
+
if is_brushnet:
|
| 1215 |
+
sample = sample + down_block_add_samples.pop(0)
|
| 1216 |
+
|
| 1217 |
+
for downsample_block in self.down_blocks:
|
| 1218 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 1219 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
| 1220 |
+
additional_residuals = {}
|
| 1221 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
| 1222 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
| 1223 |
+
|
| 1224 |
+
i = len(down_block_add_samples)
|
| 1225 |
+
|
| 1226 |
+
if is_brushnet and len(down_block_add_samples)>0:
|
| 1227 |
+
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
| 1228 |
+
for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
|
| 1229 |
+
|
| 1230 |
+
sample, res_samples = downsample_block(
|
| 1231 |
+
hidden_states=sample,
|
| 1232 |
+
temb=emb,
|
| 1233 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1234 |
+
attention_mask=attention_mask,
|
| 1235 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1236 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1237 |
+
**additional_residuals,
|
| 1238 |
+
)
|
| 1239 |
+
else:
|
| 1240 |
+
additional_residuals = {}
|
| 1241 |
+
|
| 1242 |
+
i = len(down_block_add_samples)
|
| 1243 |
+
|
| 1244 |
+
if is_brushnet and len(down_block_add_samples)>0:
|
| 1245 |
+
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
| 1246 |
+
for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
|
| 1247 |
+
|
| 1248 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, **additional_residuals)
|
| 1249 |
+
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
| 1250 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
| 1251 |
+
|
| 1252 |
+
down_block_res_samples += res_samples
|
| 1253 |
+
|
| 1254 |
+
if is_controlnet:
|
| 1255 |
+
new_down_block_res_samples = ()
|
| 1256 |
+
|
| 1257 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 1258 |
+
down_block_res_samples, down_block_additional_residuals
|
| 1259 |
+
):
|
| 1260 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 1261 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 1262 |
+
|
| 1263 |
+
down_block_res_samples = new_down_block_res_samples
|
| 1264 |
+
|
| 1265 |
+
# 4. mid
|
| 1266 |
+
if self.mid_block is not None:
|
| 1267 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| 1268 |
+
sample = self.mid_block(
|
| 1269 |
+
sample,
|
| 1270 |
+
emb,
|
| 1271 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1272 |
+
attention_mask=attention_mask,
|
| 1273 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1274 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1275 |
+
)
|
| 1276 |
+
else:
|
| 1277 |
+
sample = self.mid_block(sample, emb)
|
| 1278 |
+
|
| 1279 |
+
# To support T2I-Adapter-XL
|
| 1280 |
+
if (
|
| 1281 |
+
is_adapter
|
| 1282 |
+
and len(down_intrablock_additional_residuals) > 0
|
| 1283 |
+
and sample.shape == down_intrablock_additional_residuals[0].shape
|
| 1284 |
+
):
|
| 1285 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
| 1286 |
+
|
| 1287 |
+
if is_controlnet:
|
| 1288 |
+
sample = sample + mid_block_additional_residual
|
| 1289 |
+
|
| 1290 |
+
if is_brushnet:
|
| 1291 |
+
sample = sample + mid_block_add_sample
|
| 1292 |
+
|
| 1293 |
+
# 5. up
|
| 1294 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 1295 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 1296 |
+
|
| 1297 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 1298 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 1299 |
+
|
| 1300 |
+
# if we have not reached the final block and need to forward the
|
| 1301 |
+
# upsample size, we do it here
|
| 1302 |
+
if not is_final_block and forward_upsample_size:
|
| 1303 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 1304 |
+
|
| 1305 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 1306 |
+
additional_residuals = {}
|
| 1307 |
+
|
| 1308 |
+
i = len(up_block_add_samples)
|
| 1309 |
+
|
| 1310 |
+
if is_brushnet and len(up_block_add_samples)>0:
|
| 1311 |
+
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
| 1312 |
+
for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
|
| 1313 |
+
|
| 1314 |
+
sample = upsample_block(
|
| 1315 |
+
hidden_states=sample,
|
| 1316 |
+
temb=emb,
|
| 1317 |
+
res_hidden_states_tuple=res_samples,
|
| 1318 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 1319 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 1320 |
+
upsample_size=upsample_size,
|
| 1321 |
+
attention_mask=attention_mask,
|
| 1322 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 1323 |
+
**additional_residuals,
|
| 1324 |
+
)
|
| 1325 |
+
else:
|
| 1326 |
+
additional_residuals = {}
|
| 1327 |
+
|
| 1328 |
+
i = len(up_block_add_samples)
|
| 1329 |
+
|
| 1330 |
+
if is_brushnet and len(up_block_add_samples)>0:
|
| 1331 |
+
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
| 1332 |
+
for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
|
| 1333 |
+
|
| 1334 |
+
sample = upsample_block(
|
| 1335 |
+
hidden_states=sample,
|
| 1336 |
+
temb=emb,
|
| 1337 |
+
res_hidden_states_tuple=res_samples,
|
| 1338 |
+
upsample_size=upsample_size,
|
| 1339 |
+
**additional_residuals,
|
| 1340 |
+
)
|
| 1341 |
+
|
| 1342 |
+
# 6. post-process
|
| 1343 |
+
if self.conv_norm_out:
|
| 1344 |
+
sample = self.conv_norm_out(sample)
|
| 1345 |
+
sample = self.conv_act(sample)
|
| 1346 |
+
sample = self.conv_out(sample)
|
| 1347 |
+
|
| 1348 |
+
if USE_PEFT_BACKEND:
|
| 1349 |
+
# remove `lora_scale` from each PEFT layer
|
| 1350 |
+
unscale_lora_layers(self, lora_scale)
|
| 1351 |
+
|
| 1352 |
+
if not return_dict:
|
| 1353 |
+
return (sample,)
|
| 1354 |
+
|
| 1355 |
+
return UNet2DConditionOutput(sample=sample)
|
ComfyUI-BrushNet/brushnet_nodes.py
ADDED
|
@@ -0,0 +1,1080 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import types
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torchvision.transforms as T
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
| 9 |
+
|
| 10 |
+
#import sys
|
| 11 |
+
#from sys import platform
|
| 12 |
+
# Get the parent directory of 'comfy' and add it to the Python path
|
| 13 |
+
#comfy_parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
| 14 |
+
#sys.path.append(comfy_parent_dir)
|
| 15 |
+
|
| 16 |
+
import comfy
|
| 17 |
+
import folder_paths
|
| 18 |
+
|
| 19 |
+
from .model_patch import add_model_patch_option, patch_model_function_wrapper
|
| 20 |
+
|
| 21 |
+
from .brushnet.brushnet import BrushNetModel
|
| 22 |
+
from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
|
| 23 |
+
|
| 24 |
+
from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
|
| 25 |
+
|
| 26 |
+
current_directory = os.path.dirname(os.path.abspath(__file__))
|
| 27 |
+
brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
|
| 28 |
+
brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
|
| 29 |
+
powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
|
| 30 |
+
|
| 31 |
+
sd15_scaling_factor = 0.18215
|
| 32 |
+
sdxl_scaling_factor = 0.13025
|
| 33 |
+
|
| 34 |
+
ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
|
| 35 |
+
comfy.ldm.models.autoencoder.AutoencoderKL
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class BrushNetLoader:
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def INPUT_TYPES(s):
|
| 43 |
+
files, inpaint_path = get_files_with_extension('inpaint')
|
| 44 |
+
s.inpaint_path = inpaint_path
|
| 45 |
+
return {"required":
|
| 46 |
+
{
|
| 47 |
+
"brushnet": (files, ),
|
| 48 |
+
"dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
|
| 49 |
+
},
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
CATEGORY = "inpaint"
|
| 53 |
+
RETURN_TYPES = ("BRMODEL",)
|
| 54 |
+
RETURN_NAMES = ("brushnet",)
|
| 55 |
+
|
| 56 |
+
FUNCTION = "brushnet_loading"
|
| 57 |
+
|
| 58 |
+
def brushnet_loading(self, brushnet, dtype):
|
| 59 |
+
brushnet_file = os.path.join(self.inpaint_path, brushnet)
|
| 60 |
+
is_SDXL = False
|
| 61 |
+
is_PP = False
|
| 62 |
+
sd = comfy.utils.load_torch_file(brushnet_file)
|
| 63 |
+
brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
|
| 64 |
+
del sd
|
| 65 |
+
if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
|
| 66 |
+
is_SDXL = False
|
| 67 |
+
if keys == 322:
|
| 68 |
+
is_PP = False
|
| 69 |
+
print('BrushNet model type: SD1.5')
|
| 70 |
+
else:
|
| 71 |
+
is_PP = True
|
| 72 |
+
print('PowerPaint model type: SD1.5')
|
| 73 |
+
elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
|
| 74 |
+
print('BrushNet model type: Loading SDXL')
|
| 75 |
+
is_SDXL = True
|
| 76 |
+
is_PP = False
|
| 77 |
+
else:
|
| 78 |
+
raise Exception("Unknown BrushNet model")
|
| 79 |
+
|
| 80 |
+
with init_empty_weights():
|
| 81 |
+
if is_SDXL:
|
| 82 |
+
brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
|
| 83 |
+
brushnet_model = BrushNetModel.from_config(brushnet_config)
|
| 84 |
+
elif is_PP:
|
| 85 |
+
brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
|
| 86 |
+
brushnet_model = PowerPaintModel.from_config(brushnet_config)
|
| 87 |
+
else:
|
| 88 |
+
brushnet_config = BrushNetModel.load_config(brushnet_config_file)
|
| 89 |
+
brushnet_model = BrushNetModel.from_config(brushnet_config)
|
| 90 |
+
|
| 91 |
+
if is_PP:
|
| 92 |
+
print("PowerPaint model file:", brushnet_file)
|
| 93 |
+
else:
|
| 94 |
+
print("BrushNet model file:", brushnet_file)
|
| 95 |
+
|
| 96 |
+
if dtype == 'float16':
|
| 97 |
+
torch_dtype = torch.float16
|
| 98 |
+
elif dtype == 'bfloat16':
|
| 99 |
+
torch_dtype = torch.bfloat16
|
| 100 |
+
elif dtype == 'float32':
|
| 101 |
+
torch_dtype = torch.float32
|
| 102 |
+
else:
|
| 103 |
+
torch_dtype = torch.float64
|
| 104 |
+
|
| 105 |
+
brushnet_model = load_checkpoint_and_dispatch(
|
| 106 |
+
brushnet_model,
|
| 107 |
+
brushnet_file,
|
| 108 |
+
device_map="sequential",
|
| 109 |
+
max_memory=None,
|
| 110 |
+
offload_folder=None,
|
| 111 |
+
offload_state_dict=False,
|
| 112 |
+
dtype=torch_dtype,
|
| 113 |
+
force_hooks=False,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if is_PP:
|
| 117 |
+
print("PowerPaint model is loaded")
|
| 118 |
+
elif is_SDXL:
|
| 119 |
+
print("BrushNet SDXL model is loaded")
|
| 120 |
+
else:
|
| 121 |
+
print("BrushNet SD1.5 model is loaded")
|
| 122 |
+
|
| 123 |
+
return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class PowerPaintCLIPLoader:
|
| 127 |
+
|
| 128 |
+
@classmethod
|
| 129 |
+
def INPUT_TYPES(s):
|
| 130 |
+
inpaint_files, inpaint_path = get_files_with_extension('inpaint', ['bin'])
|
| 131 |
+
s.inpaint_path = inpaint_path
|
| 132 |
+
clip_files, clip_path = get_files_with_extension('clip')
|
| 133 |
+
s.clip_path = clip_path
|
| 134 |
+
return {"required":
|
| 135 |
+
{
|
| 136 |
+
"base": (clip_files, ),
|
| 137 |
+
"powerpaint": (inpaint_files, ),
|
| 138 |
+
},
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
CATEGORY = "inpaint"
|
| 142 |
+
RETURN_TYPES = ("CLIP",)
|
| 143 |
+
RETURN_NAMES = ("clip",)
|
| 144 |
+
|
| 145 |
+
FUNCTION = "ppclip_loading"
|
| 146 |
+
|
| 147 |
+
def ppclip_loading(self, base, powerpaint):
|
| 148 |
+
base_CLIP_file = os.path.join(self.clip_path, base)
|
| 149 |
+
pp_CLIP_file = os.path.join(self.inpaint_path, powerpaint)
|
| 150 |
+
|
| 151 |
+
pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
|
| 152 |
+
|
| 153 |
+
print('PowerPaint base CLIP file: ', base_CLIP_file)
|
| 154 |
+
|
| 155 |
+
pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
|
| 156 |
+
pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
|
| 157 |
+
|
| 158 |
+
add_tokens(
|
| 159 |
+
tokenizer = pp_tokenizer,
|
| 160 |
+
text_encoder = pp_text_encoder,
|
| 161 |
+
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
|
| 162 |
+
initialize_tokens = ["a", "a", "a"],
|
| 163 |
+
num_vectors_per_token = 10,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
|
| 167 |
+
|
| 168 |
+
print('PowerPaint CLIP file: ', pp_CLIP_file)
|
| 169 |
+
|
| 170 |
+
pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
|
| 171 |
+
pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
|
| 172 |
+
|
| 173 |
+
return (pp_clip,)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class PowerPaint:
|
| 177 |
+
|
| 178 |
+
@classmethod
|
| 179 |
+
def INPUT_TYPES(s):
|
| 180 |
+
return {"required":
|
| 181 |
+
{
|
| 182 |
+
"model": ("MODEL",),
|
| 183 |
+
"vae": ("VAE", ),
|
| 184 |
+
"image": ("IMAGE",),
|
| 185 |
+
"mask": ("MASK",),
|
| 186 |
+
"powerpaint": ("BRMODEL", ),
|
| 187 |
+
"clip": ("CLIP", ),
|
| 188 |
+
"positive": ("CONDITIONING", ),
|
| 189 |
+
"negative": ("CONDITIONING", ),
|
| 190 |
+
"fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
|
| 191 |
+
"function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
|
| 192 |
+
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
| 193 |
+
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
| 194 |
+
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
| 195 |
+
},
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
CATEGORY = "inpaint"
|
| 199 |
+
RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
|
| 200 |
+
RETURN_NAMES = ("model","positive","negative","latent",)
|
| 201 |
+
|
| 202 |
+
FUNCTION = "model_update"
|
| 203 |
+
|
| 204 |
+
def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at):
|
| 205 |
+
|
| 206 |
+
is_SDXL, is_PP = check_compatibilty(model, powerpaint)
|
| 207 |
+
if not is_PP:
|
| 208 |
+
raise Exception("BrushNet model was loaded, please use BrushNet node")
|
| 209 |
+
|
| 210 |
+
# Make a copy of the model so that we're not patching it everywhere in the workflow.
|
| 211 |
+
model = model.clone()
|
| 212 |
+
|
| 213 |
+
# prepare image and mask
|
| 214 |
+
# no batches for original image and mask
|
| 215 |
+
masked_image, mask = prepare_image(image, mask)
|
| 216 |
+
|
| 217 |
+
batch = masked_image.shape[0]
|
| 218 |
+
#width = masked_image.shape[2]
|
| 219 |
+
#height = masked_image.shape[1]
|
| 220 |
+
|
| 221 |
+
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
|
| 222 |
+
scaling_factor = model.model.model_config.latent_format.scale_factor
|
| 223 |
+
else:
|
| 224 |
+
scaling_factor = sd15_scaling_factor
|
| 225 |
+
|
| 226 |
+
torch_dtype = powerpaint['dtype']
|
| 227 |
+
|
| 228 |
+
# prepare conditioning latents
|
| 229 |
+
conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
|
| 230 |
+
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
| 231 |
+
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
| 232 |
+
|
| 233 |
+
# prepare embeddings
|
| 234 |
+
|
| 235 |
+
if function == "object removal":
|
| 236 |
+
promptA = "P_ctxt"
|
| 237 |
+
promptB = "P_ctxt"
|
| 238 |
+
negative_promptA = "P_obj"
|
| 239 |
+
negative_promptB = "P_obj"
|
| 240 |
+
print('You should add to positive prompt: "empty scene blur"')
|
| 241 |
+
#positive = positive + " empty scene blur"
|
| 242 |
+
elif function == "context aware":
|
| 243 |
+
promptA = "P_ctxt"
|
| 244 |
+
promptB = "P_ctxt"
|
| 245 |
+
negative_promptA = ""
|
| 246 |
+
negative_promptB = ""
|
| 247 |
+
#positive = positive + " empty scene"
|
| 248 |
+
print('You should add to positive prompt: "empty scene"')
|
| 249 |
+
elif function == "shape guided":
|
| 250 |
+
promptA = "P_shape"
|
| 251 |
+
promptB = "P_ctxt"
|
| 252 |
+
negative_promptA = "P_shape"
|
| 253 |
+
negative_promptB = "P_ctxt"
|
| 254 |
+
elif function == "image outpainting":
|
| 255 |
+
promptA = "P_ctxt"
|
| 256 |
+
promptB = "P_ctxt"
|
| 257 |
+
negative_promptA = "P_obj"
|
| 258 |
+
negative_promptB = "P_obj"
|
| 259 |
+
#positive = positive + " empty scene"
|
| 260 |
+
print('You should add to positive prompt: "empty scene"')
|
| 261 |
+
else:
|
| 262 |
+
promptA = "P_obj"
|
| 263 |
+
promptB = "P_obj"
|
| 264 |
+
negative_promptA = "P_obj"
|
| 265 |
+
negative_promptB = "P_obj"
|
| 266 |
+
|
| 267 |
+
tokens = clip.tokenize(promptA)
|
| 268 |
+
prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
|
| 269 |
+
|
| 270 |
+
tokens = clip.tokenize(negative_promptA)
|
| 271 |
+
negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
|
| 272 |
+
|
| 273 |
+
tokens = clip.tokenize(promptB)
|
| 274 |
+
prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
|
| 275 |
+
|
| 276 |
+
tokens = clip.tokenize(negative_promptB)
|
| 277 |
+
negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
|
| 278 |
+
|
| 279 |
+
prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
| 280 |
+
negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
|
| 281 |
+
|
| 282 |
+
# unload vae and CLIPs
|
| 283 |
+
del vae
|
| 284 |
+
del clip
|
| 285 |
+
for loaded_model in comfy.model_management.current_loaded_models:
|
| 286 |
+
if type(loaded_model.model.model) in ModelsToUnload:
|
| 287 |
+
comfy.model_management.current_loaded_models.remove(loaded_model)
|
| 288 |
+
loaded_model.model_unload()
|
| 289 |
+
del loaded_model
|
| 290 |
+
|
| 291 |
+
# apply patch to model
|
| 292 |
+
|
| 293 |
+
brushnet_conditioning_scale = scale
|
| 294 |
+
control_guidance_start = start_at
|
| 295 |
+
control_guidance_end = end_at
|
| 296 |
+
|
| 297 |
+
add_brushnet_patch(model,
|
| 298 |
+
powerpaint['brushnet'],
|
| 299 |
+
torch_dtype,
|
| 300 |
+
conditioning_latents,
|
| 301 |
+
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
|
| 302 |
+
negative_prompt_embeds_pp, prompt_embeds_pp,
|
| 303 |
+
None, None, None)
|
| 304 |
+
|
| 305 |
+
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
|
| 306 |
+
|
| 307 |
+
return (model, positive, negative, {"samples":latent},)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class BrushNet:
|
| 311 |
+
|
| 312 |
+
@classmethod
|
| 313 |
+
def INPUT_TYPES(s):
|
| 314 |
+
return {"required":
|
| 315 |
+
{
|
| 316 |
+
"model": ("MODEL",),
|
| 317 |
+
"vae": ("VAE", ),
|
| 318 |
+
"image": ("IMAGE",),
|
| 319 |
+
"mask": ("MASK",),
|
| 320 |
+
"brushnet": ("BRMODEL", ),
|
| 321 |
+
"positive": ("CONDITIONING", ),
|
| 322 |
+
"negative": ("CONDITIONING", ),
|
| 323 |
+
"scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
| 324 |
+
"start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
| 325 |
+
"end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
| 326 |
+
},
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
CATEGORY = "inpaint"
|
| 330 |
+
RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
|
| 331 |
+
RETURN_NAMES = ("model","positive","negative","latent",)
|
| 332 |
+
|
| 333 |
+
FUNCTION = "model_update"
|
| 334 |
+
|
| 335 |
+
def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
|
| 336 |
+
|
| 337 |
+
is_SDXL, is_PP = check_compatibilty(model, brushnet)
|
| 338 |
+
|
| 339 |
+
if is_PP:
|
| 340 |
+
raise Exception("PowerPaint model was loaded, please use PowerPaint node")
|
| 341 |
+
|
| 342 |
+
# Make a copy of the model so that we're not patching it everywhere in the workflow.
|
| 343 |
+
model = model.clone()
|
| 344 |
+
|
| 345 |
+
# prepare image and mask
|
| 346 |
+
# no batches for original image and mask
|
| 347 |
+
masked_image, mask = prepare_image(image, mask)
|
| 348 |
+
|
| 349 |
+
batch = masked_image.shape[0]
|
| 350 |
+
width = masked_image.shape[2]
|
| 351 |
+
height = masked_image.shape[1]
|
| 352 |
+
|
| 353 |
+
if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
|
| 354 |
+
scaling_factor = model.model.model_config.latent_format.scale_factor
|
| 355 |
+
elif is_SDXL:
|
| 356 |
+
scaling_factor = sdxl_scaling_factor
|
| 357 |
+
else:
|
| 358 |
+
scaling_factor = sd15_scaling_factor
|
| 359 |
+
|
| 360 |
+
torch_dtype = brushnet['dtype']
|
| 361 |
+
|
| 362 |
+
# prepare conditioning latents
|
| 363 |
+
conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
|
| 364 |
+
conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
| 365 |
+
conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
| 366 |
+
|
| 367 |
+
# unload vae
|
| 368 |
+
del vae
|
| 369 |
+
for loaded_model in comfy.model_management.current_loaded_models:
|
| 370 |
+
if type(loaded_model.model.model) in ModelsToUnload:
|
| 371 |
+
comfy.model_management.current_loaded_models.remove(loaded_model)
|
| 372 |
+
loaded_model.model_unload()
|
| 373 |
+
del loaded_model
|
| 374 |
+
|
| 375 |
+
# prepare embeddings
|
| 376 |
+
|
| 377 |
+
prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
| 378 |
+
negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
| 379 |
+
|
| 380 |
+
max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
|
| 381 |
+
if prompt_embeds.shape[1] < max_tokens:
|
| 382 |
+
multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
|
| 383 |
+
prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
|
| 384 |
+
print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
|
| 385 |
+
if negative_prompt_embeds.shape[1] < max_tokens:
|
| 386 |
+
multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
|
| 387 |
+
negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
|
| 388 |
+
print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
|
| 389 |
+
|
| 390 |
+
if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
|
| 391 |
+
pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
| 392 |
+
else:
|
| 393 |
+
print('BrushNet: positive conditioning has not pooled_output')
|
| 394 |
+
if is_SDXL:
|
| 395 |
+
print('BrushNet will not produce correct results')
|
| 396 |
+
pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
|
| 397 |
+
|
| 398 |
+
if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
|
| 399 |
+
negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
| 400 |
+
else:
|
| 401 |
+
print('BrushNet: negative conditioning has not pooled_output')
|
| 402 |
+
if is_SDXL:
|
| 403 |
+
print('BrushNet will not produce correct results')
|
| 404 |
+
negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
|
| 405 |
+
|
| 406 |
+
time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
|
| 407 |
+
|
| 408 |
+
if not is_SDXL:
|
| 409 |
+
pooled_prompt_embeds = None
|
| 410 |
+
negative_pooled_prompt_embeds = None
|
| 411 |
+
time_ids = None
|
| 412 |
+
|
| 413 |
+
# apply patch to model
|
| 414 |
+
|
| 415 |
+
brushnet_conditioning_scale = scale
|
| 416 |
+
control_guidance_start = start_at
|
| 417 |
+
control_guidance_end = end_at
|
| 418 |
+
|
| 419 |
+
add_brushnet_patch(model,
|
| 420 |
+
brushnet['brushnet'],
|
| 421 |
+
torch_dtype,
|
| 422 |
+
conditioning_latents,
|
| 423 |
+
(brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
|
| 424 |
+
prompt_embeds, negative_prompt_embeds,
|
| 425 |
+
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
|
| 426 |
+
|
| 427 |
+
latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
|
| 428 |
+
|
| 429 |
+
return (model, positive, negative, {"samples":latent},)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
class BlendInpaint:
|
| 433 |
+
|
| 434 |
+
@classmethod
|
| 435 |
+
def INPUT_TYPES(s):
|
| 436 |
+
return {"required":
|
| 437 |
+
{
|
| 438 |
+
"inpaint": ("IMAGE",),
|
| 439 |
+
"original": ("IMAGE",),
|
| 440 |
+
"mask": ("MASK",),
|
| 441 |
+
"kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
|
| 442 |
+
"sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
|
| 443 |
+
},
|
| 444 |
+
"optional":
|
| 445 |
+
{
|
| 446 |
+
"origin": ("VECTOR",),
|
| 447 |
+
},
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
CATEGORY = "inpaint"
|
| 451 |
+
RETURN_TYPES = ("IMAGE","MASK",)
|
| 452 |
+
RETURN_NAMES = ("image","MASK",)
|
| 453 |
+
|
| 454 |
+
FUNCTION = "blend_inpaint"
|
| 455 |
+
|
| 456 |
+
def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
|
| 457 |
+
|
| 458 |
+
original, mask = check_image_mask(original, mask, 'Blend Inpaint')
|
| 459 |
+
|
| 460 |
+
if len(inpaint.shape) < 4:
|
| 461 |
+
# image tensor shape should be [B, H, W, C], but batch somehow is missing
|
| 462 |
+
inpaint = inpaint[None,:,:,:]
|
| 463 |
+
|
| 464 |
+
if inpaint.shape[0] < original.shape[0]:
|
| 465 |
+
print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
|
| 466 |
+
original= original[:inpaint.shape[0],:,:]
|
| 467 |
+
mask = mask[:inpaint.shape[0],:,:]
|
| 468 |
+
|
| 469 |
+
if inpaint.shape[0] > original.shape[0]:
|
| 470 |
+
# batch over inpaint
|
| 471 |
+
count = 0
|
| 472 |
+
original_list = []
|
| 473 |
+
mask_list = []
|
| 474 |
+
origin_list = []
|
| 475 |
+
while (count < inpaint.shape[0]):
|
| 476 |
+
for i in range(original.shape[0]):
|
| 477 |
+
original_list.append(original[i][None,:,:,:])
|
| 478 |
+
mask_list.append(mask[i][None,:,:])
|
| 479 |
+
if origin is not None:
|
| 480 |
+
origin_list.append(origin[i][None,:])
|
| 481 |
+
count += 1
|
| 482 |
+
if count >= inpaint.shape[0]:
|
| 483 |
+
break
|
| 484 |
+
original = torch.concat(original_list, dim=0)
|
| 485 |
+
mask = torch.concat(mask_list, dim=0)
|
| 486 |
+
if origin is not None:
|
| 487 |
+
origin = torch.concat(origin_list, dim=0)
|
| 488 |
+
|
| 489 |
+
if kernel % 2 == 0:
|
| 490 |
+
kernel += 1
|
| 491 |
+
transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
|
| 492 |
+
|
| 493 |
+
ret = []
|
| 494 |
+
blurred = []
|
| 495 |
+
for i in range(inpaint.shape[0]):
|
| 496 |
+
if origin is None:
|
| 497 |
+
blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
|
| 498 |
+
blurred.append(blurred_mask[0])
|
| 499 |
+
|
| 500 |
+
result = torch.nn.functional.interpolate(
|
| 501 |
+
inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
|
| 502 |
+
size=(
|
| 503 |
+
original[i].shape[0],
|
| 504 |
+
original[i].shape[1],
|
| 505 |
+
)
|
| 506 |
+
).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
|
| 507 |
+
else:
|
| 508 |
+
# got mask from CutForInpaint
|
| 509 |
+
height, width, _ = original[i].shape
|
| 510 |
+
x0 = origin[i][0].item()
|
| 511 |
+
y0 = origin[i][1].item()
|
| 512 |
+
|
| 513 |
+
if mask[i].shape[0] < height or mask[i].shape[1] < width:
|
| 514 |
+
padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
|
| 515 |
+
y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
|
| 516 |
+
else:
|
| 517 |
+
padded_mask = mask[i]
|
| 518 |
+
blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
|
| 519 |
+
blurred.append(blurred_mask[0][0])
|
| 520 |
+
|
| 521 |
+
result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
|
| 522 |
+
y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
|
| 523 |
+
result = result[None,:,:,:].to(original.device).to(original.dtype)
|
| 524 |
+
|
| 525 |
+
ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
|
| 526 |
+
|
| 527 |
+
return (torch.stack(ret), torch.stack(blurred), )
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
class CutForInpaint:
|
| 531 |
+
|
| 532 |
+
@classmethod
|
| 533 |
+
def INPUT_TYPES(s):
|
| 534 |
+
return {"required":
|
| 535 |
+
{
|
| 536 |
+
"image": ("IMAGE",),
|
| 537 |
+
"mask": ("MASK",),
|
| 538 |
+
"width": ("INT", {"default": 512, "min": 64, "max": 2048}),
|
| 539 |
+
"height": ("INT", {"default": 512, "min": 64, "max": 2048}),
|
| 540 |
+
},
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
CATEGORY = "inpaint"
|
| 544 |
+
RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
|
| 545 |
+
RETURN_NAMES = ("image","mask","origin",)
|
| 546 |
+
|
| 547 |
+
FUNCTION = "cut_for_inpaint"
|
| 548 |
+
|
| 549 |
+
def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int):
|
| 550 |
+
|
| 551 |
+
image, mask = check_image_mask(image, mask, 'BrushNet')
|
| 552 |
+
|
| 553 |
+
ret = []
|
| 554 |
+
msk = []
|
| 555 |
+
org = []
|
| 556 |
+
for i in range(image.shape[0]):
|
| 557 |
+
x0, y0, w, h = cut_with_mask(mask[i], width, height)
|
| 558 |
+
ret.append((image[i][y0:y0+h,x0:x0+w,:]))
|
| 559 |
+
msk.append((mask[i][y0:y0+h,x0:x0+w]))
|
| 560 |
+
org.append(torch.IntTensor([x0,y0]))
|
| 561 |
+
|
| 562 |
+
return (torch.stack(ret), torch.stack(msk), torch.stack(org), )
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
#### Utility function
|
| 566 |
+
|
| 567 |
+
def get_files_with_extension(folder_name, extension=['safetensors']):
|
| 568 |
+
|
| 569 |
+
try:
|
| 570 |
+
inpaint_path = folder_paths.get_folder_paths(folder_name)[0]
|
| 571 |
+
except:
|
| 572 |
+
inpaint_path = os.path.join(folder_paths.models_dir, folder_name)
|
| 573 |
+
|
| 574 |
+
if not os.path.isdir(inpaint_path):
|
| 575 |
+
inpaint_path = os.path.join(folder_paths.base_path, inpaint_path)
|
| 576 |
+
if not os.path.isdir(inpaint_path):
|
| 577 |
+
return ([], '')
|
| 578 |
+
#raise Exception("Can't find", folder_name, " path")
|
| 579 |
+
|
| 580 |
+
while not inpaint_path[-1].isalpha():
|
| 581 |
+
inpaint_path = inpaint_path[:-1]
|
| 582 |
+
|
| 583 |
+
abs_list = []
|
| 584 |
+
for x in os.walk(inpaint_path):
|
| 585 |
+
for name in x[2]:
|
| 586 |
+
for ext in extension:
|
| 587 |
+
if ext in name:
|
| 588 |
+
abs_list.append(os.path.join(x[0], name))
|
| 589 |
+
|
| 590 |
+
abs_list = sorted(list(set(abs_list)))
|
| 591 |
+
|
| 592 |
+
names = []
|
| 593 |
+
for x in abs_list:
|
| 594 |
+
remain = x
|
| 595 |
+
y = ''
|
| 596 |
+
while remain != inpaint_path:
|
| 597 |
+
remain, folder = os.path.split(remain)
|
| 598 |
+
if len(y) > 0:
|
| 599 |
+
y = os.path.join(folder, y)
|
| 600 |
+
else:
|
| 601 |
+
y = folder
|
| 602 |
+
names.append(y)
|
| 603 |
+
return names, inpaint_path
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def brushnet_blocks(sd):
|
| 607 |
+
brushnet_down_block = 0
|
| 608 |
+
brushnet_mid_block = 0
|
| 609 |
+
brushnet_up_block = 0
|
| 610 |
+
for key in sd:
|
| 611 |
+
if 'brushnet_down_block' in key:
|
| 612 |
+
brushnet_down_block += 1
|
| 613 |
+
if 'brushnet_mid_block' in key:
|
| 614 |
+
brushnet_mid_block += 1
|
| 615 |
+
if 'brushnet_up_block' in key:
|
| 616 |
+
brushnet_up_block += 1
|
| 617 |
+
return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
# Check models compatibility
|
| 621 |
+
def check_compatibilty(model, brushnet):
|
| 622 |
+
is_SDXL = False
|
| 623 |
+
is_PP = False
|
| 624 |
+
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
| 625 |
+
print('Base model type: SD1.5')
|
| 626 |
+
is_SDXL = False
|
| 627 |
+
if brushnet["SDXL"]:
|
| 628 |
+
raise Exception("Base model is SD15, but BrushNet is SDXL type")
|
| 629 |
+
if brushnet["PP"]:
|
| 630 |
+
is_PP = True
|
| 631 |
+
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
| 632 |
+
print('Base model type: SDXL')
|
| 633 |
+
is_SDXL = True
|
| 634 |
+
if not brushnet["SDXL"]:
|
| 635 |
+
raise Exception("Base model is SDXL, but BrushNet is SD15 type")
|
| 636 |
+
else:
|
| 637 |
+
print('Base model type: ', type(model.model.model_config))
|
| 638 |
+
raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
|
| 639 |
+
|
| 640 |
+
return (is_SDXL, is_PP)
|
| 641 |
+
|
| 642 |
+
|
| 643 |
+
def check_image_mask(image, mask, name):
|
| 644 |
+
if len(image.shape) < 4:
|
| 645 |
+
# image tensor shape should be [B, H, W, C], but batch somehow is missing
|
| 646 |
+
image = image[None,:,:,:]
|
| 647 |
+
|
| 648 |
+
if len(mask.shape) > 3:
|
| 649 |
+
# mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
|
| 650 |
+
# take first mask, red channel
|
| 651 |
+
mask = (mask[:,:,:,0])[:,:,:]
|
| 652 |
+
elif len(mask.shape) < 3:
|
| 653 |
+
# mask tensor shape should be [B, H, W] but batch somehow is missing
|
| 654 |
+
mask = mask[None,:,:]
|
| 655 |
+
|
| 656 |
+
if image.shape[0] > mask.shape[0]:
|
| 657 |
+
print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
|
| 658 |
+
if mask.shape[0] == 1:
|
| 659 |
+
print(name, "will copy the mask to fill batch")
|
| 660 |
+
mask = torch.cat([mask] * image.shape[0], dim=0)
|
| 661 |
+
else:
|
| 662 |
+
print(name, "will add empty masks to fill batch")
|
| 663 |
+
empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
|
| 664 |
+
mask = torch.cat([mask, empty_mask], dim=0)
|
| 665 |
+
elif image.shape[0] < mask.shape[0]:
|
| 666 |
+
print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
|
| 667 |
+
mask = mask[:image.shape[0],:,:]
|
| 668 |
+
|
| 669 |
+
return (image, mask)
|
| 670 |
+
|
| 671 |
+
# Prepare image and mask
|
| 672 |
+
def prepare_image(image, mask):
|
| 673 |
+
|
| 674 |
+
image, mask = check_image_mask(image, mask, 'BrushNet')
|
| 675 |
+
|
| 676 |
+
print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
|
| 677 |
+
|
| 678 |
+
if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
|
| 679 |
+
raise Exception("Image and mask should be the same size")
|
| 680 |
+
|
| 681 |
+
# As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
|
| 682 |
+
mask = mask.round()
|
| 683 |
+
|
| 684 |
+
masked_image = image * (1.0 - mask[:,:,:,None])
|
| 685 |
+
|
| 686 |
+
return (masked_image, mask)
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
def cut_with_mask(mask, width, height):
|
| 690 |
+
iy, ix = (mask == 1).nonzero(as_tuple=True)
|
| 691 |
+
|
| 692 |
+
h0, w0 = mask.shape
|
| 693 |
+
|
| 694 |
+
if iy.numel() == 0:
|
| 695 |
+
x_c = w0 / 2.0
|
| 696 |
+
y_c = h0 / 2.0
|
| 697 |
+
else:
|
| 698 |
+
x_min = ix.min().item()
|
| 699 |
+
x_max = ix.max().item()
|
| 700 |
+
y_min = iy.min().item()
|
| 701 |
+
y_max = iy.max().item()
|
| 702 |
+
|
| 703 |
+
if x_max - x_min > width or y_max - y_min > height:
|
| 704 |
+
raise Exception("Mask is bigger than provided dimensions")
|
| 705 |
+
|
| 706 |
+
x_c = (x_min + x_max) / 2.0
|
| 707 |
+
y_c = (y_min + y_max) / 2.0
|
| 708 |
+
|
| 709 |
+
width2 = width / 2.0
|
| 710 |
+
height2 = height / 2.0
|
| 711 |
+
|
| 712 |
+
if w0 <= width:
|
| 713 |
+
x0 = 0
|
| 714 |
+
w = w0
|
| 715 |
+
else:
|
| 716 |
+
x0 = max(0, x_c - width2)
|
| 717 |
+
w = width
|
| 718 |
+
if x0 + width > w0:
|
| 719 |
+
x0 = w0 - width
|
| 720 |
+
|
| 721 |
+
if h0 <= height:
|
| 722 |
+
y0 = 0
|
| 723 |
+
h = h0
|
| 724 |
+
else:
|
| 725 |
+
y0 = max(0, y_c - height2)
|
| 726 |
+
h = height
|
| 727 |
+
if y0 + height > h0:
|
| 728 |
+
y0 = h0 - height
|
| 729 |
+
|
| 730 |
+
return (int(x0), int(y0), int(w), int(h))
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
# Prepare conditioning_latents
|
| 734 |
+
@torch.inference_mode()
|
| 735 |
+
def get_image_latents(masked_image, mask, vae, scaling_factor):
|
| 736 |
+
processed_image = masked_image.to(vae.device)
|
| 737 |
+
image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
|
| 738 |
+
processed_mask = 1. - mask[:,None,:,:]
|
| 739 |
+
interpolated_mask = torch.nn.functional.interpolate(
|
| 740 |
+
processed_mask,
|
| 741 |
+
size=(
|
| 742 |
+
image_latents.shape[-2],
|
| 743 |
+
image_latents.shape[-1]
|
| 744 |
+
)
|
| 745 |
+
)
|
| 746 |
+
interpolated_mask = interpolated_mask.to(image_latents.device)
|
| 747 |
+
|
| 748 |
+
conditioning_latents = [image_latents, interpolated_mask]
|
| 749 |
+
|
| 750 |
+
print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
|
| 751 |
+
|
| 752 |
+
return conditioning_latents
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
# Main function where magic happens
|
| 756 |
+
@torch.inference_mode()
|
| 757 |
+
def brushnet_inference(x, timesteps, transformer_options):
|
| 758 |
+
if 'model_patch' not in transformer_options:
|
| 759 |
+
print('BrushNet inference: there is no model_patch key in transformer_options')
|
| 760 |
+
return ([], 0, [])
|
| 761 |
+
mp = transformer_options['model_patch']
|
| 762 |
+
if 'brushnet' not in mp:
|
| 763 |
+
print('BrushNet inference: there is no brushnet key in mdel_patch')
|
| 764 |
+
return ([], 0, [])
|
| 765 |
+
bo = mp['brushnet']
|
| 766 |
+
if 'model' not in bo:
|
| 767 |
+
print('BrushNet inference: there is no model key in brushnet')
|
| 768 |
+
return ([], 0, [])
|
| 769 |
+
brushnet = bo['model']
|
| 770 |
+
if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
|
| 771 |
+
print('BrushNet model is not a BrushNetModel class')
|
| 772 |
+
return ([], 0, [])
|
| 773 |
+
|
| 774 |
+
torch_dtype = bo['dtype']
|
| 775 |
+
cl_list = bo['latents']
|
| 776 |
+
brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
|
| 777 |
+
pe = bo['prompt_embeds']
|
| 778 |
+
npe = bo['negative_prompt_embeds']
|
| 779 |
+
ppe, nppe, time_ids = bo['add_embeds']
|
| 780 |
+
|
| 781 |
+
#do_classifier_free_guidance = mp['free_guidance']
|
| 782 |
+
do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
|
| 783 |
+
|
| 784 |
+
x = x.detach().clone()
|
| 785 |
+
x = x.to(torch_dtype).to(brushnet.device)
|
| 786 |
+
|
| 787 |
+
timesteps = timesteps.detach().clone()
|
| 788 |
+
timesteps = timesteps.to(torch_dtype).to(brushnet.device)
|
| 789 |
+
|
| 790 |
+
total_steps = mp['total_steps']
|
| 791 |
+
step = mp['step']
|
| 792 |
+
|
| 793 |
+
added_cond_kwargs = {}
|
| 794 |
+
|
| 795 |
+
if do_classifier_free_guidance and step == 0:
|
| 796 |
+
print('BrushNet inference: do_classifier_free_guidance is True')
|
| 797 |
+
|
| 798 |
+
sub_idx = None
|
| 799 |
+
if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
|
| 800 |
+
sub_idx = transformer_options['ad_params']['sub_idxs']
|
| 801 |
+
|
| 802 |
+
# we have batch input images
|
| 803 |
+
batch = cl_list[0].shape[0]
|
| 804 |
+
# we have incoming latents
|
| 805 |
+
latents_incoming = x.shape[0]
|
| 806 |
+
# and we already got some
|
| 807 |
+
latents_got = bo['latent_id']
|
| 808 |
+
if step == 0 or batch > 1:
|
| 809 |
+
print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
|
| 810 |
+
% (step, batch, latents_incoming, latents_got))
|
| 811 |
+
|
| 812 |
+
image_latents = []
|
| 813 |
+
masks = []
|
| 814 |
+
prompt_embeds = []
|
| 815 |
+
negative_prompt_embeds = []
|
| 816 |
+
pooled_prompt_embeds = []
|
| 817 |
+
negative_pooled_prompt_embeds = []
|
| 818 |
+
if sub_idx:
|
| 819 |
+
# AnimateDiff indexes detected
|
| 820 |
+
if step == 0:
|
| 821 |
+
print('BrushNet inference: AnimateDiff indexes detected and applied')
|
| 822 |
+
|
| 823 |
+
batch = len(sub_idx)
|
| 824 |
+
|
| 825 |
+
if do_classifier_free_guidance:
|
| 826 |
+
for i in sub_idx:
|
| 827 |
+
image_latents.append(cl_list[0][i][None,:,:,:])
|
| 828 |
+
masks.append(cl_list[1][i][None,:,:,:])
|
| 829 |
+
prompt_embeds.append(pe)
|
| 830 |
+
negative_prompt_embeds.append(npe)
|
| 831 |
+
pooled_prompt_embeds.append(ppe)
|
| 832 |
+
negative_pooled_prompt_embeds.append(nppe)
|
| 833 |
+
for i in sub_idx:
|
| 834 |
+
image_latents.append(cl_list[0][i][None,:,:,:])
|
| 835 |
+
masks.append(cl_list[1][i][None,:,:,:])
|
| 836 |
+
else:
|
| 837 |
+
for i in sub_idx:
|
| 838 |
+
image_latents.append(cl_list[0][i][None,:,:,:])
|
| 839 |
+
masks.append(cl_list[1][i][None,:,:,:])
|
| 840 |
+
prompt_embeds.append(pe)
|
| 841 |
+
pooled_prompt_embeds.append(ppe)
|
| 842 |
+
else:
|
| 843 |
+
# do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
|
| 844 |
+
continue_batch = True
|
| 845 |
+
for i in range(latents_incoming):
|
| 846 |
+
number = latents_got + i
|
| 847 |
+
if number < batch:
|
| 848 |
+
# 1st pass, cond
|
| 849 |
+
image_latents.append(cl_list[0][number][None,:,:,:])
|
| 850 |
+
masks.append(cl_list[1][number][None,:,:,:])
|
| 851 |
+
prompt_embeds.append(pe)
|
| 852 |
+
pooled_prompt_embeds.append(ppe)
|
| 853 |
+
elif do_classifier_free_guidance and number < batch * 2:
|
| 854 |
+
# 2nd pass, uncond
|
| 855 |
+
image_latents.append(cl_list[0][number-batch][None,:,:,:])
|
| 856 |
+
masks.append(cl_list[1][number-batch][None,:,:,:])
|
| 857 |
+
negative_prompt_embeds.append(npe)
|
| 858 |
+
negative_pooled_prompt_embeds.append(nppe)
|
| 859 |
+
else:
|
| 860 |
+
# latent batch
|
| 861 |
+
image_latents.append(cl_list[0][0][None,:,:,:])
|
| 862 |
+
masks.append(cl_list[1][0][None,:,:,:])
|
| 863 |
+
prompt_embeds.append(pe)
|
| 864 |
+
pooled_prompt_embeds.append(ppe)
|
| 865 |
+
latents_got = -i
|
| 866 |
+
continue_batch = False
|
| 867 |
+
|
| 868 |
+
if continue_batch:
|
| 869 |
+
# we don't have full batch yet
|
| 870 |
+
if do_classifier_free_guidance:
|
| 871 |
+
if number < batch * 2 - 1:
|
| 872 |
+
bo['latent_id'] = number + 1
|
| 873 |
+
else:
|
| 874 |
+
bo['latent_id'] = 0
|
| 875 |
+
else:
|
| 876 |
+
if number < batch - 1:
|
| 877 |
+
bo['latent_id'] = number + 1
|
| 878 |
+
else:
|
| 879 |
+
bo['latent_id'] = 0
|
| 880 |
+
else:
|
| 881 |
+
bo['latent_id'] = 0
|
| 882 |
+
|
| 883 |
+
cl = []
|
| 884 |
+
for il, m in zip(image_latents, masks):
|
| 885 |
+
cl.append(torch.concat([il, m], dim=1))
|
| 886 |
+
cl2apply = torch.concat(cl, dim=0)
|
| 887 |
+
|
| 888 |
+
conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
|
| 889 |
+
|
| 890 |
+
prompt_embeds.extend(negative_prompt_embeds)
|
| 891 |
+
prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
|
| 892 |
+
|
| 893 |
+
if ppe is not None:
|
| 894 |
+
added_cond_kwargs = {}
|
| 895 |
+
added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
|
| 896 |
+
|
| 897 |
+
pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
|
| 898 |
+
pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
|
| 899 |
+
added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
|
| 900 |
+
else:
|
| 901 |
+
added_cond_kwargs = None
|
| 902 |
+
|
| 903 |
+
if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
|
| 904 |
+
if step == 0:
|
| 905 |
+
print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
|
| 906 |
+
conditioning_latents = torch.nn.functional.interpolate(
|
| 907 |
+
conditioning_latents, size=(
|
| 908 |
+
x.shape[2],
|
| 909 |
+
x.shape[3],
|
| 910 |
+
), mode='bicubic',
|
| 911 |
+
).to(torch_dtype).to(brushnet.device)
|
| 912 |
+
|
| 913 |
+
if step == 0:
|
| 914 |
+
print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape)
|
| 915 |
+
|
| 916 |
+
if step < control_guidance_start or step > control_guidance_end:
|
| 917 |
+
cond_scale = 0.0
|
| 918 |
+
else:
|
| 919 |
+
cond_scale = brushnet_conditioning_scale
|
| 920 |
+
|
| 921 |
+
return brushnet(x,
|
| 922 |
+
encoder_hidden_states=prompt_embeds,
|
| 923 |
+
brushnet_cond=conditioning_latents,
|
| 924 |
+
timestep = timesteps,
|
| 925 |
+
conditioning_scale=cond_scale,
|
| 926 |
+
guess_mode=False,
|
| 927 |
+
added_cond_kwargs=added_cond_kwargs,
|
| 928 |
+
return_dict=False,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
# This is main patch function
|
| 933 |
+
def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
|
| 934 |
+
controls,
|
| 935 |
+
prompt_embeds, negative_prompt_embeds,
|
| 936 |
+
pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids):
|
| 937 |
+
|
| 938 |
+
is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
|
| 939 |
+
|
| 940 |
+
if is_SDXL:
|
| 941 |
+
input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
|
| 942 |
+
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
| 943 |
+
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
| 944 |
+
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
| 945 |
+
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
| 946 |
+
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
| 947 |
+
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
| 948 |
+
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
| 949 |
+
[8, comfy.ldm.modules.attention.SpatialTransformer]]
|
| 950 |
+
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
|
| 951 |
+
output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
|
| 952 |
+
[1, comfy.ldm.modules.attention.SpatialTransformer],
|
| 953 |
+
[2, comfy.ldm.modules.attention.SpatialTransformer],
|
| 954 |
+
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
| 955 |
+
[3, comfy.ldm.modules.attention.SpatialTransformer],
|
| 956 |
+
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
| 957 |
+
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
| 958 |
+
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
| 959 |
+
[6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
| 960 |
+
[7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
| 961 |
+
[8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
|
| 962 |
+
else:
|
| 963 |
+
input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
|
| 964 |
+
[1, comfy.ldm.modules.attention.SpatialTransformer],
|
| 965 |
+
[2, comfy.ldm.modules.attention.SpatialTransformer],
|
| 966 |
+
[3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
| 967 |
+
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
| 968 |
+
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
| 969 |
+
[6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
| 970 |
+
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
| 971 |
+
[8, comfy.ldm.modules.attention.SpatialTransformer],
|
| 972 |
+
[9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
|
| 973 |
+
[10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
| 974 |
+
[11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
|
| 975 |
+
middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
|
| 976 |
+
output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
| 977 |
+
[1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
| 978 |
+
[2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
|
| 979 |
+
[2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
| 980 |
+
[3, comfy.ldm.modules.attention.SpatialTransformer],
|
| 981 |
+
[4, comfy.ldm.modules.attention.SpatialTransformer],
|
| 982 |
+
[5, comfy.ldm.modules.attention.SpatialTransformer],
|
| 983 |
+
[5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
| 984 |
+
[6, comfy.ldm.modules.attention.SpatialTransformer],
|
| 985 |
+
[7, comfy.ldm.modules.attention.SpatialTransformer],
|
| 986 |
+
[8, comfy.ldm.modules.attention.SpatialTransformer],
|
| 987 |
+
[8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
|
| 988 |
+
[9, comfy.ldm.modules.attention.SpatialTransformer],
|
| 989 |
+
[10, comfy.ldm.modules.attention.SpatialTransformer],
|
| 990 |
+
[11, comfy.ldm.modules.attention.SpatialTransformer]]
|
| 991 |
+
|
| 992 |
+
def last_layer_index(block, tp):
|
| 993 |
+
layer_list = []
|
| 994 |
+
for layer in block:
|
| 995 |
+
layer_list.append(type(layer))
|
| 996 |
+
layer_list.reverse()
|
| 997 |
+
if tp not in layer_list:
|
| 998 |
+
return -1, layer_list.reverse()
|
| 999 |
+
return len(layer_list) - 1 - layer_list.index(tp), layer_list
|
| 1000 |
+
|
| 1001 |
+
def brushnet_forward(model, x, timesteps, transformer_options, control):
|
| 1002 |
+
if 'brushnet' not in transformer_options['model_patch']:
|
| 1003 |
+
input_samples = []
|
| 1004 |
+
mid_sample = 0
|
| 1005 |
+
output_samples = []
|
| 1006 |
+
else:
|
| 1007 |
+
# brushnet inference
|
| 1008 |
+
input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options)
|
| 1009 |
+
|
| 1010 |
+
# give additional samples to blocks
|
| 1011 |
+
for i, tp in input_blocks:
|
| 1012 |
+
idx, layer_list = last_layer_index(model.input_blocks[i], tp)
|
| 1013 |
+
if idx < 0:
|
| 1014 |
+
print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
|
| 1015 |
+
continue
|
| 1016 |
+
model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
|
| 1017 |
+
|
| 1018 |
+
idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
|
| 1019 |
+
if idx < 0:
|
| 1020 |
+
print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
|
| 1021 |
+
model.middle_block[idx].add_sample_after = mid_sample
|
| 1022 |
+
|
| 1023 |
+
for i, tp in output_blocks:
|
| 1024 |
+
idx, layer_list = last_layer_index(model.output_blocks[i], tp)
|
| 1025 |
+
if idx < 0:
|
| 1026 |
+
print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
|
| 1027 |
+
continue
|
| 1028 |
+
model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
|
| 1029 |
+
|
| 1030 |
+
patch_model_function_wrapper(model, brushnet_forward)
|
| 1031 |
+
|
| 1032 |
+
to = add_model_patch_option(model)
|
| 1033 |
+
mp = to['model_patch']
|
| 1034 |
+
if 'brushnet' not in mp:
|
| 1035 |
+
mp['brushnet'] = {}
|
| 1036 |
+
bo = mp['brushnet']
|
| 1037 |
+
|
| 1038 |
+
bo['model'] = brushnet
|
| 1039 |
+
bo['dtype'] = torch_dtype
|
| 1040 |
+
bo['latents'] = conditioning_latents
|
| 1041 |
+
bo['controls'] = controls
|
| 1042 |
+
bo['prompt_embeds'] = prompt_embeds
|
| 1043 |
+
bo['negative_prompt_embeds'] = negative_prompt_embeds
|
| 1044 |
+
bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
|
| 1045 |
+
bo['latent_id'] = 0
|
| 1046 |
+
|
| 1047 |
+
# patch layers `forward` so we can apply brushnet
|
| 1048 |
+
def forward_patched_by_brushnet(self, x, *args, **kwargs):
|
| 1049 |
+
h = self.original_forward(x, *args, **kwargs)
|
| 1050 |
+
if hasattr(self, 'add_sample_after') and type(self):
|
| 1051 |
+
to_add = self.add_sample_after
|
| 1052 |
+
if torch.is_tensor(to_add):
|
| 1053 |
+
# interpolate due to RAUNet
|
| 1054 |
+
if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
|
| 1055 |
+
to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
|
| 1056 |
+
h += to_add.to(h.dtype).to(h.device)
|
| 1057 |
+
else:
|
| 1058 |
+
h += self.add_sample_after
|
| 1059 |
+
self.add_sample_after = 0
|
| 1060 |
+
return h
|
| 1061 |
+
|
| 1062 |
+
for i, block in enumerate(model.model.diffusion_model.input_blocks):
|
| 1063 |
+
for j, layer in enumerate(block):
|
| 1064 |
+
if not hasattr(layer, 'original_forward'):
|
| 1065 |
+
layer.original_forward = layer.forward
|
| 1066 |
+
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
| 1067 |
+
layer.add_sample_after = 0
|
| 1068 |
+
|
| 1069 |
+
for j, layer in enumerate(model.model.diffusion_model.middle_block):
|
| 1070 |
+
if not hasattr(layer, 'original_forward'):
|
| 1071 |
+
layer.original_forward = layer.forward
|
| 1072 |
+
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
| 1073 |
+
layer.add_sample_after = 0
|
| 1074 |
+
|
| 1075 |
+
for i, block in enumerate(model.model.diffusion_model.output_blocks):
|
| 1076 |
+
for j, layer in enumerate(block):
|
| 1077 |
+
if not hasattr(layer, 'original_forward'):
|
| 1078 |
+
layer.original_forward = layer.forward
|
| 1079 |
+
layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
|
| 1080 |
+
layer.add_sample_after = 0
|
ComfyUI-BrushNet/model_patch.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import comfy
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Check and add 'model_patch' to model.model_options['transformer_options']
|
| 6 |
+
def add_model_patch_option(model):
|
| 7 |
+
if 'transformer_options' not in model.model_options:
|
| 8 |
+
model.model_options['transformer_options'] = {}
|
| 9 |
+
to = model.model_options['transformer_options']
|
| 10 |
+
if "model_patch" not in to:
|
| 11 |
+
to["model_patch"] = {}
|
| 12 |
+
return to
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Patch model with model_function_wrapper
|
| 16 |
+
def patch_model_function_wrapper(model, forward_patch):
|
| 17 |
+
|
| 18 |
+
def brushnet_model_function_wrapper(apply_model_method, options_dict):
|
| 19 |
+
to = options_dict['c']['transformer_options']
|
| 20 |
+
|
| 21 |
+
control = None
|
| 22 |
+
if 'control' in options_dict['c']:
|
| 23 |
+
control = options_dict['c']['control']
|
| 24 |
+
|
| 25 |
+
x = options_dict['input']
|
| 26 |
+
timestep = options_dict['timestep']
|
| 27 |
+
|
| 28 |
+
# check if there are patches to execute
|
| 29 |
+
if 'model_patch' not in to or 'forward' not in to['model_patch']:
|
| 30 |
+
return apply_model_method(x, timestep, **options_dict['c'])
|
| 31 |
+
|
| 32 |
+
mp = to['model_patch']
|
| 33 |
+
unet = mp['unet']
|
| 34 |
+
|
| 35 |
+
all_sigmas = mp['all_sigmas']
|
| 36 |
+
sigma = to['sigmas'][0].item()
|
| 37 |
+
total_steps = all_sigmas.shape[0] - 1
|
| 38 |
+
step = torch.argmin((all_sigmas - sigma).abs()).item()
|
| 39 |
+
|
| 40 |
+
mp['step'] = step
|
| 41 |
+
mp['total_steps'] = total_steps
|
| 42 |
+
|
| 43 |
+
# comfy.model_base.apply_model
|
| 44 |
+
xc = model.model.model_sampling.calculate_input(timestep, x)
|
| 45 |
+
if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
|
| 46 |
+
xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
|
| 47 |
+
t = model.model.model_sampling.timestep(timestep).float()
|
| 48 |
+
# execute all patches
|
| 49 |
+
for method in mp['forward']:
|
| 50 |
+
method(unet, xc, t, to, control)
|
| 51 |
+
|
| 52 |
+
return apply_model_method(x, timestep, **options_dict['c'])
|
| 53 |
+
|
| 54 |
+
if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
|
| 55 |
+
print('BrushNet is going to replace existing model_function_wrapper:', model.model_options["model_function_wrapper"])
|
| 56 |
+
model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
|
| 57 |
+
|
| 58 |
+
to = add_model_patch_option(model)
|
| 59 |
+
mp = to['model_patch']
|
| 60 |
+
|
| 61 |
+
if isinstance(model.model.model_config, comfy.supported_models.SD15):
|
| 62 |
+
mp['SDXL'] = False
|
| 63 |
+
elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
|
| 64 |
+
mp['SDXL'] = True
|
| 65 |
+
else:
|
| 66 |
+
print('Base model type: ', type(model.model.model_config))
|
| 67 |
+
raise Exception("Unsupported model type: ", type(model.model.model_config))
|
| 68 |
+
|
| 69 |
+
if 'forward' not in mp:
|
| 70 |
+
mp['forward'] = [forward_patch]
|
| 71 |
+
else:
|
| 72 |
+
mp['forward'].append(forward_patch)
|
| 73 |
+
|
| 74 |
+
mp['unet'] = model.model.diffusion_model
|
| 75 |
+
mp['step'] = 0
|
| 76 |
+
mp['total_steps'] = 1
|
| 77 |
+
|
| 78 |
+
# apply patches to code
|
| 79 |
+
if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
|
| 80 |
+
comfy.samplers.original_sample = comfy.samplers.sample
|
| 81 |
+
comfy.samplers.sample = modified_sample
|
| 82 |
+
|
| 83 |
+
if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
|
| 84 |
+
'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
|
| 85 |
+
comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
|
| 86 |
+
comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
|
| 90 |
+
# The first versions had modified_common_ksampler, but it broke custom KSampler nodes
|
| 91 |
+
def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
|
| 92 |
+
latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
|
| 93 |
+
'''
|
| 94 |
+
Modified by BrushNet nodes
|
| 95 |
+
'''
|
| 96 |
+
cfg_guider = comfy.samplers.CFGGuider(model)
|
| 97 |
+
cfg_guider.set_conds(positive, negative)
|
| 98 |
+
cfg_guider.set_cfg(cfg)
|
| 99 |
+
|
| 100 |
+
### Modified part ######################################################################
|
| 101 |
+
#
|
| 102 |
+
to = add_model_patch_option(model)
|
| 103 |
+
to['model_patch']['all_sigmas'] = sigmas
|
| 104 |
+
#
|
| 105 |
+
#sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
|
| 106 |
+
#sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
|
| 107 |
+
#
|
| 108 |
+
#
|
| 109 |
+
#if math.isclose(cfg, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
|
| 110 |
+
# to['model_patch']['free_guidance'] = False
|
| 111 |
+
#else:
|
| 112 |
+
# to['model_patch']['free_guidance'] = True
|
| 113 |
+
#
|
| 114 |
+
#######################################################################################
|
| 115 |
+
|
| 116 |
+
return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
# To use Controlnet with RAUNet it is much easier to modify apply_control a little
|
| 120 |
+
def modified_apply_control(h, control, name):
|
| 121 |
+
'''
|
| 122 |
+
Modified by BrushNet nodes
|
| 123 |
+
'''
|
| 124 |
+
if control is not None and name in control and len(control[name]) > 0:
|
| 125 |
+
ctrl = control[name].pop()
|
| 126 |
+
if ctrl is not None:
|
| 127 |
+
if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
|
| 128 |
+
ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(h.dtype).to(h.device)
|
| 129 |
+
try:
|
| 130 |
+
h += ctrl
|
| 131 |
+
except:
|
| 132 |
+
print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
|
| 133 |
+
return h
|
| 134 |
+
|
ComfyUI-BrushNet/raunet_nodes.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn.functional as F
|
| 2 |
+
import comfy
|
| 3 |
+
|
| 4 |
+
from .model_patch import add_model_patch_option, patch_model_function_wrapper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class RAUNet:
|
| 9 |
+
|
| 10 |
+
@classmethod
|
| 11 |
+
def INPUT_TYPES(s):
|
| 12 |
+
return {"required":
|
| 13 |
+
{
|
| 14 |
+
"model": ("MODEL",),
|
| 15 |
+
"du_start": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
| 16 |
+
"du_end": ("INT", {"default": 4, "min": 0, "max": 10000}),
|
| 17 |
+
"xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}),
|
| 18 |
+
"xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}),
|
| 19 |
+
},
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
CATEGORY = "inpaint"
|
| 23 |
+
RETURN_TYPES = ("MODEL",)
|
| 24 |
+
RETURN_NAMES = ("model",)
|
| 25 |
+
|
| 26 |
+
FUNCTION = "model_update"
|
| 27 |
+
|
| 28 |
+
def model_update(self, model, du_start, du_end, xa_start, xa_end):
|
| 29 |
+
|
| 30 |
+
model = model.clone()
|
| 31 |
+
|
| 32 |
+
add_raunet_patch(model,
|
| 33 |
+
du_start,
|
| 34 |
+
du_end,
|
| 35 |
+
xa_start,
|
| 36 |
+
xa_end)
|
| 37 |
+
|
| 38 |
+
return (model,)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# This is main patch function
|
| 42 |
+
def add_raunet_patch(model, du_start, du_end, xa_start, xa_end):
|
| 43 |
+
|
| 44 |
+
def raunet_forward(model, x, timesteps, transformer_options, control):
|
| 45 |
+
if 'model_patch' not in transformer_options:
|
| 46 |
+
print("RAUNet: 'model_patch' not in transformer_options, skip")
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
mp = transformer_options['model_patch']
|
| 50 |
+
is_SDXL = mp['SDXL']
|
| 51 |
+
|
| 52 |
+
if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
|
| 53 |
+
print('RAUNet: model is SDXL, but input[6] != Downsample, skip')
|
| 54 |
+
return
|
| 55 |
+
|
| 56 |
+
if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
|
| 57 |
+
print('RAUNet: model is not SDXL, but input[3] != Downsample, skip')
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
if 'raunet' not in mp:
|
| 61 |
+
print('RAUNet: "raunet" not in model_patch options, skip')
|
| 62 |
+
return
|
| 63 |
+
|
| 64 |
+
if is_SDXL:
|
| 65 |
+
block = model.input_blocks[6][0]
|
| 66 |
+
else:
|
| 67 |
+
block = model.input_blocks[3][0]
|
| 68 |
+
|
| 69 |
+
total_steps = mp['total_steps']
|
| 70 |
+
step = mp['step']
|
| 71 |
+
|
| 72 |
+
ro = mp['raunet']
|
| 73 |
+
du_start = ro['du_start']
|
| 74 |
+
du_end = ro['du_end']
|
| 75 |
+
|
| 76 |
+
if step >= du_start and step < du_end:
|
| 77 |
+
block.op.stride = (4, 4)
|
| 78 |
+
block.op.padding = (2, 2)
|
| 79 |
+
block.op.dilation = (2, 2)
|
| 80 |
+
else:
|
| 81 |
+
block.op.stride = (2, 2)
|
| 82 |
+
block.op.padding = (1, 1)
|
| 83 |
+
block.op.dilation = (1, 1)
|
| 84 |
+
|
| 85 |
+
patch_model_function_wrapper(model, raunet_forward)
|
| 86 |
+
model.set_model_input_block_patch(in_xattn_patch)
|
| 87 |
+
model.set_model_output_block_patch(out_xattn_patch)
|
| 88 |
+
|
| 89 |
+
to = add_model_patch_option(model)
|
| 90 |
+
mp = to['model_patch']
|
| 91 |
+
if 'raunet' not in mp:
|
| 92 |
+
mp['raunet'] = {}
|
| 93 |
+
ro = mp['raunet']
|
| 94 |
+
|
| 95 |
+
ro['du_start'] = du_start
|
| 96 |
+
ro['du_end'] = du_end
|
| 97 |
+
ro['xa_start'] = xa_start
|
| 98 |
+
ro['xa_end'] = xa_end
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def in_xattn_patch(h, transformer_options):
|
| 102 |
+
# both SDXL and SD15 = (input,4)
|
| 103 |
+
if transformer_options["block"] != ("input", 4):
|
| 104 |
+
# wrong block
|
| 105 |
+
return h
|
| 106 |
+
if 'model_patch' not in transformer_options:
|
| 107 |
+
print("RAUNet (i-x-p): 'model_patch' not in transformer_options")
|
| 108 |
+
return h
|
| 109 |
+
mp = transformer_options['model_patch']
|
| 110 |
+
if 'raunet' not in mp:
|
| 111 |
+
print("RAUNet (i-x-p): 'raunet' not in model_patch options")
|
| 112 |
+
return h
|
| 113 |
+
|
| 114 |
+
step = mp['step']
|
| 115 |
+
ro = mp['raunet']
|
| 116 |
+
xa_start = ro['xa_start']
|
| 117 |
+
xa_end = ro['xa_end']
|
| 118 |
+
|
| 119 |
+
if step < xa_start or step >= xa_end:
|
| 120 |
+
return h
|
| 121 |
+
h = F.avg_pool2d(h, kernel_size=(2,2))
|
| 122 |
+
return h
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def out_xattn_patch(h, hsp, transformer_options):
|
| 126 |
+
if 'model_patch' not in transformer_options:
|
| 127 |
+
print("RAUNet (o-x-p): 'model_patch' not in transformer_options")
|
| 128 |
+
return h, hsp
|
| 129 |
+
mp = transformer_options['model_patch']
|
| 130 |
+
if 'raunet' not in mp:
|
| 131 |
+
print("RAUNet (o-x-p): 'raunet' not in model_patch options")
|
| 132 |
+
return h
|
| 133 |
+
|
| 134 |
+
step = mp['step']
|
| 135 |
+
is_SDXL = mp['SDXL']
|
| 136 |
+
ro = mp['raunet']
|
| 137 |
+
xa_start = ro['xa_start']
|
| 138 |
+
xa_end = ro['xa_end']
|
| 139 |
+
|
| 140 |
+
if is_SDXL:
|
| 141 |
+
if transformer_options["block"] != ("output", 5):
|
| 142 |
+
# wrong block
|
| 143 |
+
return h, hsp
|
| 144 |
+
else:
|
| 145 |
+
if transformer_options["block"] != ("output", 8):
|
| 146 |
+
# wrong block
|
| 147 |
+
return h, hsp
|
| 148 |
+
|
| 149 |
+
if step < xa_start or step >= xa_end:
|
| 150 |
+
return h, hsp
|
| 151 |
+
#error in hidiffusion codebase, size * 2 for particular sizes only
|
| 152 |
+
#re_size = (int(h.shape[-2] * 2), int(h.shape[-1] * 2))
|
| 153 |
+
re_size = (hsp.shape[-2], hsp.shape[-1])
|
| 154 |
+
h = F.interpolate(h, size=re_size, mode='bicubic')
|
| 155 |
+
|
| 156 |
+
return h, hsp
|
| 157 |
+
|
| 158 |
+
|
ComfyUI-BrushNet/requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
diffusers>=0.27.0
|
| 2 |
+
accelerate>=0.29.0
|
| 3 |
+
peft>=0.7.0
|
ComfyUI-Easy-Use/LICENSE
ADDED
|
@@ -0,0 +1,674 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
GNU GENERAL PUBLIC LICENSE
|
| 2 |
+
Version 3, 29 June 2007
|
| 3 |
+
|
| 4 |
+
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
| 5 |
+
Everyone is permitted to copy and distribute verbatim copies
|
| 6 |
+
of this license document, but changing it is not allowed.
|
| 7 |
+
|
| 8 |
+
Preamble
|
| 9 |
+
|
| 10 |
+
The GNU General Public License is a free, copyleft license for
|
| 11 |
+
software and other kinds of works.
|
| 12 |
+
|
| 13 |
+
The licenses for most software and other practical works are designed
|
| 14 |
+
to take away your freedom to share and change the works. By contrast,
|
| 15 |
+
the GNU General Public License is intended to guarantee your freedom to
|
| 16 |
+
share and change all versions of a program--to make sure it remains free
|
| 17 |
+
software for all its users. We, the Free Software Foundation, use the
|
| 18 |
+
GNU General Public License for most of our software; it applies also to
|
| 19 |
+
any other work released this way by its authors. You can apply it to
|
| 20 |
+
your programs, too.
|
| 21 |
+
|
| 22 |
+
When we speak of free software, we are referring to freedom, not
|
| 23 |
+
price. Our General Public Licenses are designed to make sure that you
|
| 24 |
+
have the freedom to distribute copies of free software (and charge for
|
| 25 |
+
them if you wish), that you receive source code or can get it if you
|
| 26 |
+
want it, that you can change the software or use pieces of it in new
|
| 27 |
+
free programs, and that you know you can do these things.
|
| 28 |
+
|
| 29 |
+
To protect your rights, we need to prevent others from denying you
|
| 30 |
+
these rights or asking you to surrender the rights. Therefore, you have
|
| 31 |
+
certain responsibilities if you distribute copies of the software, or if
|
| 32 |
+
you modify it: responsibilities to respect the freedom of others.
|
| 33 |
+
|
| 34 |
+
For example, if you distribute copies of such a program, whether
|
| 35 |
+
gratis or for a fee, you must pass on to the recipients the same
|
| 36 |
+
freedoms that you received. You must make sure that they, too, receive
|
| 37 |
+
or can get the source code. And you must show them these terms so they
|
| 38 |
+
know their rights.
|
| 39 |
+
|
| 40 |
+
Developers that use the GNU GPL protect your rights with two steps:
|
| 41 |
+
(1) assert copyright on the software, and (2) offer you this License
|
| 42 |
+
giving you legal permission to copy, distribute and/or modify it.
|
| 43 |
+
|
| 44 |
+
For the developers' and authors' protection, the GPL clearly explains
|
| 45 |
+
that there is no warranty for this free software. For both users' and
|
| 46 |
+
authors' sake, the GPL requires that modified versions be marked as
|
| 47 |
+
changed, so that their problems will not be attributed erroneously to
|
| 48 |
+
authors of previous versions.
|
| 49 |
+
|
| 50 |
+
Some devices are designed to deny users access to install or run
|
| 51 |
+
modified versions of the software inside them, although the manufacturer
|
| 52 |
+
can do so. This is fundamentally incompatible with the aim of
|
| 53 |
+
protecting users' freedom to change the software. The systematic
|
| 54 |
+
pattern of such abuse occurs in the area of products for individuals to
|
| 55 |
+
use, which is precisely where it is most unacceptable. Therefore, we
|
| 56 |
+
have designed this version of the GPL to prohibit the practice for those
|
| 57 |
+
products. If such problems arise substantially in other domains, we
|
| 58 |
+
stand ready to extend this provision to those domains in future versions
|
| 59 |
+
of the GPL, as needed to protect the freedom of users.
|
| 60 |
+
|
| 61 |
+
Finally, every program is threatened constantly by software patents.
|
| 62 |
+
States should not allow patents to restrict development and use of
|
| 63 |
+
software on general-purpose computers, but in those that do, we wish to
|
| 64 |
+
avoid the special danger that patents applied to a free program could
|
| 65 |
+
make it effectively proprietary. To prevent this, the GPL assures that
|
| 66 |
+
patents cannot be used to render the program non-free.
|
| 67 |
+
|
| 68 |
+
The precise terms and conditions for copying, distribution and
|
| 69 |
+
modification follow.
|
| 70 |
+
|
| 71 |
+
TERMS AND CONDITIONS
|
| 72 |
+
|
| 73 |
+
0. Definitions.
|
| 74 |
+
|
| 75 |
+
"This License" refers to version 3 of the GNU General Public License.
|
| 76 |
+
|
| 77 |
+
"Copyright" also means copyright-like laws that apply to other kinds of
|
| 78 |
+
works, such as semiconductor masks.
|
| 79 |
+
|
| 80 |
+
"The Program" refers to any copyrightable work licensed under this
|
| 81 |
+
License. Each licensee is addressed as "you". "Licensees" and
|
| 82 |
+
"recipients" may be individuals or organizations.
|
| 83 |
+
|
| 84 |
+
To "modify" a work means to copy from or adapt all or part of the work
|
| 85 |
+
in a fashion requiring copyright permission, other than the making of an
|
| 86 |
+
exact copy. The resulting work is called a "modified version" of the
|
| 87 |
+
earlier work or a work "based on" the earlier work.
|
| 88 |
+
|
| 89 |
+
A "covered work" means either the unmodified Program or a work based
|
| 90 |
+
on the Program.
|
| 91 |
+
|
| 92 |
+
To "propagate" a work means to do anything with it that, without
|
| 93 |
+
permission, would make you directly or secondarily liable for
|
| 94 |
+
infringement under applicable copyright law, except executing it on a
|
| 95 |
+
computer or modifying a private copy. Propagation includes copying,
|
| 96 |
+
distribution (with or without modification), making available to the
|
| 97 |
+
public, and in some countries other activities as well.
|
| 98 |
+
|
| 99 |
+
To "convey" a work means any kind of propagation that enables other
|
| 100 |
+
parties to make or receive copies. Mere interaction with a user through
|
| 101 |
+
a computer network, with no transfer of a copy, is not conveying.
|
| 102 |
+
|
| 103 |
+
An interactive user interface displays "Appropriate Legal Notices"
|
| 104 |
+
to the extent that it includes a convenient and prominently visible
|
| 105 |
+
feature that (1) displays an appropriate copyright notice, and (2)
|
| 106 |
+
tells the user that there is no warranty for the work (except to the
|
| 107 |
+
extent that warranties are provided), that licensees may convey the
|
| 108 |
+
work under this License, and how to view a copy of this License. If
|
| 109 |
+
the interface presents a list of user commands or options, such as a
|
| 110 |
+
menu, a prominent item in the list meets this criterion.
|
| 111 |
+
|
| 112 |
+
1. Source Code.
|
| 113 |
+
|
| 114 |
+
The "source code" for a work means the preferred form of the work
|
| 115 |
+
for making modifications to it. "Object code" means any non-source
|
| 116 |
+
form of a work.
|
| 117 |
+
|
| 118 |
+
A "Standard Interface" means an interface that either is an official
|
| 119 |
+
standard defined by a recognized standards body, or, in the case of
|
| 120 |
+
interfaces specified for a particular programming language, one that
|
| 121 |
+
is widely used among developers working in that language.
|
| 122 |
+
|
| 123 |
+
The "System Libraries" of an executable work include anything, other
|
| 124 |
+
than the work as a whole, that (a) is included in the normal form of
|
| 125 |
+
packaging a Major Component, but which is not part of that Major
|
| 126 |
+
Component, and (b) serves only to enable use of the work with that
|
| 127 |
+
Major Component, or to implement a Standard Interface for which an
|
| 128 |
+
implementation is available to the public in source code form. A
|
| 129 |
+
"Major Component", in this context, means a major essential component
|
| 130 |
+
(kernel, window system, and so on) of the specific operating system
|
| 131 |
+
(if any) on which the executable work runs, or a compiler used to
|
| 132 |
+
produce the work, or an object code interpreter used to run it.
|
| 133 |
+
|
| 134 |
+
The "Corresponding Source" for a work in object code form means all
|
| 135 |
+
the source code needed to generate, install, and (for an executable
|
| 136 |
+
work) run the object code and to modify the work, including scripts to
|
| 137 |
+
control those activities. However, it does not include the work's
|
| 138 |
+
System Libraries, or general-purpose tools or generally available free
|
| 139 |
+
programs which are used unmodified in performing those activities but
|
| 140 |
+
which are not part of the work. For example, Corresponding Source
|
| 141 |
+
includes interface definition files associated with source files for
|
| 142 |
+
the work, and the source code for shared libraries and dynamically
|
| 143 |
+
linked subprograms that the work is specifically designed to require,
|
| 144 |
+
such as by intimate data communication or control flow between those
|
| 145 |
+
subprograms and other parts of the work.
|
| 146 |
+
|
| 147 |
+
The Corresponding Source need not include anything that users
|
| 148 |
+
can regenerate automatically from other parts of the Corresponding
|
| 149 |
+
Source.
|
| 150 |
+
|
| 151 |
+
The Corresponding Source for a work in source code form is that
|
| 152 |
+
same work.
|
| 153 |
+
|
| 154 |
+
2. Basic Permissions.
|
| 155 |
+
|
| 156 |
+
All rights granted under this License are granted for the term of
|
| 157 |
+
copyright on the Program, and are irrevocable provided the stated
|
| 158 |
+
conditions are met. This License explicitly affirms your unlimited
|
| 159 |
+
permission to run the unmodified Program. The output from running a
|
| 160 |
+
covered work is covered by this License only if the output, given its
|
| 161 |
+
content, constitutes a covered work. This License acknowledges your
|
| 162 |
+
rights of fair use or other equivalent, as provided by copyright law.
|
| 163 |
+
|
| 164 |
+
You may make, run and propagate covered works that you do not
|
| 165 |
+
convey, without conditions so long as your license otherwise remains
|
| 166 |
+
in force. You may convey covered works to others for the sole purpose
|
| 167 |
+
of having them make modifications exclusively for you, or provide you
|
| 168 |
+
with facilities for running those works, provided that you comply with
|
| 169 |
+
the terms of this License in conveying all material for which you do
|
| 170 |
+
not control copyright. Those thus making or running the covered works
|
| 171 |
+
for you must do so exclusively on your behalf, under your direction
|
| 172 |
+
and control, on terms that prohibit them from making any copies of
|
| 173 |
+
your copyrighted material outside their relationship with you.
|
| 174 |
+
|
| 175 |
+
Conveying under any other circumstances is permitted solely under
|
| 176 |
+
the conditions stated below. Sublicensing is not allowed; section 10
|
| 177 |
+
makes it unnecessary.
|
| 178 |
+
|
| 179 |
+
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
| 180 |
+
|
| 181 |
+
No covered work shall be deemed part of an effective technological
|
| 182 |
+
measure under any applicable law fulfilling obligations under article
|
| 183 |
+
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
| 184 |
+
similar laws prohibiting or restricting circumvention of such
|
| 185 |
+
measures.
|
| 186 |
+
|
| 187 |
+
When you convey a covered work, you waive any legal power to forbid
|
| 188 |
+
circumvention of technological measures to the extent such circumvention
|
| 189 |
+
is effected by exercising rights under this License with respect to
|
| 190 |
+
the covered work, and you disclaim any intention to limit operation or
|
| 191 |
+
modification of the work as a means of enforcing, against the work's
|
| 192 |
+
users, your or third parties' legal rights to forbid circumvention of
|
| 193 |
+
technological measures.
|
| 194 |
+
|
| 195 |
+
4. Conveying Verbatim Copies.
|
| 196 |
+
|
| 197 |
+
You may convey verbatim copies of the Program's source code as you
|
| 198 |
+
receive it, in any medium, provided that you conspicuously and
|
| 199 |
+
appropriately publish on each copy an appropriate copyright notice;
|
| 200 |
+
keep intact all notices stating that this License and any
|
| 201 |
+
non-permissive terms added in accord with section 7 apply to the code;
|
| 202 |
+
keep intact all notices of the absence of any warranty; and give all
|
| 203 |
+
recipients a copy of this License along with the Program.
|
| 204 |
+
|
| 205 |
+
You may charge any price or no price for each copy that you convey,
|
| 206 |
+
and you may offer support or warranty protection for a fee.
|
| 207 |
+
|
| 208 |
+
5. Conveying Modified Source Versions.
|
| 209 |
+
|
| 210 |
+
You may convey a work based on the Program, or the modifications to
|
| 211 |
+
produce it from the Program, in the form of source code under the
|
| 212 |
+
terms of section 4, provided that you also meet all of these conditions:
|
| 213 |
+
|
| 214 |
+
a) The work must carry prominent notices stating that you modified
|
| 215 |
+
it, and giving a relevant date.
|
| 216 |
+
|
| 217 |
+
b) The work must carry prominent notices stating that it is
|
| 218 |
+
released under this License and any conditions added under section
|
| 219 |
+
7. This requirement modifies the requirement in section 4 to
|
| 220 |
+
"keep intact all notices".
|
| 221 |
+
|
| 222 |
+
c) You must license the entire work, as a whole, under this
|
| 223 |
+
License to anyone who comes into possession of a copy. This
|
| 224 |
+
License will therefore apply, along with any applicable section 7
|
| 225 |
+
additional terms, to the whole of the work, and all its parts,
|
| 226 |
+
regardless of how they are packaged. This License gives no
|
| 227 |
+
permission to license the work in any other way, but it does not
|
| 228 |
+
invalidate such permission if you have separately received it.
|
| 229 |
+
|
| 230 |
+
d) If the work has interactive user interfaces, each must display
|
| 231 |
+
Appropriate Legal Notices; however, if the Program has interactive
|
| 232 |
+
interfaces that do not display Appropriate Legal Notices, your
|
| 233 |
+
work need not make them do so.
|
| 234 |
+
|
| 235 |
+
A compilation of a covered work with other separate and independent
|
| 236 |
+
works, which are not by their nature extensions of the covered work,
|
| 237 |
+
and which are not combined with it such as to form a larger program,
|
| 238 |
+
in or on a volume of a storage or distribution medium, is called an
|
| 239 |
+
"aggregate" if the compilation and its resulting copyright are not
|
| 240 |
+
used to limit the access or legal rights of the compilation's users
|
| 241 |
+
beyond what the individual works permit. Inclusion of a covered work
|
| 242 |
+
in an aggregate does not cause this License to apply to the other
|
| 243 |
+
parts of the aggregate.
|
| 244 |
+
|
| 245 |
+
6. Conveying Non-Source Forms.
|
| 246 |
+
|
| 247 |
+
You may convey a covered work in object code form under the terms
|
| 248 |
+
of sections 4 and 5, provided that you also convey the
|
| 249 |
+
machine-readable Corresponding Source under the terms of this License,
|
| 250 |
+
in one of these ways:
|
| 251 |
+
|
| 252 |
+
a) Convey the object code in, or embodied in, a physical product
|
| 253 |
+
(including a physical distribution medium), accompanied by the
|
| 254 |
+
Corresponding Source fixed on a durable physical medium
|
| 255 |
+
customarily used for software interchange.
|
| 256 |
+
|
| 257 |
+
b) Convey the object code in, or embodied in, a physical product
|
| 258 |
+
(including a physical distribution medium), accompanied by a
|
| 259 |
+
written offer, valid for at least three years and valid for as
|
| 260 |
+
long as you offer spare parts or customer support for that product
|
| 261 |
+
model, to give anyone who possesses the object code either (1) a
|
| 262 |
+
copy of the Corresponding Source for all the software in the
|
| 263 |
+
product that is covered by this License, on a durable physical
|
| 264 |
+
medium customarily used for software interchange, for a price no
|
| 265 |
+
more than your reasonable cost of physically performing this
|
| 266 |
+
conveying of source, or (2) access to copy the
|
| 267 |
+
Corresponding Source from a network server at no charge.
|
| 268 |
+
|
| 269 |
+
c) Convey individual copies of the object code with a copy of the
|
| 270 |
+
written offer to provide the Corresponding Source. This
|
| 271 |
+
alternative is allowed only occasionally and noncommercially, and
|
| 272 |
+
only if you received the object code with such an offer, in accord
|
| 273 |
+
with subsection 6b.
|
| 274 |
+
|
| 275 |
+
d) Convey the object code by offering access from a designated
|
| 276 |
+
place (gratis or for a charge), and offer equivalent access to the
|
| 277 |
+
Corresponding Source in the same way through the same place at no
|
| 278 |
+
further charge. You need not require recipients to copy the
|
| 279 |
+
Corresponding Source along with the object code. If the place to
|
| 280 |
+
copy the object code is a network server, the Corresponding Source
|
| 281 |
+
may be on a different server (operated by you or a third party)
|
| 282 |
+
that supports equivalent copying facilities, provided you maintain
|
| 283 |
+
clear directions next to the object code saying where to find the
|
| 284 |
+
Corresponding Source. Regardless of what server hosts the
|
| 285 |
+
Corresponding Source, you remain obligated to ensure that it is
|
| 286 |
+
available for as long as needed to satisfy these requirements.
|
| 287 |
+
|
| 288 |
+
e) Convey the object code using peer-to-peer transmission, provided
|
| 289 |
+
you inform other peers where the object code and Corresponding
|
| 290 |
+
Source of the work are being offered to the general public at no
|
| 291 |
+
charge under subsection 6d.
|
| 292 |
+
|
| 293 |
+
A separable portion of the object code, whose source code is excluded
|
| 294 |
+
from the Corresponding Source as a System Library, need not be
|
| 295 |
+
included in conveying the object code work.
|
| 296 |
+
|
| 297 |
+
A "User Product" is either (1) a "consumer product", which means any
|
| 298 |
+
tangible personal property which is normally used for personal, family,
|
| 299 |
+
or household purposes, or (2) anything designed or sold for incorporation
|
| 300 |
+
into a dwelling. In determining whether a product is a consumer product,
|
| 301 |
+
doubtful cases shall be resolved in favor of coverage. For a particular
|
| 302 |
+
product received by a particular user, "normally used" refers to a
|
| 303 |
+
typical or common use of that class of product, regardless of the status
|
| 304 |
+
of the particular user or of the way in which the particular user
|
| 305 |
+
actually uses, or expects or is expected to use, the product. A product
|
| 306 |
+
is a consumer product regardless of whether the product has substantial
|
| 307 |
+
commercial, industrial or non-consumer uses, unless such uses represent
|
| 308 |
+
the only significant mode of use of the product.
|
| 309 |
+
|
| 310 |
+
"Installation Information" for a User Product means any methods,
|
| 311 |
+
procedures, authorization keys, or other information required to install
|
| 312 |
+
and execute modified versions of a covered work in that User Product from
|
| 313 |
+
a modified version of its Corresponding Source. The information must
|
| 314 |
+
suffice to ensure that the continued functioning of the modified object
|
| 315 |
+
code is in no case prevented or interfered with solely because
|
| 316 |
+
modification has been made.
|
| 317 |
+
|
| 318 |
+
If you convey an object code work under this section in, or with, or
|
| 319 |
+
specifically for use in, a User Product, and the conveying occurs as
|
| 320 |
+
part of a transaction in which the right of possession and use of the
|
| 321 |
+
User Product is transferred to the recipient in perpetuity or for a
|
| 322 |
+
fixed term (regardless of how the transaction is characterized), the
|
| 323 |
+
Corresponding Source conveyed under this section must be accompanied
|
| 324 |
+
by the Installation Information. But this requirement does not apply
|
| 325 |
+
if neither you nor any third party retains the ability to install
|
| 326 |
+
modified object code on the User Product (for example, the work has
|
| 327 |
+
been installed in ROM).
|
| 328 |
+
|
| 329 |
+
The requirement to provide Installation Information does not include a
|
| 330 |
+
requirement to continue to provide support service, warranty, or updates
|
| 331 |
+
for a work that has been modified or installed by the recipient, or for
|
| 332 |
+
the User Product in which it has been modified or installed. Access to a
|
| 333 |
+
network may be denied when the modification itself materially and
|
| 334 |
+
adversely affects the operation of the network or violates the rules and
|
| 335 |
+
protocols for communication across the network.
|
| 336 |
+
|
| 337 |
+
Corresponding Source conveyed, and Installation Information provided,
|
| 338 |
+
in accord with this section must be in a format that is publicly
|
| 339 |
+
documented (and with an implementation available to the public in
|
| 340 |
+
source code form), and must require no special password or key for
|
| 341 |
+
unpacking, reading or copying.
|
| 342 |
+
|
| 343 |
+
7. Additional Terms.
|
| 344 |
+
|
| 345 |
+
"Additional permissions" are terms that supplement the terms of this
|
| 346 |
+
License by making exceptions from one or more of its conditions.
|
| 347 |
+
Additional permissions that are applicable to the entire Program shall
|
| 348 |
+
be treated as though they were included in this License, to the extent
|
| 349 |
+
that they are valid under applicable law. If additional permissions
|
| 350 |
+
apply only to part of the Program, that part may be used separately
|
| 351 |
+
under those permissions, but the entire Program remains governed by
|
| 352 |
+
this License without regard to the additional permissions.
|
| 353 |
+
|
| 354 |
+
When you convey a copy of a covered work, you may at your option
|
| 355 |
+
remove any additional permissions from that copy, or from any part of
|
| 356 |
+
it. (Additional permissions may be written to require their own
|
| 357 |
+
removal in certain cases when you modify the work.) You may place
|
| 358 |
+
additional permissions on material, added by you to a covered work,
|
| 359 |
+
for which you have or can give appropriate copyright permission.
|
| 360 |
+
|
| 361 |
+
Notwithstanding any other provision of this License, for material you
|
| 362 |
+
add to a covered work, you may (if authorized by the copyright holders of
|
| 363 |
+
that material) supplement the terms of this License with terms:
|
| 364 |
+
|
| 365 |
+
a) Disclaiming warranty or limiting liability differently from the
|
| 366 |
+
terms of sections 15 and 16 of this License; or
|
| 367 |
+
|
| 368 |
+
b) Requiring preservation of specified reasonable legal notices or
|
| 369 |
+
author attributions in that material or in the Appropriate Legal
|
| 370 |
+
Notices displayed by works containing it; or
|
| 371 |
+
|
| 372 |
+
c) Prohibiting misrepresentation of the origin of that material, or
|
| 373 |
+
requiring that modified versions of such material be marked in
|
| 374 |
+
reasonable ways as different from the original version; or
|
| 375 |
+
|
| 376 |
+
d) Limiting the use for publicity purposes of names of licensors or
|
| 377 |
+
authors of the material; or
|
| 378 |
+
|
| 379 |
+
e) Declining to grant rights under trademark law for use of some
|
| 380 |
+
trade names, trademarks, or service marks; or
|
| 381 |
+
|
| 382 |
+
f) Requiring indemnification of licensors and authors of that
|
| 383 |
+
material by anyone who conveys the material (or modified versions of
|
| 384 |
+
it) with contractual assumptions of liability to the recipient, for
|
| 385 |
+
any liability that these contractual assumptions directly impose on
|
| 386 |
+
those licensors and authors.
|
| 387 |
+
|
| 388 |
+
All other non-permissive additional terms are considered "further
|
| 389 |
+
restrictions" within the meaning of section 10. If the Program as you
|
| 390 |
+
received it, or any part of it, contains a notice stating that it is
|
| 391 |
+
governed by this License along with a term that is a further
|
| 392 |
+
restriction, you may remove that term. If a license document contains
|
| 393 |
+
a further restriction but permits relicensing or conveying under this
|
| 394 |
+
License, you may add to a covered work material governed by the terms
|
| 395 |
+
of that license document, provided that the further restriction does
|
| 396 |
+
not survive such relicensing or conveying.
|
| 397 |
+
|
| 398 |
+
If you add terms to a covered work in accord with this section, you
|
| 399 |
+
must place, in the relevant source files, a statement of the
|
| 400 |
+
additional terms that apply to those files, or a notice indicating
|
| 401 |
+
where to find the applicable terms.
|
| 402 |
+
|
| 403 |
+
Additional terms, permissive or non-permissive, may be stated in the
|
| 404 |
+
form of a separately written license, or stated as exceptions;
|
| 405 |
+
the above requirements apply either way.
|
| 406 |
+
|
| 407 |
+
8. Termination.
|
| 408 |
+
|
| 409 |
+
You may not propagate or modify a covered work except as expressly
|
| 410 |
+
provided under this License. Any attempt otherwise to propagate or
|
| 411 |
+
modify it is void, and will automatically terminate your rights under
|
| 412 |
+
this License (including any patent licenses granted under the third
|
| 413 |
+
paragraph of section 11).
|
| 414 |
+
|
| 415 |
+
However, if you cease all violation of this License, then your
|
| 416 |
+
license from a particular copyright holder is reinstated (a)
|
| 417 |
+
provisionally, unless and until the copyright holder explicitly and
|
| 418 |
+
finally terminates your license, and (b) permanently, if the copyright
|
| 419 |
+
holder fails to notify you of the violation by some reasonable means
|
| 420 |
+
prior to 60 days after the cessation.
|
| 421 |
+
|
| 422 |
+
Moreover, your license from a particular copyright holder is
|
| 423 |
+
reinstated permanently if the copyright holder notifies you of the
|
| 424 |
+
violation by some reasonable means, this is the first time you have
|
| 425 |
+
received notice of violation of this License (for any work) from that
|
| 426 |
+
copyright holder, and you cure the violation prior to 30 days after
|
| 427 |
+
your receipt of the notice.
|
| 428 |
+
|
| 429 |
+
Termination of your rights under this section does not terminate the
|
| 430 |
+
licenses of parties who have received copies or rights from you under
|
| 431 |
+
this License. If your rights have been terminated and not permanently
|
| 432 |
+
reinstated, you do not qualify to receive new licenses for the same
|
| 433 |
+
material under section 10.
|
| 434 |
+
|
| 435 |
+
9. Acceptance Not Required for Having Copies.
|
| 436 |
+
|
| 437 |
+
You are not required to accept this License in order to receive or
|
| 438 |
+
run a copy of the Program. Ancillary propagation of a covered work
|
| 439 |
+
occurring solely as a consequence of using peer-to-peer transmission
|
| 440 |
+
to receive a copy likewise does not require acceptance. However,
|
| 441 |
+
nothing other than this License grants you permission to propagate or
|
| 442 |
+
modify any covered work. These actions infringe copyright if you do
|
| 443 |
+
not accept this License. Therefore, by modifying or propagating a
|
| 444 |
+
covered work, you indicate your acceptance of this License to do so.
|
| 445 |
+
|
| 446 |
+
10. Automatic Licensing of Downstream Recipients.
|
| 447 |
+
|
| 448 |
+
Each time you convey a covered work, the recipient automatically
|
| 449 |
+
receives a license from the original licensors, to run, modify and
|
| 450 |
+
propagate that work, subject to this License. You are not responsible
|
| 451 |
+
for enforcing compliance by third parties with this License.
|
| 452 |
+
|
| 453 |
+
An "entity transaction" is a transaction transferring control of an
|
| 454 |
+
organization, or substantially all assets of one, or subdividing an
|
| 455 |
+
organization, or merging organizations. If propagation of a covered
|
| 456 |
+
work results from an entity transaction, each party to that
|
| 457 |
+
transaction who receives a copy of the work also receives whatever
|
| 458 |
+
licenses to the work the party's predecessor in interest had or could
|
| 459 |
+
give under the previous paragraph, plus a right to possession of the
|
| 460 |
+
Corresponding Source of the work from the predecessor in interest, if
|
| 461 |
+
the predecessor has it or can get it with reasonable efforts.
|
| 462 |
+
|
| 463 |
+
You may not impose any further restrictions on the exercise of the
|
| 464 |
+
rights granted or affirmed under this License. For example, you may
|
| 465 |
+
not impose a license fee, royalty, or other charge for exercise of
|
| 466 |
+
rights granted under this License, and you may not initiate litigation
|
| 467 |
+
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
| 468 |
+
any patent claim is infringed by making, using, selling, offering for
|
| 469 |
+
sale, or importing the Program or any portion of it.
|
| 470 |
+
|
| 471 |
+
11. Patents.
|
| 472 |
+
|
| 473 |
+
A "contributor" is a copyright holder who authorizes use under this
|
| 474 |
+
License of the Program or a work on which the Program is based. The
|
| 475 |
+
work thus licensed is called the contributor's "contributor version".
|
| 476 |
+
|
| 477 |
+
A contributor's "essential patent claims" are all patent claims
|
| 478 |
+
owned or controlled by the contributor, whether already acquired or
|
| 479 |
+
hereafter acquired, that would be infringed by some manner, permitted
|
| 480 |
+
by this License, of making, using, or selling its contributor version,
|
| 481 |
+
but do not include claims that would be infringed only as a
|
| 482 |
+
consequence of further modification of the contributor version. For
|
| 483 |
+
purposes of this definition, "control" includes the right to grant
|
| 484 |
+
patent sublicenses in a manner consistent with the requirements of
|
| 485 |
+
this License.
|
| 486 |
+
|
| 487 |
+
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
| 488 |
+
patent license under the contributor's essential patent claims, to
|
| 489 |
+
make, use, sell, offer for sale, import and otherwise run, modify and
|
| 490 |
+
propagate the contents of its contributor version.
|
| 491 |
+
|
| 492 |
+
In the following three paragraphs, a "patent license" is any express
|
| 493 |
+
agreement or commitment, however denominated, not to enforce a patent
|
| 494 |
+
(such as an express permission to practice a patent or covenant not to
|
| 495 |
+
sue for patent infringement). To "grant" such a patent license to a
|
| 496 |
+
party means to make such an agreement or commitment not to enforce a
|
| 497 |
+
patent against the party.
|
| 498 |
+
|
| 499 |
+
If you convey a covered work, knowingly relying on a patent license,
|
| 500 |
+
and the Corresponding Source of the work is not available for anyone
|
| 501 |
+
to copy, free of charge and under the terms of this License, through a
|
| 502 |
+
publicly available network server or other readily accessible means,
|
| 503 |
+
then you must either (1) cause the Corresponding Source to be so
|
| 504 |
+
available, or (2) arrange to deprive yourself of the benefit of the
|
| 505 |
+
patent license for this particular work, or (3) arrange, in a manner
|
| 506 |
+
consistent with the requirements of this License, to extend the patent
|
| 507 |
+
license to downstream recipients. "Knowingly relying" means you have
|
| 508 |
+
actual knowledge that, but for the patent license, your conveying the
|
| 509 |
+
covered work in a country, or your recipient's use of the covered work
|
| 510 |
+
in a country, would infringe one or more identifiable patents in that
|
| 511 |
+
country that you have reason to believe are valid.
|
| 512 |
+
|
| 513 |
+
If, pursuant to or in connection with a single transaction or
|
| 514 |
+
arrangement, you convey, or propagate by procuring conveyance of, a
|
| 515 |
+
covered work, and grant a patent license to some of the parties
|
| 516 |
+
receiving the covered work authorizing them to use, propagate, modify
|
| 517 |
+
or convey a specific copy of the covered work, then the patent license
|
| 518 |
+
you grant is automatically extended to all recipients of the covered
|
| 519 |
+
work and works based on it.
|
| 520 |
+
|
| 521 |
+
A patent license is "discriminatory" if it does not include within
|
| 522 |
+
the scope of its coverage, prohibits the exercise of, or is
|
| 523 |
+
conditioned on the non-exercise of one or more of the rights that are
|
| 524 |
+
specifically granted under this License. You may not convey a covered
|
| 525 |
+
work if you are a party to an arrangement with a third party that is
|
| 526 |
+
in the business of distributing software, under which you make payment
|
| 527 |
+
to the third party based on the extent of your activity of conveying
|
| 528 |
+
the work, and under which the third party grants, to any of the
|
| 529 |
+
parties who would receive the covered work from you, a discriminatory
|
| 530 |
+
patent license (a) in connection with copies of the covered work
|
| 531 |
+
conveyed by you (or copies made from those copies), or (b) primarily
|
| 532 |
+
for and in connection with specific products or compilations that
|
| 533 |
+
contain the covered work, unless you entered into that arrangement,
|
| 534 |
+
or that patent license was granted, prior to 28 March 2007.
|
| 535 |
+
|
| 536 |
+
Nothing in this License shall be construed as excluding or limiting
|
| 537 |
+
any implied license or other defenses to infringement that may
|
| 538 |
+
otherwise be available to you under applicable patent law.
|
| 539 |
+
|
| 540 |
+
12. No Surrender of Others' Freedom.
|
| 541 |
+
|
| 542 |
+
If conditions are imposed on you (whether by court order, agreement or
|
| 543 |
+
otherwise) that contradict the conditions of this License, they do not
|
| 544 |
+
excuse you from the conditions of this License. If you cannot convey a
|
| 545 |
+
covered work so as to satisfy simultaneously your obligations under this
|
| 546 |
+
License and any other pertinent obligations, then as a consequence you may
|
| 547 |
+
not convey it at all. For example, if you agree to terms that obligate you
|
| 548 |
+
to collect a royalty for further conveying from those to whom you convey
|
| 549 |
+
the Program, the only way you could satisfy both those terms and this
|
| 550 |
+
License would be to refrain entirely from conveying the Program.
|
| 551 |
+
|
| 552 |
+
13. Use with the GNU Affero General Public License.
|
| 553 |
+
|
| 554 |
+
Notwithstanding any other provision of this License, you have
|
| 555 |
+
permission to link or combine any covered work with a work licensed
|
| 556 |
+
under version 3 of the GNU Affero General Public License into a single
|
| 557 |
+
combined work, and to convey the resulting work. The terms of this
|
| 558 |
+
License will continue to apply to the part which is the covered work,
|
| 559 |
+
but the special requirements of the GNU Affero General Public License,
|
| 560 |
+
section 13, concerning interaction through a network will apply to the
|
| 561 |
+
combination as such.
|
| 562 |
+
|
| 563 |
+
14. Revised Versions of this License.
|
| 564 |
+
|
| 565 |
+
The Free Software Foundation may publish revised and/or new versions of
|
| 566 |
+
the GNU General Public License from time to time. Such new versions will
|
| 567 |
+
be similar in spirit to the present version, but may differ in detail to
|
| 568 |
+
address new problems or concerns.
|
| 569 |
+
|
| 570 |
+
Each version is given a distinguishing version number. If the
|
| 571 |
+
Program specifies that a certain numbered version of the GNU General
|
| 572 |
+
Public License "or any later version" applies to it, you have the
|
| 573 |
+
option of following the terms and conditions either of that numbered
|
| 574 |
+
version or of any later version published by the Free Software
|
| 575 |
+
Foundation. If the Program does not specify a version number of the
|
| 576 |
+
GNU General Public License, you may choose any version ever published
|
| 577 |
+
by the Free Software Foundation.
|
| 578 |
+
|
| 579 |
+
If the Program specifies that a proxy can decide which future
|
| 580 |
+
versions of the GNU General Public License can be used, that proxy's
|
| 581 |
+
public statement of acceptance of a version permanently authorizes you
|
| 582 |
+
to choose that version for the Program.
|
| 583 |
+
|
| 584 |
+
Later license versions may give you additional or different
|
| 585 |
+
permissions. However, no additional obligations are imposed on any
|
| 586 |
+
author or copyright holder as a result of your choosing to follow a
|
| 587 |
+
later version.
|
| 588 |
+
|
| 589 |
+
15. Disclaimer of Warranty.
|
| 590 |
+
|
| 591 |
+
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
| 592 |
+
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
| 593 |
+
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
| 594 |
+
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
| 595 |
+
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
| 596 |
+
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
| 597 |
+
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
| 598 |
+
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
| 599 |
+
|
| 600 |
+
16. Limitation of Liability.
|
| 601 |
+
|
| 602 |
+
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
| 603 |
+
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
| 604 |
+
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
| 605 |
+
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
| 606 |
+
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
| 607 |
+
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
| 608 |
+
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
| 609 |
+
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
| 610 |
+
SUCH DAMAGES.
|
| 611 |
+
|
| 612 |
+
17. Interpretation of Sections 15 and 16.
|
| 613 |
+
|
| 614 |
+
If the disclaimer of warranty and limitation of liability provided
|
| 615 |
+
above cannot be given local legal effect according to their terms,
|
| 616 |
+
reviewing courts shall apply local law that most closely approximates
|
| 617 |
+
an absolute waiver of all civil liability in connection with the
|
| 618 |
+
Program, unless a warranty or assumption of liability accompanies a
|
| 619 |
+
copy of the Program in return for a fee.
|
| 620 |
+
|
| 621 |
+
END OF TERMS AND CONDITIONS
|
| 622 |
+
|
| 623 |
+
How to Apply These Terms to Your New Programs
|
| 624 |
+
|
| 625 |
+
If you develop a new program, and you want it to be of the greatest
|
| 626 |
+
possible use to the public, the best way to achieve this is to make it
|
| 627 |
+
free software which everyone can redistribute and change under these terms.
|
| 628 |
+
|
| 629 |
+
To do so, attach the following notices to the program. It is safest
|
| 630 |
+
to attach them to the start of each source file to most effectively
|
| 631 |
+
state the exclusion of warranty; and each file should have at least
|
| 632 |
+
the "copyright" line and a pointer to where the full notice is found.
|
| 633 |
+
|
| 634 |
+
<one line to give the program's name and a brief idea of what it does.>
|
| 635 |
+
Copyright (C) <year> <name of author>
|
| 636 |
+
|
| 637 |
+
This program is free software: you can redistribute it and/or modify
|
| 638 |
+
it under the terms of the GNU General Public License as published by
|
| 639 |
+
the Free Software Foundation, either version 3 of the License, or
|
| 640 |
+
(at your option) any later version.
|
| 641 |
+
|
| 642 |
+
This program is distributed in the hope that it will be useful,
|
| 643 |
+
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
| 644 |
+
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
| 645 |
+
GNU General Public License for more details.
|
| 646 |
+
|
| 647 |
+
You should have received a copy of the GNU General Public License
|
| 648 |
+
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
| 649 |
+
|
| 650 |
+
Also add information on how to contact you by electronic and paper mail.
|
| 651 |
+
|
| 652 |
+
If the program does terminal interaction, make it output a short
|
| 653 |
+
notice like this when it starts in an interactive mode:
|
| 654 |
+
|
| 655 |
+
<program> Copyright (C) <year> <name of author>
|
| 656 |
+
This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
| 657 |
+
This is free software, and you are welcome to redistribute it
|
| 658 |
+
under certain conditions; type `show c' for details.
|
| 659 |
+
|
| 660 |
+
The hypothetical commands `show w' and `show c' should show the appropriate
|
| 661 |
+
parts of the General Public License. Of course, your program's commands
|
| 662 |
+
might be different; for a GUI interface, you would use an "about box".
|
| 663 |
+
|
| 664 |
+
You should also get your employer (if you work as a programmer) or school,
|
| 665 |
+
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
| 666 |
+
For more information on this, and how to apply and follow the GNU GPL, see
|
| 667 |
+
<https://www.gnu.org/licenses/>.
|
| 668 |
+
|
| 669 |
+
The GNU General Public License does not permit incorporating your program
|
| 670 |
+
into proprietary programs. If your program is a subroutine library, you
|
| 671 |
+
may consider it more useful to permit linking proprietary applications with
|
| 672 |
+
the library. If this is what you want to do, use the GNU Lesser General
|
| 673 |
+
Public License instead of this License. But first, please read
|
| 674 |
+
<https://www.gnu.org/licenses/why-not-lgpl.html>.
|
ComfyUI-Easy-Use/README.ZH_CN.md
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<a href="https://space.bilibili.com/1840885116">视频介绍</a> |
|
| 5 |
+
文档 (康明孙) |
|
| 6 |
+
<a href="https://github.com/yolain/ComfyUI-Yolain-Workflows">工作流合集</a> |
|
| 7 |
+
<a href="#%EF%B8%8F-donation">捐助</a>
|
| 8 |
+
<br><br>
|
| 9 |
+
<a href="./README.md"><img src="https://img.shields.io/badge/🇬🇧English-e9e9e9"></a>
|
| 10 |
+
<a href="./README.ZH_CN.md"><img src="https://img.shields.io/badge/🇨🇳中文简体-0b8cf5"></a>
|
| 11 |
+
</div>
|
| 12 |
+
|
| 13 |
+
**ComfyUI-Easy-Use** 是一个化繁为简的节点整合包, 在 [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) 的基础上进行延展,并针对了诸多主流的节点包做了整合与优化,以达到更快更方便使用ComfyUI的目的,在保证自由度的同时还原了本属于Stable Diffusion的极致畅快出图体验。
|
| 14 |
+
|
| 15 |
+
## 👨🏻🎨 特色介绍
|
| 16 |
+
|
| 17 |
+
- 沿用了 [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) 的思路,大大减少了折腾工作流的时间成本。
|
| 18 |
+
- UI界面美化,首次安装的用户,如需使用UI主题,请在 Settings -> Color Palette 中自行切换主题并**刷新页面**即可
|
| 19 |
+
- 增加了预采样参数配置的节点,可与采样节点分离,更方便预览。
|
| 20 |
+
- 支持通配符与Lora的提示词节点,如需使用Lora Block Weight用法,需先保证自定义节点包中安装了 [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack)
|
| 21 |
+
- 可多选的风格化提示词选择器,默认是Fooocus的样式json,可自定义json放在styles底下,samples文件夹里可放预览图(名称和name一致,图片文件名如有空格需转为下划线'_')
|
| 22 |
+
- 加载器可开启A1111提示词风格模式,可重现与webui生成近乎相同的图像,需先安装 [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes)
|
| 23 |
+
- 可使用`easy latentNoisy`或`easy preSamplingNoiseIn`节点实现对潜空间的噪声注入
|
| 24 |
+
- 简化 SD1.x、SD2.x、SDXL、SVD、Zero123等流程
|
| 25 |
+
- 简化 Stable Cascade [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#1-13-stable-cascade)
|
| 26 |
+
- 简化 Layer Diffuse [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-3-layerdiffusion)
|
| 27 |
+
- 简化 InstantID [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-2-instantid), 需先保证自定义节点包中安装了 [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
|
| 28 |
+
- 简化 IPAdapter, 需先保证自定义节点包中安装最新版v2的 [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus)
|
| 29 |
+
- 扩展 XYplot 的可用性
|
| 30 |
+
- 整合了Fooocus Inpaint功能
|
| 31 |
+
- 整合了常用的逻辑计算、转换类型、展示所有类型等
|
| 32 |
+
- 支持节点上checkpoint、lora模型子目录分类及预览图 (请在设置中开启上下文菜单嵌套子目录)
|
| 33 |
+
- 支持BriaAI的RMBG-1.4模型的背景去除节点,[技术参考](https://huggingface.co/briaai/RMBG-1.4)
|
| 34 |
+
- 支持 强制清理comfyUI模型显存占用
|
| 35 |
+
- 支持Stable Diffusion 3 多账号API节点
|
| 36 |
+
- 支持IC-Light的应用 [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-5-ic-light) | [代码整合来源](https://github.com/huchenlei/ComfyUI-IC-Light) | [技术参考](https://github.com/lllyasviel/IC-Light)
|
| 37 |
+
- 中文提示词自动识别,使用[opus-mt-zh-en模型](https://huggingface.co/Helsinki-NLP/opus-mt-zh-en)
|
| 38 |
+
- 支持 sd3 模型
|
| 39 |
+
- 支持 kolors 模型
|
| 40 |
+
- 支持 flux 模型
|
| 41 |
+
|
| 42 |
+
## 👨🏻🔧 安装
|
| 43 |
+
|
| 44 |
+
1. 将存储库克隆到 **custom_nodes** 目录并安装依赖
|
| 45 |
+
```shell
|
| 46 |
+
#1. git下载
|
| 47 |
+
git clone https://github.com/yolain/ComfyUI-Easy-Use
|
| 48 |
+
#2. 安装依赖
|
| 49 |
+
双击install.bat安装依赖
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
## 👨🏻🚀 计划
|
| 53 |
+
|
| 54 |
+
- [x] 更新便于维护的新前端代码
|
| 55 |
+
- [x] 使用sass维护css样式
|
| 56 |
+
- [x] 对原有扩展进行优化
|
| 57 |
+
- [x] 增加新的组件(如节点时间统计等)
|
| 58 |
+
- [ ] 在[ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Yolain-Workflows)中上传更多的工作流(如kolors,sd3等),并更新english版本的readme
|
| 59 |
+
- [ ] 更详细功能介绍的 gitbook
|
| 60 |
+
|
| 61 |
+
## 📜 更新日志
|
| 62 |
+
|
| 63 |
+
**v1.2.2**
|
| 64 |
+
|
| 65 |
+
- 增加 v2 版本新前端代码
|
| 66 |
+
- 增加 `easy fluxLoader`
|
| 67 |
+
- 增加 `controlnetApply` 相关节点对sd3和hunyuanDiT的支持
|
| 68 |
+
|
| 69 |
+
**v1.2.1**
|
| 70 |
+
|
| 71 |
+
- 增加 `easy ipadapterApplyFaceIDKolors`
|
| 72 |
+
- `easy ipadapterApply` 和 `easy ipadapterApplyADV` 增加 **PLUS (kolors genernal)** 和 **FACEID PLUS KOLORS** 预置项
|
| 73 |
+
- `easy imageRemBg` 增加 **inspyrenet** 选项
|
| 74 |
+
- 增加 `easy controlnetLoader++`
|
| 75 |
+
- 去除 `easy positive` `easy negative` 等prompt节点的自动将中文翻译功能,自动翻译仅在 `easy a1111Loader` 等不支持中文TE的加载器中生效
|
| 76 |
+
- 增加 `easy kolorsLoader` - 可灵加载器,参考了 [MinusZoneAI](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) 和 [kijai](https://github.com/kijai/ComfyUI-KwaiKolorsWrapper) 的代码。
|
| 77 |
+
|
| 78 |
+
**v1.2.0**
|
| 79 |
+
|
| 80 |
+
- 增加 `easy pulIDApply` 和 `easy pulIDApplyADV`
|
| 81 |
+
- 增加 `easy hunyuanDiTLoader` 和 `easy pixArtLoader`
|
| 82 |
+
- 当新菜单的位置在上或者下时增加上 crystools 的显示,推荐开两个就好(如果后续crystools有更新UI适配我可能会删除掉)
|
| 83 |
+
- 增加 **easy sliderControl** - 滑块控制节点,当前可用于控制ipadapterMS的参数 (双击滑块可重置为默认值)
|
| 84 |
+
- 增加 **layer_weights** 属性在 `easy ipadapterApplyADV` 节点
|
| 85 |
+
|
| 86 |
+
**v1.1.9**
|
| 87 |
+
|
| 88 |
+
- 增加 新的调度器 **gitsScheduler**
|
| 89 |
+
- 增加 `easy imageBatchToImageList` 和 `easy imageListToImageBatch` (修复Impact版的一点小问题)
|
| 90 |
+
- 递归模型子目录嵌套
|
| 91 |
+
- 支持 sd3 模型
|
| 92 |
+
- 增加 `easy applyInpaint` - 局部重绘全模式节点 (相比与之前的kSamplerInpating节点逻辑会更合理些)
|
| 93 |
+
|
| 94 |
+
**v1.1.8**
|
| 95 |
+
|
| 96 |
+
- 增加中文提示词自动翻译,使用[opus-mt-zh-en模型](https://huggingface.co/Helsinki-NLP/opus-mt-zh-en), 默认已对wildcard、lora正则处理, 其他需要保留的中文,可使用`@你的提示词@`包裹 (若依赖安装完成后报错, 请重启),测算大约会占0.3GB显存
|
| 97 |
+
- 增加 `easy controlnetStack` - controlnet堆
|
| 98 |
+
- 增加 `easy applyBrushNet` - [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4brushnet_1.1.8.json)
|
| 99 |
+
- 增加 `easy applyPowerPaint` - [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4powerpaint_outpaint_1.1.8.json)
|
| 100 |
+
|
| 101 |
+
**v1.1.7**
|
| 102 |
+
|
| 103 |
+
- 修复 一些模型(如controlnet模型等)未成功写入缓存,导致修改前置节点束参数(如提示词)需要二次载入模型的问题
|
| 104 |
+
- 增加 `easy prompt` - 主体和光影预置项,后期可能会调整
|
| 105 |
+
- 增加 `easy icLightApply` - 重绘光影, 从[ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light)优化
|
| 106 |
+
- 增加 `easy imageSplitGrid` - 图像网格拆分
|
| 107 |
+
- `easy kSamplerInpainting` 的 **additional** 属性增加差异扩散和brushnet等相关选项
|
| 108 |
+
- 增加 brushnet模型加载的支持 - [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)
|
| 109 |
+
- 增加 `easy applyFooocusInpaint` - Fooocus内补节点 替代原有的 FooocusInpaintLoader
|
| 110 |
+
- 移除 `easy fooocusInpaintLoader` - 容易bug,不再使用
|
| 111 |
+
- 修改 easy kSampler等采样器中并联的model 不再替换输出中pipe里的model
|
| 112 |
+
|
| 113 |
+
**v1.1.6**
|
| 114 |
+
|
| 115 |
+
- 增加步调齐整适配 - 在所有的预采样和全采样器节点中的 调度器(schedulder) 增加了 **alignYourSteps** 选项
|
| 116 |
+
- `easy kSampler` 和 `easy fullkSampler` 的 **image_output** 增加 **Preview&Choose**选项
|
| 117 |
+
- 增加 `easy styleAlignedBatchAlign` - 风格对齐 [style_aligned_comfy](https://github.com/brianfitzgerald/style_aligned_comfy)
|
| 118 |
+
- 增加 `easy ckptNames`
|
| 119 |
+
- 增加 `easy controlnetNames`
|
| 120 |
+
- 增加 `easy imagesSplitimage` - 批次图像拆分单张
|
| 121 |
+
- 增加 `easy imageCount` - 图像数量
|
| 122 |
+
- 增加 `easy textSwitch` - 文字切换
|
| 123 |
+
|
| 124 |
+
**v1.1.5**
|
| 125 |
+
|
| 126 |
+
- 重写 `easy cleanGPUUsed` - 可强制清理comfyUI的模型显存占用
|
| 127 |
+
- 增加 `easy humanSegmentation` - 多类分割、人像分割
|
| 128 |
+
- 增加 `easy imageColorMatch`
|
| 129 |
+
- 增加 `easy ipadapterApplyRegional`
|
| 130 |
+
- 增加 `easy ipadapterApplyFromParams`
|
| 131 |
+
- 增加 `easy imageInterrogator` - 图像反推
|
| 132 |
+
- 增加 `easy stableDiffusion3API` - 简易的Stable Diffusion 3 多账号API节点
|
| 133 |
+
|
| 134 |
+
**v1.1.4**
|
| 135 |
+
|
| 136 |
+
- 增加 `easy imageChooser` - 从[cg-image-picker](https://github.com/chrisgoringe/cg-image-picker)简化的图片选择器
|
| 137 |
+
- 增加 `easy preSamplingCustom` - 自定义预采样,可支持cosXL-edit
|
| 138 |
+
- 增加 `easy ipadapterStyleComposition`
|
| 139 |
+
- 增加 在Loaders上右键菜单可查看 checkpoints、lora 信息
|
| 140 |
+
- 修复 `easy preSamplingNoiseIn`、`easy latentNoisy`、`east Unsampler` 以兼容ComfyUI Revision>=2098 [0542088e] 以上版本
|
| 141 |
+
- 修复 FooocusInpaint修改ModelPatcher计算权重引发的问题,理应在生成model后重置ModelPatcher为默认值
|
| 142 |
+
|
| 143 |
+
**v1.1.3**
|
| 144 |
+
|
| 145 |
+
- `easy ipadapterApply` 增加 **COMPOSITION** 预置项
|
| 146 |
+
- 增加 对[ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) lora模型 的加载支持
|
| 147 |
+
- 增加 `easy promptLine`
|
| 148 |
+
- 增加 `easy promptReplace`
|
| 149 |
+
- 增加 `easy promptConcat`
|
| 150 |
+
- `easy wildcards` 增加 **multiline_mode**属性
|
| 151 |
+
- 增加 当节点需要下载模型时,若huggingface连接超时,会切换至镜像地址下载模型
|
| 152 |
+
|
| 153 |
+
<details>
|
| 154 |
+
<summary><b>v1.1.2</b></summary>
|
| 155 |
+
|
| 156 |
+
- 改写 EasyUse 相关节点的部分插槽推荐节点
|
| 157 |
+
- 增加 **启用上下文菜单自动嵌套子目录** 设置项,默认为启用状态,可分类子目录及checkpoints、loras预览图
|
| 158 |
+
- 增加 `easy sv3dLoader`
|
| 159 |
+
- 增加 `easy dynamiCrafterLoader`
|
| 160 |
+
- 增加 `easy ipadapterApply`
|
| 161 |
+
- 增加 `easy ipadapterApplyADV`
|
| 162 |
+
- 增加 `easy ipadapterApplyEncoder`
|
| 163 |
+
- 增加 `easy ipadapterApplyEmbeds`
|
| 164 |
+
- 增加 `easy preMaskDetailerFix`
|
| 165 |
+
- `easy kSamplerInpainting` 增加 **additional** 属性,可设置成 Differential Diffusion 或 Only InpaintModelConditioning
|
| 166 |
+
- 修复 `easy stylesSelector` 当未选择样式时,原���提示词发生了变化
|
| 167 |
+
- 修复 `easy pipeEdit` 提示词输入lora时报错
|
| 168 |
+
- 修复 layerDiffuse xyplot相关bug
|
| 169 |
+
</details>
|
| 170 |
+
|
| 171 |
+
<details>
|
| 172 |
+
<summary><b>v1.1.1/b></summary>
|
| 173 |
+
|
| 174 |
+
- 修复首次添加含seed的节点且当前模式为control_before_generate时,seed为0的问题
|
| 175 |
+
- `easy preSamplingAdvanced` 增加 **return_with_leftover_noise**
|
| 176 |
+
- 修复 `easy stylesSelector` 当选择自定义样式文件时运行队列报错
|
| 177 |
+
- `easy preSamplingLayerDiffusion` 增加 mask 可选传入参数
|
| 178 |
+
- 将所有 **seed_num** 调整回 **seed**
|
| 179 |
+
- 修补官方BUG: 当control_mode为before 在首次加载页面时未修改节点中widget名称为 control_before_generate
|
| 180 |
+
- 去除强制**control_before_generate**设定
|
| 181 |
+
- 增加 `easy imageRemBg` - 默认为BriaAI的RMBG-1.4模型, 移除背景效果更加,速度更快
|
| 182 |
+
</details>
|
| 183 |
+
|
| 184 |
+
<details>
|
| 185 |
+
<summary><b>v1.1.0</b></summary>
|
| 186 |
+
|
| 187 |
+
- 增加 `easy imageSplitList` - 拆分每 N 张图像
|
| 188 |
+
- 增加 `easy preSamplingDiffusionADDTL` - 可配置前景、背景、blended的additional_prompt等
|
| 189 |
+
- 增加 `easy preSamplingNoiseIn` 可替代需要前置的`easy latentNoisy`节点 实现效果更好的噪声注入
|
| 190 |
+
- `easy pipeEdit` 增加 条件拼接模式选择,可选择替换、合并、联结、平均、设置条件时间
|
| 191 |
+
- 增加 `easy pipeEdit` - 可编辑Pipe的节点(包含可重新输入提示词)
|
| 192 |
+
- 增加 `easy preSamplingLayerDiffusion` 与 `easy kSamplerLayerDiffusion` (连接 `easy kSampler` 也能通)
|
| 193 |
+
- 增加 在 加载器、预采样、采样器、Controlnet等节点上右键可快速替换同类型节点的便捷菜单
|
| 194 |
+
- 增加 `easy instantIDApplyADV` 可连入 positive 与 negative
|
| 195 |
+
- 修复 `easy wildcards` 读取lora未填写完整路径时未自动检索导致加载lora失败的问题
|
| 196 |
+
- 修复 `easy instantIDApply` mask 未传入正确值
|
| 197 |
+
- 修复 在 非a1111提示词风格下 BREAK 不生效的问题
|
| 198 |
+
</details>
|
| 199 |
+
|
| 200 |
+
<details>
|
| 201 |
+
<summary><b>v1.0.9</b></summary>
|
| 202 |
+
|
| 203 |
+
- 修复未安装 ComfyUI-Impack-Pack 和 ComfyUI_InstantID 时报错
|
| 204 |
+
- 修复 `easy pipeIn` - pipe设为可不必选
|
| 205 |
+
- 增加 `easy instantIDApply` - 需要先安装 [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID), 工作流参考[示例](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-2-instantid)
|
| 206 |
+
- 修复 `easy detailerFix` 未添加到保存图片格式化扩展名可用节点列表
|
| 207 |
+
- 修复 `easy XYInputs: PromptSR` 在替换负面提示词时报错
|
| 208 |
+
</details>
|
| 209 |
+
|
| 210 |
+
<details>
|
| 211 |
+
<summary><b>v1.0.8</b></summary>
|
| 212 |
+
|
| 213 |
+
- `easy cascadeLoader` stage_c 与 stage_b 支持checkpoint模型 (需要下载[checkpoints](https://huggingface.co/stabilityai/stable-cascade/tree/main/comfyui_checkpoints))
|
| 214 |
+
- `easy styleSelector` 搜索框修改为不区分大小写匹配
|
| 215 |
+
- `easy fullLoader` 增加 **positive**、**negative**、**latent** 输出项
|
| 216 |
+
- 修复 SDXLClipModel 在 ComfyUI 修订版本号 2016[c2cb8e88] 及以上的报错(判断了版本号可兼容老版本)
|
| 217 |
+
- 修复 `easy detailerFix` 批次大小大于1时生成出错
|
| 218 |
+
- 修复`easy preSampling`等 latent传入后无法根据批次索引生成的问题
|
| 219 |
+
- 修复 `easy svdLoader` 报错
|
| 220 |
+
- 优化代码,减少了诸多冗余,提升运行速度
|
| 221 |
+
- 去除中文翻译对照文本
|
| 222 |
+
|
| 223 |
+
(翻译对照已由 [AIGODLIKE-COMFYUI-TRANSLATION](https://github.com/AIGODLIKE/AIGODLIKE-ComfyUI-Translation) 统一维护啦!
|
| 224 |
+
首次下载或者版本较早的朋友请更新 AIGODLIKE-COMFYUI-TRANSLATION 和本节点包至最新版本。)
|
| 225 |
+
</details>
|
| 226 |
+
|
| 227 |
+
<details>
|
| 228 |
+
<summary><b>v1.0.7</b></summary>
|
| 229 |
+
|
| 230 |
+
- 增加 `easy cascadeLoader` - stable cascade 加载器
|
| 231 |
+
- 增加 `easy preSamplingCascade` - stabled cascade stage_c 预采样参数
|
| 232 |
+
- 增加 `easy fullCascadeKSampler` - stable cascade stage_c 完整版采样器
|
| 233 |
+
- 增加 `easy cascadeKSampler` - stable cascade stage-c ksampler simple
|
| 234 |
+
</details>
|
| 235 |
+
|
| 236 |
+
<details>
|
| 237 |
+
<summary><b>v1.0.6</b></summary>
|
| 238 |
+
|
| 239 |
+
- 增加 `easy XYInputs: Checkpoint`
|
| 240 |
+
- 增加 `easy XYInputs: Lora`
|
| 241 |
+
- `easy seed` 增加固定种子值时可手动切换随机种
|
| 242 |
+
- 修复 `easy fullLoader`等加载器切换lora时自动调整节点大小的问题
|
| 243 |
+
- 去除原有ttn的图片保存逻辑并适配ComfyUI默认的图片保存格式化扩展
|
| 244 |
+
</details>
|
| 245 |
+
|
| 246 |
+
<details>
|
| 247 |
+
<summary><b>v1.0.5</b></summary>
|
| 248 |
+
|
| 249 |
+
- 增加 `easy isSDXL`
|
| 250 |
+
- `easy svdLoader` 增加提示词控制, 可配合open_clip模型进行使用
|
| 251 |
+
- `easy wildcards` 增加 **populated_text** 可输出通配填充后文本
|
| 252 |
+
</details>
|
| 253 |
+
|
| 254 |
+
<details>
|
| 255 |
+
<summary><b>v1.0.4</b></summary>
|
| 256 |
+
|
| 257 |
+
- 增加 `easy showLoaderSettingsNames` 可显示与输出加载器部件中的 模型与VAE名称
|
| 258 |
+
- 增加 `easy promptList` - 提示词列表
|
| 259 |
+
- 增加 `easy fooocusInpaintLoader` - Fooocus内补节点(仅支持XL模型的流程)
|
| 260 |
+
- 增加 **Logic** 逻辑类节点 - 包含类型、计算、判断和转换类型等
|
| 261 |
+
- 增加 `easy imageSave` - 带日期转换和宽高格式化的图像保存节点
|
| 262 |
+
- 增加 `easy joinImageBatch` - 合并图像批次
|
| 263 |
+
- `easy showAnything` 增加支持转换其他类型(如:tensor类型的条件、图像等)
|
| 264 |
+
- `easy kSamplerInpainting` 增加 **patch** 传入值,配合Fooocus内补节点使用
|
| 265 |
+
- `easy imageSave` 增加 **only_preivew**
|
| 266 |
+
|
| 267 |
+
- 修复 xyplot在pillow>9.5中报错
|
| 268 |
+
- 修复 `easy wildcards` 在使用PS扩展插件运行时报错
|
| 269 |
+
- 修复 `easy latentCompositeMaskedWithCond`
|
| 270 |
+
- 修复 `easy XYInputs: ControlNet` 报错
|
| 271 |
+
- 修复 `easy loraStack` **toggle** 为 disabled 时报错
|
| 272 |
+
|
| 273 |
+
- 修改首次安装节点包不再自动替换主题,需手动调整并刷新页面
|
| 274 |
+
</details>
|
| 275 |
+
|
| 276 |
+
<details>
|
| 277 |
+
<summary><b>v1.0.3</b></summary>
|
| 278 |
+
|
| 279 |
+
- 增加 `easy stylesSelector` 风格化提示词选择器
|
| 280 |
+
- 增加队列进度条设置项,默认为未启用状态
|
| 281 |
+
- `easy controlnetLoader` 和 `easy controlnetLoaderADV` 增加参数 **scale_soft_weights**
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
- 修复 `easy XYInputs: Sampler/Scheduler` 报错
|
| 285 |
+
- 修复 右侧菜单 点击按钮时老是跑位的问题
|
| 286 |
+
- 修复 styles 路径在其他环境报错
|
| 287 |
+
- 修复 `easy comfyLoader` 读取错误
|
| 288 |
+
- 修复 xyPlot 在连接 zero123 时报错
|
| 289 |
+
- 修复加载器中提示词为组件时报错
|
| 290 |
+
- 修复 `easy getNode` 和 `easy setNode` 加载时标题未更改
|
| 291 |
+
- 修复所有采样器中存储图片使用子目录前缀不生效的问题
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
- 调整UI主题
|
| 295 |
+
</details>
|
| 296 |
+
|
| 297 |
+
<details>
|
| 298 |
+
<summary><b>v1.0.2</b></summary>
|
| 299 |
+
|
| 300 |
+
- 增加 **autocomplete** 文件夹,如果您安装了 [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts), 将在启动时合并该文件夹下的所有txt文件并覆盖到pyssss包里的autocomplete.txt文件。
|
| 301 |
+
- 增加 `easy XYPlotAdvanced` 和 `easy XYInputs` 等相关节点
|
| 302 |
+
- 增加 **Alt+1到9** 快捷键,可快速粘贴 Node templates 的节点预设 (对应 1到9 顺序)
|
| 303 |
+
|
| 304 |
+
- 修复 `easy imageInsetCrop` 测量值为百分比时步进为1
|
| 305 |
+
- 修复 开启 `a1111_prompt_style` 时XY图表无法使用的问题
|
| 306 |
+
- 右键菜单中增加了一个 `📜Groups Map(EasyUse)`
|
| 307 |
+
|
| 308 |
+
- 修复在Comfy新版本中UI加载失败
|
| 309 |
+
- 修复 `easy pipeToBasicPipe` 报错
|
| 310 |
+
- 修改 `easy fullLoader` 和 `easy a1111Loader` 中的 **a1111_prompt_style** 默认值为 False
|
| 311 |
+
- `easy XYInputs ModelMergeBlocks` 支持csv文件导入数值
|
| 312 |
+
|
| 313 |
+
- 替换了XY图生成时的字体文件
|
| 314 |
+
|
| 315 |
+
- 移除 `easy imageRemBg`
|
| 316 |
+
- 移除包中的介绍图和工作流文件,减少包体积
|
| 317 |
+
|
| 318 |
+
</details>
|
| 319 |
+
|
| 320 |
+
<details>
|
| 321 |
+
<summary><b>v1.0.1</b></summary>
|
| 322 |
+
|
| 323 |
+
- 新增 `easy seed` - 简易随机种
|
| 324 |
+
- `easy preDetailerFix` 新增了 `optional_image` 传入图像可选,如未传默认取值为pipe里的图像
|
| 325 |
+
- 新增 `easy kSamplerInpainting` 用于内补潜空间的采样器
|
| 326 |
+
- 新增 `easy pipeToBasicPipe` 用于转换到Impact的某些节点上
|
| 327 |
+
|
| 328 |
+
- 修复 `easy comfyLoader` 报错
|
| 329 |
+
- 修复所有包含输出图片尺寸的节点取值方式无法批处理的问题
|
| 330 |
+
- 修复 `width` 和 `height` 无法在 `easy svdLoader` 自定义的报错问题
|
| 331 |
+
- 修复所有采样器预览图片的地址链接 (解决在 MACOS 系统中图片无法在采样器中预览的问题)
|
| 332 |
+
- 修复 `vae_name` 在 `easy fullLoader` 和 `easy a1111Loader` 和 `easy comfyLoader` 中选择但未替换原始vae问题
|
| 333 |
+
- 修复 `easy fullkSampler` 除pipe外其他输出值的报错
|
| 334 |
+
- 修复 `easy hiresFix` 输入连接pipe和image、vae同时存在时报错
|
| 335 |
+
- 修复 `easy fullLoader` 中 `model_override` 连接后未执行
|
| 336 |
+
- 修复 因新增`easy seed` 导致action错误
|
| 337 |
+
- 修复 `easy xyplot` 的字体文件路径读取错误
|
| 338 |
+
- 修复 convert 到 `easy seed` 随机种无法固定的问题
|
| 339 |
+
- 修复 `easy pipeIn` 值传入的报错问题
|
| 340 |
+
- 修复 `easy zero123Loader` 和 `easy svdLoader` 读取模型时将模型加入到缓存中
|
| 341 |
+
- 修复 `easy kSampler` `easy kSamplerTiled` `easy detailerFix` 的 `image_output` 默认值为 Preview
|
| 342 |
+
- `easy fullLoader` 和 `easy a1111Loader` 新增了 `a1111_prompt_style` 参数可以重现和webui生成相同的图像,当前您需要安装 [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) 才能使用此功能
|
| 343 |
+
</details>
|
| 344 |
+
|
| 345 |
+
<details>
|
| 346 |
+
<summary><b>v1.0.0</b></summary>
|
| 347 |
+
|
| 348 |
+
- 新增`easy positive` - 简易正面提示词文本
|
| 349 |
+
- 新增`easy negative` - 简易负面提示词文本
|
| 350 |
+
- 新增`easy wildcards` - 支持通配符和Lora选择的提示词文本
|
| 351 |
+
- 新增`easy portraitMaster` - 肖像大师v2.2
|
| 352 |
+
- 新增`easy loraStack` - Lora堆
|
| 353 |
+
- 新增`easy fullLoader` - 完整版的加载器
|
| 354 |
+
- 新增`easy zero123Loader` - 简易zero123加载器
|
| 355 |
+
- 新增`easy svdLoader` - 简易svd加载器
|
| 356 |
+
- 新增`easy fullkSampler` - 完整版的采样器(无分离)
|
| 357 |
+
- 新增`easy hiresFix` - 支持Pipe的高清修复
|
| 358 |
+
- 新增`easy predetailerFix` `easy DetailerFix` - 支持Pipe的细节修复
|
| 359 |
+
- 新增`easy ultralyticsDetectorPipe` `easy samLoaderPipe` - 检测加载器(细节修复的输入项)
|
| 360 |
+
- 新增`easy pipein` `easy pipeout` - Pipe的输入与输出
|
| 361 |
+
- 新增`easy xyPlot` - 简易的xyplot (后续会更新更多可控参数)
|
| 362 |
+
- 新增`easy imageRemoveBG` - 图像去除背景
|
| 363 |
+
- 新增`easy imagePixelPerfect` - 图像完美像素
|
| 364 |
+
- 新增`easy poseEditor` - 姿势编辑器
|
| 365 |
+
- 新增UI主题(黑曜石)- 默认自动加载UI, 也可在设置中自行更替
|
| 366 |
+
|
| 367 |
+
- 修复 `easy globalSeed` 不生效问题
|
| 368 |
+
- 修复所有的`seed_num` 因 [cg-use-everywhere](https://github.com/chrisgoringe/cg-use-everywhere) 实时更新图表导致值错乱的问题
|
| 369 |
+
- 修复`easy imageSize` `easy imageSizeBySide` `easy imageSizeByLongerSide` 可作为终节点
|
| 370 |
+
- 修复 `seed_num` (随机种子值) 在历史记录中读取无法一致的Bug
|
| 371 |
+
</details>
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
<details>
|
| 375 |
+
<summary><b>v0.5</b></summary>
|
| 376 |
+
|
| 377 |
+
- 新增 `easy controlnetLoaderADV` 节点
|
| 378 |
+
- 新增 `easy imageSizeBySide` 节点,可选输出为长边或短边
|
| 379 |
+
- 新增 `easy LLLiteLoader` 节点,如果您预先安装过 kohya-ss/ControlNet-LLLite-ComfyUI 包,请将 models 里的模型文件移动至 ComfyUI\models\controlnet\ (即comfy默认的controlnet路径里,请勿修改模型的文件名,不然会读取不到)。
|
| 380 |
+
- 新增 `easy imageSize` 和 `easy imageSizeByLongerSize` 输出的尺寸显示。
|
| 381 |
+
- 新增 `easy showSpentTime` 节点用于展示图片推理花费时间与VAE解码花费时间。
|
| 382 |
+
- `easy controlnetLoaderADV` 和 `easy controlnetLoader` 新增 `control_net` 可选传入参数
|
| 383 |
+
- `easy preSampling` 和 `easy preSamplingAdvanced` 新增 `image_to_latent` 可选传入参数
|
| 384 |
+
- `easy a1111Loader` 和 `easy comfyLoader` 新增 `batch_size` 传入参数
|
| 385 |
+
|
| 386 |
+
- 修改 `easy controlnetLoader` 到 loader 分类底下。
|
| 387 |
+
</details>
|
| 388 |
+
|
| 389 |
+
## 整合参考到的相关节点包
|
| 390 |
+
|
| 391 |
+
声明: 非常尊重这些原作者们的付出,开源不易,我仅仅只是做了一些整合与优化。
|
| 392 |
+
|
| 393 |
+
| 节点名 (搜索名) | 相关的库 | 库相关的节点 |
|
| 394 |
+
|:-------------------------------|:----------------------------------------------------------------------------|:------------------------|
|
| 395 |
+
| easy setNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.SetNode |
|
| 396 |
+
| easy getNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.GetNode |
|
| 397 |
+
| easy bookmark | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | Bookmark 🔖 |
|
| 398 |
+
| easy portraitMarker | [comfyui-portrait-master](https://github.com/florestefano1975/comfyui-portrait-master) | Portrait Master |
|
| 399 |
+
| easy LLLiteLoader | [ControlNet-LLLite-ComfyUI](https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI) | LLLiteLoader |
|
| 400 |
+
| easy globalSeed | [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) | Global Seed (Inspire) |
|
| 401 |
+
| easy preSamplingDynamicCFG | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
|
| 402 |
+
| dynamicThresholdingFull | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
|
| 403 |
+
| easy imageInsetCrop | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | ImageInsetCrop |
|
| 404 |
+
| easy poseEditor | [ComfyUI_Custom_Nodes_AlekPet](https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet) | poseNode |
|
| 405 |
+
| easy if | [ComfyUI-Logic](https://github.com/theUpsider/ComfyUI-Logic) | IfExecute |
|
| 406 |
+
| easy preSamplingLayerDiffusion | [ComfyUI-layerdiffusion](https://github.com/huchenlei/ComfyUI-layerdiffusion) | LayeredDiffusionApply等 |
|
| 407 |
+
| easy dynamiCrafterLoader | [ComfyUI-layerdiffusion](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter) | Apply Dynamicrafter |
|
| 408 |
+
| easy imageChooser | [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) | Preview Chooser |
|
| 409 |
+
| easy styleAlignedBatchAlign | [style_aligned_comfy](https://github.com/chrisgoringe/cg-image-picker) | styleAlignedBatchAlign |
|
| 410 |
+
| easy icLightApply | [ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light) | ICLightApply等 |
|
| 411 |
+
| easy kolorsLoader | [ComfyUI-Kolors-MZ](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) | kolorsLoader |
|
| 412 |
+
|
| 413 |
+
## Credits
|
| 414 |
+
|
| 415 |
+
[ComfyUI](https://github.com/comfyanonymous/ComfyUI) - 功能强大且模块化的Stable Diffusion GUI
|
| 416 |
+
|
| 417 |
+
[ComfyUI-ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) - ComfyUI管理器
|
| 418 |
+
|
| 419 |
+
[tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) - 管道节点(节点束)让用户减少了不必要的连接
|
| 420 |
+
|
| 421 |
+
[ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) - diffus3的获取与设置点让用户可以分离工作流构成
|
| 422 |
+
|
| 423 |
+
[ComfyUI-Impact-Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack) - 常规整合包1
|
| 424 |
+
|
| 425 |
+
[ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) - 常规整合包2
|
| 426 |
+
|
| 427 |
+
[ComfyUI-Logic](https://github.com/theUpsider/ComfyUI-Logic) - ComfyUI逻辑运算
|
| 428 |
+
|
| 429 |
+
[ComfyUI-ResAdapter](https://github.com/jiaxiangc/ComfyUI-ResAdapter) - 让模型生成不受训练分辨率限制
|
| 430 |
+
|
| 431 |
+
[ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) - 风格迁移
|
| 432 |
+
|
| 433 |
+
[ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) - 人脸迁移
|
| 434 |
+
|
| 435 |
+
[ComfyUI_PuLID](https://github.com/cubiq/PuLID_ComfyUI) - 人脸迁移
|
| 436 |
+
|
| 437 |
+
[ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) - pyssss 小蛇🐍脚本
|
| 438 |
+
|
| 439 |
+
[cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) - 图片选择器
|
| 440 |
+
|
| 441 |
+
[ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet) - BrushNet 内补节点
|
| 442 |
+
|
| 443 |
+
[ComfyUI_ExtraModels](https://github.com/city96/ComfyUI_ExtraModels) - DiT架构相关节点(Pixart、混元DiT等)
|
| 444 |
+
|
| 445 |
+
## ☕️ Donation
|
| 446 |
+
|
| 447 |
+
**Comfyui-Easy-Use** 是一个 GPL 许可的开源项目。为了项目取得更好、可持续的发展,我希望能够获得更多的支持。 如果我的自定义节点为您的一天增添了价值,请考虑喝杯咖啡来进一步补充能量! 💖感谢您的支持,每一杯咖啡都是我创作的动力!
|
| 448 |
+
|
| 449 |
+
- [BiliBili充电](https://space.bilibili.com/1840885116)
|
| 450 |
+
- [爱发电](https://afdian.com/a/yolain)
|
| 451 |
+
- [Wechat/Alipay](https://github.com/user-attachments/assets/803469bd-ed6a-4fab-932d-50e5088a2d03)
|
| 452 |
+
|
| 453 |
+
感谢您的捐助,我将用这些费用来租用 GPU 或购买其他 GPT 服务,以便更好地调试和完善 ComfyUI-Easy-Use 功能
|
| 454 |
+
|
| 455 |
+
## 🌟Stargazers
|
| 456 |
+
|
| 457 |
+
My gratitude extends to the generous souls who bestow a star. Your support is much appreciated!
|
| 458 |
+
|
| 459 |
+
[](https://github.com/yolain/ComfyUI-Easy-Use/stargazers)
|
ComfyUI-Easy-Use/README.en.md
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="right">
|
| 2 |
+
<a href="./README.md">中文</a> | <strong>English</strong>
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
<div align="center">
|
| 6 |
+
|
| 7 |
+
# ComfyUI Easy Use
|
| 8 |
+
</div>
|
| 9 |
+
|
| 10 |
+
**ComfyUI-Easy-Use** is a simplified node integration package, which is extended on the basis of [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes), and has been integrated and optimized for many mainstream node packages to achieve the purpose of faster and more convenient use of ComfyUI. While ensuring the degree of freedom, it restores the ultimate smooth image production experience that belongs to Stable Diffusion.
|
| 11 |
+
|
| 12 |
+
[](https://github.com/yolain/ComfyUI-Yolain-Workflows)
|
| 13 |
+
|
| 14 |
+
## 👨🏻🎨 Introduce
|
| 15 |
+
|
| 16 |
+
- Inspire by [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes), which greatly reduces the time cost of tossing workflows。
|
| 17 |
+
- UI interface beautification, the first time you install the user, if you need to use the UI theme, please switch the theme in Settings -> Color Palette and refresh page.
|
| 18 |
+
- Added a node for pre-sampling parameter configuration, which can be separated from the sampling node for easier previewing
|
| 19 |
+
- Wildcards and lora's are supported, for Lora Block Weight usage, ensure that the custom node package has the [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack)
|
| 20 |
+
- Multi-selectable styled cue word selector, default is Fooocus style json, custom json can be placed under styles, samples folder can be placed in the preview image (name and name consistent, image file name such as spaces need to be converted to underscores '_')
|
| 21 |
+
- The loader enables the A1111 prompt mode, which reproduces nearly identical images to those generated by webui, and needs to be installed [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) first.
|
| 22 |
+
- Noise injection into the latent space can be achieved using the `easy latentNoisy` or `easy preSamplingNoiseIn` node
|
| 23 |
+
- Simplified processes for SD1.x, SD2.x, SDXL, SVD, Zero123, etc. [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableDiffusion)
|
| 24 |
+
- Simplified Stable Cascade [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableCascade)
|
| 25 |
+
- Simplified Layer Diffuse [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#LayerDiffusion),The first time you use it you may need to run `pip install -r requirements.txt` to install the required dependencies.
|
| 26 |
+
- Simplified InstantID [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#InstantID), You need to make sure that the custom node package has the [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
|
| 27 |
+
- Extending the usability of XYplot
|
| 28 |
+
- Fooocus Inpaint integration
|
| 29 |
+
- Integration of common logical calculations, conversion of types, display of all types, etc.
|
| 30 |
+
- Background removal nodes for the RMBG-1.4 model supporting BriaAI, [BriaAI Guide](https://huggingface.co/briaai/RMBG-1.4)
|
| 31 |
+
- Forcibly cleared the memory usage of the comfy UI model are supported
|
| 32 |
+
- Stable Diffusion 3 multi-account API nodes are supported
|
| 33 |
+
- Support Stable Diffusion 3 model
|
| 34 |
+
- Support Kolors model
|
| 35 |
+
|
| 36 |
+
## 👨🏻🔧 Installation
|
| 37 |
+
Clone the repo into the **custom_nodes** directory and install the requirements:
|
| 38 |
+
```shell
|
| 39 |
+
#1. Clone the repo
|
| 40 |
+
git clone https://github.com/yolain/ComfyUI-Easy-Use
|
| 41 |
+
#2. Install the requirements
|
| 42 |
+
Double-click install.bat to install the required dependencies
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
## ☕️ Plan
|
| 46 |
+
|
| 47 |
+
- [ ] Updated new front-end code for easier maintenance
|
| 48 |
+
- [x] Maintain css styles using sass
|
| 49 |
+
- [ ] Optimize existing extensions
|
| 50 |
+
- [ ] Add new components
|
| 51 |
+
- [ ] Add light theme
|
| 52 |
+
- [ ] Upload new workflows to [ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Yolain-Workflows) and translate readme to english version.
|
| 53 |
+
- [ ] Write gitbook with more detailed function introdution
|
| 54 |
+
|
| 55 |
+
## 📜 Changelog
|
| 56 |
+
|
| 57 |
+
**v1.2.1**
|
| 58 |
+
|
| 59 |
+
- Added **inspyrenet** to `easy imageRemBg`
|
| 60 |
+
- Added `easy controlnetLoader++`
|
| 61 |
+
- Added **PLUS (kolors genernal)** preset to `easy ipadapterApply` and `easy ipadapterApplyADV` (Supported kolors ipadapter)
|
| 62 |
+
- Added `easy kolorsLoader` - Code based on [MinusZoneAI](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ)'s and [kijai](https://github.com/kijai/ComfyUI-KwaiKolorsWrapper)'s repo, thanks for their contribution.
|
| 63 |
+
|
| 64 |
+
**v1.2.0**
|
| 65 |
+
|
| 66 |
+
- Added `easy pulIDApply` and `easy pulIDApplyADV`
|
| 67 |
+
- Added `easy huanyuanDiTLoader` and `easy pixArtLoader`
|
| 68 |
+
- Added **easy sliderControl** - Slider control node, which can currently be used to control the parameters of ipadapterMS (double-click the slider to reset to default)
|
| 69 |
+
- Added **layer_weights** in `easy ipadapterApplyADV`
|
| 70 |
+
|
| 71 |
+
**v1.1.9**
|
| 72 |
+
|
| 73 |
+
- Added **gitsScheduler**
|
| 74 |
+
- Added `easy imageBatchToImageList` and `easy imageListToImageBatch`
|
| 75 |
+
- Recursive subcategories nested for models
|
| 76 |
+
- Support for Stable Diffusion 3 model
|
| 77 |
+
- Added `easy applyInpaint` - All inpainting mode in this node
|
| 78 |
+
|
| 79 |
+
**v1.1.8**
|
| 80 |
+
|
| 81 |
+
- Added `easy controlnetStack`
|
| 82 |
+
- Added `easy applyBrushNet` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4brushnet_1.1.8.json)
|
| 83 |
+
- Added `easy applyPowerPaint` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4powerpaint_outpaint_1.1.8.json)
|
| 84 |
+
|
| 85 |
+
**v1.1.7**
|
| 86 |
+
|
| 87 |
+
- Added `easy prompt` - Subject and light presets, maybe adjusted later
|
| 88 |
+
- Added `easy icLightApply` - Light and shadow migration, Code based on [ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light)
|
| 89 |
+
- Added `easy imageSplitGrid`
|
| 90 |
+
- `easy kSamplerInpainting` added options such as different diffusion and brushnet in **additional** widget
|
| 91 |
+
- Support for brushnet model loading - [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)
|
| 92 |
+
- Added `easy applyFooocusInpaint` - Replace FooocusInpaintLoader
|
| 93 |
+
- Removed `easy fooocusInpaintLoader`
|
| 94 |
+
|
| 95 |
+
**v1.1.6**
|
| 96 |
+
|
| 97 |
+
- Added **alignYourSteps** to **schedulder** widget in all `easy preSampling` and `easy fullkSampler`
|
| 98 |
+
- Added **Preview&Choose** to **image_output** widget in `easy kSampler` & `easy fullkSampler`
|
| 99 |
+
- Added `easy styleAlignedBatchAlign` - Credit of [style_aligned_comfy](https://github.com/brianfitzgerald/style_aligned_comfy)
|
| 100 |
+
- Added `easy ckptNames`
|
| 101 |
+
- Added `easy controlnetNames`
|
| 102 |
+
- Added `easy imagesSplitimage` - Batch images split into single images
|
| 103 |
+
- Added `easy imageCount` - Get Image Count
|
| 104 |
+
- Added `easy textSwitch` - Text Switch
|
| 105 |
+
|
| 106 |
+
**v1.1.5**
|
| 107 |
+
|
| 108 |
+
- Rewrite `easy cleanGPUUsed` - the memory usage of the comfyUI can to be cleared
|
| 109 |
+
- Added `easy humanSegmentation` - Human Part Segmentation
|
| 110 |
+
- Added `easy imageColorMatch`
|
| 111 |
+
- Added `easy ipadapterApplyRegional`
|
| 112 |
+
- Added `easy ipadapterApplyFromParams`
|
| 113 |
+
- Added `easy imageInterrogator` - Image To Prompt
|
| 114 |
+
- Added `easy stableDiffusion3API` - Easy Stable Diffusion 3 Multiple accounts API Node
|
| 115 |
+
|
| 116 |
+
**v1.1.4**
|
| 117 |
+
|
| 118 |
+
- Added `easy preSamplingCustom` - Custom-PreSampling, can be supported cosXL-edit
|
| 119 |
+
- Added `easy ipadapterStyleComposition`
|
| 120 |
+
- Added the right-click menu to view checkpoints and lora information in all Loaders
|
| 121 |
+
- Fixed `easy preSamplingNoiseIn`、`easy latentNoisy`、`east Unsampler` compatible with ComfyUI Revision>=2098 [0542088e] or later
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
**v1.1.3**
|
| 125 |
+
|
| 126 |
+
- `easy ipadapterApply` Added **COMPOSITION** preset
|
| 127 |
+
- Supported [ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) when load ResAdapter lora
|
| 128 |
+
- Added `easy promptLine`
|
| 129 |
+
- Added `easy promptReplace`
|
| 130 |
+
- Added `easy promptConcat`
|
| 131 |
+
- `easy wildcards` Added **multiline_mode**
|
| 132 |
+
|
| 133 |
+
**v1.1.2**
|
| 134 |
+
|
| 135 |
+
- Optimized some of the recommended nodes for slots related to EasyUse
|
| 136 |
+
- Added **Enable ContextMenu Auto Nest Subdirectories** The setting item is enabled by default, and it can be classified into subdirectories, checkpoints and loras previews
|
| 137 |
+
- Added `easy sv3dLoader`
|
| 138 |
+
- Added `easy dynamiCrafterLoader`
|
| 139 |
+
- Added `easy ipadapterApply`
|
| 140 |
+
- Added `easy ipadapterApplyADV`
|
| 141 |
+
- Added `easy ipadapterApplyEncoder`
|
| 142 |
+
- Added `easy ipadapterApplyEmbeds`
|
| 143 |
+
- Added `easy preMaskDetailerFix`
|
| 144 |
+
- Fixed `easy stylesSelector` is change the prompt when not select the style
|
| 145 |
+
- Fixed `easy pipeEdit` error when add lora to prompt
|
| 146 |
+
- Fixed layerDiffuse xyplot bug
|
| 147 |
+
- `easy kSamplerInpainting` add *additional* widget,you can choose 'Differential Diffusion' or 'Only InpaintModelConditioning'
|
| 148 |
+
|
| 149 |
+
**v1.1.1**
|
| 150 |
+
|
| 151 |
+
- The issue that the seed is 0 when a node with a seed control is added and **control before generate** is fixed for the first time run queue prompt.
|
| 152 |
+
- `easy preSamplingAdvanced` Added **return_with_leftover_noise**
|
| 153 |
+
- Fixed `easy stylesSelector` error when choose the custom file
|
| 154 |
+
- `easy preSamplingLayerDiffusion` Added optional input parameter for mask
|
| 155 |
+
- Renamed all nodes widget name named seed_num to seed
|
| 156 |
+
- Remove forced **control_before_generate** settings。 If you want to use control_before_generate, change widget_value_control_mode to before in system settings
|
| 157 |
+
- Added `easy imageRemBg` - The default is BriaAI's RMBG-1.4 model, which removes the background effect more and faster
|
| 158 |
+
|
| 159 |
+
<details>
|
| 160 |
+
<summary><b>v1.1.0</b></summary>
|
| 161 |
+
|
| 162 |
+
- Added `easy imageSplitList` - to split every N images
|
| 163 |
+
- Added `easy preSamplingDiffusionADDTL` - It can modify foreground、background or blended additional prompt
|
| 164 |
+
- Added `easy preSamplingNoiseIn` It can replace the `easy latentNoisy` node that needs to be fronted to achieve better noise injection
|
| 165 |
+
- `easy pipeEdit` Added conditioning splicing mode selection, you can choose to replace, concat, combine, average, and set timestep range
|
| 166 |
+
- Added `easy pipeEdit` - nodes that can edit pipes (including re-enterable prompts)
|
| 167 |
+
- Added `easy preSamplingLayerDiffusion` and `easy kSamplerLayerDiffusion`
|
| 168 |
+
- Added a convenient menu to right-click on nodes such as Loader, Presampler, Sampler, Controlnet, etc. to quickly replace nodes of the same type
|
| 169 |
+
- Added `easy instantIDApplyADV` can link positive and negative
|
| 170 |
+
- Fixed layerDiffusion error when batch size greater than 1
|
| 171 |
+
- Fixed `easy wildcards` When LoRa is not filled in completely, LoRa is not automatically retrieved, resulting in failure to load LoRa
|
| 172 |
+
- Fixed the issue that 'BREAK' non-initiation when didn't use a1111 prompt style
|
| 173 |
+
- Fixed `easy instantIDApply` mask not input right
|
| 174 |
+
</details>
|
| 175 |
+
|
| 176 |
+
<details>
|
| 177 |
+
<summary><b>v1.0.9</b></summary>
|
| 178 |
+
|
| 179 |
+
- Fixed the error when ComfyUI-Impack-Pack and ComfyUI_InstantID were not installed
|
| 180 |
+
- Fixed `easy pipeIn`
|
| 181 |
+
- Added `easy instantIDApply` - you need installed [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) fisrt, Workflow[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#InstantID)
|
| 182 |
+
- Fixed `easy detailerFix` not added to the list of nodes available for saving images formatting extensions
|
| 183 |
+
- Fixed `easy XYInputs: PromptSR` errors are reported when replacing negative prompts
|
| 184 |
+
</details>
|
| 185 |
+
|
| 186 |
+
<details>
|
| 187 |
+
<summary><b>v1.0.8</b></summary>
|
| 188 |
+
|
| 189 |
+
- `easy cascadeLoader` stage_c and stage_b support the checkpoint model (Download [checkpoints](https://huggingface.co/stabilityai/stable-cascade/tree/main/comfyui_checkpoints) models)
|
| 190 |
+
- `easy styleSelector` The search box is modified to be case-insensitive
|
| 191 |
+
- `easy fullLoader` **positive**、**negative**、**latent** added to the output items
|
| 192 |
+
- Fixed the issue that 'easy preSampling' and other similar node, latent could not be generated based on the batch index after passing in
|
| 193 |
+
- Fixed `easy svdLoader` error when the positive or negative is empty
|
| 194 |
+
- Fixed the error of SDXLClipModel in ComfyUI revision 2016[c2cb8e88] and above (the revision number was judged to be compatible with the old revision)
|
| 195 |
+
- Fixed `easy detailerFix` generation error when batch size is greater than 1
|
| 196 |
+
- Optimize the code, reduce a lot of redundant code and improve the running speed
|
| 197 |
+
</details>
|
| 198 |
+
|
| 199 |
+
<details>
|
| 200 |
+
<summary><b>v1.0.7</b></summary>
|
| 201 |
+
|
| 202 |
+
- Added `easy cascadeLoader` - stable cascade Loader
|
| 203 |
+
- Added `easy preSamplingCascade` - stable cascade preSampling Settings
|
| 204 |
+
- Added `easy fullCascadeKSampler` - stable cascade stage-c ksampler full
|
| 205 |
+
- Added `easy cascadeKSampler` - stable cascade stage-c ksampler simple
|
| 206 |
+
-
|
| 207 |
+
- Optimize the image to image[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#image-to-image)
|
| 208 |
+
</details>
|
| 209 |
+
|
| 210 |
+
<details>
|
| 211 |
+
<summary><b>v1.0.6</b></summary>
|
| 212 |
+
|
| 213 |
+
- Added `easy XYInputs: Checkpoint`
|
| 214 |
+
- Added `easy XYInputs: Lora`
|
| 215 |
+
- `easy seed` can manually switch the random seed when increasing the fixed seed value
|
| 216 |
+
- Fixed `easy fullLoader` and all loaders to automatically adjust the node size when switching LoRa
|
| 217 |
+
- Removed the original ttn image saving logic and adapted to the default image saving format extension of ComfyUI
|
| 218 |
+
</details>
|
| 219 |
+
|
| 220 |
+
<details>
|
| 221 |
+
<summary><b>v1.0.5</b></summary>
|
| 222 |
+
|
| 223 |
+
- Added `easy isSDXL`
|
| 224 |
+
- Added prompt word control on `easy svdLoader`, which can be used with open_clip model
|
| 225 |
+
- Added **populated_text** on `easy wildcards`, wildcard populated text can be output
|
| 226 |
+
</details>
|
| 227 |
+
|
| 228 |
+
<details>
|
| 229 |
+
<summary><b>v1.0.4</b></summary>
|
| 230 |
+
|
| 231 |
+
- `easy showAnything` added support for converting other types (e.g., tensor conditions, images, etc.)
|
| 232 |
+
- Added `easy showLoaderSettingsNames` can display the model and VAE name in the output loader assembly
|
| 233 |
+
- Added `easy promptList`
|
| 234 |
+
- Added `easy fooocusInpaintLoader` (only the process of SDXLModel is supported)
|
| 235 |
+
- Added **Logic** nodes
|
| 236 |
+
- Added `easy imageSave` - Image saving node with date conversion and aspect and height formatting
|
| 237 |
+
- Added `easy joinImageBatch`
|
| 238 |
+
- `easy kSamplerInpainting` Added the **patch** input value to be used with the FooocusInpaintLoader node
|
| 239 |
+
|
| 240 |
+
- Fixed xyplot error when with Pillow>9.5
|
| 241 |
+
- Fixed `easy wildcards` An error is reported when running with the PS extension
|
| 242 |
+
- Fixed `easy XYInputs: ControlNet` Error
|
| 243 |
+
- Fixed `easy loraStack` error when **toggle** is disabled
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
- Changing the first-time install node package no longer automatically replaces the theme, you need to manually adjust and refresh the page
|
| 247 |
+
- `easy imageSave` added **only_preivew**
|
| 248 |
+
- Adjust the `easy latentCompositeMaskedWithCond` node
|
| 249 |
+
</details>
|
| 250 |
+
|
| 251 |
+
<details>
|
| 252 |
+
<summary><b>v1.0.3</b></summary>
|
| 253 |
+
|
| 254 |
+
- Added `easy stylesSelector`
|
| 255 |
+
- Added **scale_soft_weights** in `easy controlnetLoader` and `easy controlnetLoaderADV`
|
| 256 |
+
- Added the queue progress bar setting item, which is not enabled by default
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
- Fixed `easy XYInputs: Sampler/Scheduler` Error
|
| 260 |
+
- Fixed the right menu has a problem when clicking the button
|
| 261 |
+
- Fixed `easy comfyLoader` error
|
| 262 |
+
- Fixed xyPlot error when connecting to zero123
|
| 263 |
+
- Fixed the error message in the loader when the prompt word was component
|
| 264 |
+
- Fixed `easy getNode` and `easy setNode` the title does not change when loading
|
| 265 |
+
- Fixed all samplers using subdirectories to store images
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
- Adjust the UI theme, divided into two sets of styles: the official default background and the dark black background, which can be switched in the color palette in the settings
|
| 269 |
+
- Modify the styles path to be compatible with other environments
|
| 270 |
+
</details>
|
| 271 |
+
|
| 272 |
+
<details>
|
| 273 |
+
<summary><b>v1.0.2</b></summary>
|
| 274 |
+
|
| 275 |
+
- Added `easy XYPlotAdvanced` and some nodes about `easy XYInputs`
|
| 276 |
+
- Added **Alt+1-Alt+9** Shortcut keys to quickly paste node presets for Node templates (corresponding to 1~9 sequences)
|
| 277 |
+
- Added a `📜Groups Map(EasyUse)` to the context menu.
|
| 278 |
+
- An `autocomplete` folder has been added, If you have [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) installed, the txt files in that folder will be merged and overwritten to the autocomplete .txt file of the pyssss package at startup.
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
- Fixed XYPlot is not working when `a1111_prompt_style` is True
|
| 282 |
+
- Fixed UI loading failure in the new version of ComfyUI
|
| 283 |
+
- `easy XYInputs ModelMergeBlocks` Values can be imported from CSV files
|
| 284 |
+
- Fixed `easy pipeToBasicPipe` Bug
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
- Removed `easy imageRemBg`
|
| 288 |
+
- Remove the introductory diagram and workflow files from the package to reduce the package size
|
| 289 |
+
- Replaced the font file used in the generation of XY diagrams
|
| 290 |
+
</details>
|
| 291 |
+
|
| 292 |
+
<details>
|
| 293 |
+
<summary><b>v1.0.1</b></summary>
|
| 294 |
+
|
| 295 |
+
- Fixed `easy comfyLoader` error
|
| 296 |
+
- Fixed All nodes that contain the value of the image size
|
| 297 |
+
- Added `easy kSamplerInpainting`
|
| 298 |
+
- Added `easy pipeToBasicPipe`
|
| 299 |
+
- Fixed `width` and `height` can not customize in `easy svdLoader`
|
| 300 |
+
- Fixed all preview image path (Previously, it was not possible to preview the image on the Mac system)
|
| 301 |
+
- Fixed `vae_name` is not working in `easy fullLoader` and `easy a1111Loader` and `easy comfyLoader`
|
| 302 |
+
- Fixed `easy fullkSampler` outputs error
|
| 303 |
+
- Fixed `model_override` is not working in `easy fullLoader`
|
| 304 |
+
- Fixed `easy hiresFix` error
|
| 305 |
+
- Fixed `easy xyplot` font file path error
|
| 306 |
+
- Fixed seed that cannot be fixed when you convert `seed_num` to `easy seed`
|
| 307 |
+
- Fixed `easy pipeIn` inputs bug
|
| 308 |
+
- `easy preDetailerFix` have added a new parameter `optional_image`
|
| 309 |
+
- Fixed `easy zero123Loader` and `easy svdLoader` model into cache.
|
| 310 |
+
- Added `easy seed`
|
| 311 |
+
- Fixed `image_output` default value is "Preview"
|
| 312 |
+
- `easy fullLoader` and `easy a1111Loader` have added a new parameter `a1111_prompt_style`,that can reproduce the same image generated from stable-diffusion-webui on comfyui, but you need to install [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) to use this feature in the current version
|
| 313 |
+
</details>
|
| 314 |
+
|
| 315 |
+
<details>
|
| 316 |
+
<summary><b>v1.0.0</b></summary>
|
| 317 |
+
|
| 318 |
+
- Added `easy positive` - simple positive prompt text
|
| 319 |
+
- Added `easy negative` - simple negative prompt text
|
| 320 |
+
- Added `easy wildcards` - support for wildcards and hint text selected by Lora
|
| 321 |
+
- Added `easy portraitMaster` - PortraitMaster v2.2
|
| 322 |
+
- Added `easy loraStack` - Lora stack
|
| 323 |
+
- Added `easy fullLoader` - full version of the loader
|
| 324 |
+
- Added `easy zero123Loader` - simple zero123 loader
|
| 325 |
+
- Added `easy svdLoader` - easy svd loader
|
| 326 |
+
- Added `easy fullkSampler` - full version of the sampler (no separation)
|
| 327 |
+
- Added `easy hiresFix` - support for HD repair of Pipe
|
| 328 |
+
- Added `easy predetailerFix` and `easy DetailerFix` - support for Pipe detail fixing
|
| 329 |
+
- Added `easy ultralyticsDetectorPipe` and `easy samLoaderPipe` - Detect loader (detail fixed input)
|
| 330 |
+
- Added `easy pipein` `easy pipeout` - Pipe input and output
|
| 331 |
+
- Added `easy xyPlot` - simple xyplot (more controllable parameters will be updated in the future)
|
| 332 |
+
- Added `easy imageRemoveBG` - image to remove background
|
| 333 |
+
- Added `easy imagePixelPerfect` - image pixel perfect
|
| 334 |
+
- Added `easy poseEditor` - Pose editor
|
| 335 |
+
- New UI Theme (Obsidian) - Auto-load UI by default, which can also be changed in the settings
|
| 336 |
+
|
| 337 |
+
- Fixed `easy globalSeed` is not working
|
| 338 |
+
- Fixed an issue where all `seed_num` values were out of order due to [cg-use-everywhere](https://github.com/chrisgoringe/cg-use-everywhere) updating the chart in real time
|
| 339 |
+
- Fixed `easy imageSize`, `easy imageSizeBySide`, `easy imageSizeByLongerSide` as end nodes
|
| 340 |
+
- Fixed the bug that `seed_num` (random seed value) could not be read consistently in history
|
| 341 |
+
</details>
|
| 342 |
+
|
| 343 |
+
<details>
|
| 344 |
+
<summary><b>Updated at 12/14/2023</b></summary>
|
| 345 |
+
|
| 346 |
+
- `easy a1111Loader` and `easy comfyLoader` added `batch_size` of required input parameters
|
| 347 |
+
- Added the `easy controlnetLoaderADV` node
|
| 348 |
+
- `easy controlnetLoaderADV` and `easy controlnetLoader` added `control_net ` of optional input parameters
|
| 349 |
+
- `easy preSampling` and `easy preSamplingAdvanced` added `image_to_latent` optional input parameters
|
| 350 |
+
- Added the `easy imageSizeBySide` node, which can be output as a long side or a short side
|
| 351 |
+
</details>
|
| 352 |
+
|
| 353 |
+
<details>
|
| 354 |
+
<summary><b>Updated at 12/13/2023</b></summary>
|
| 355 |
+
|
| 356 |
+
- Added the `easy LLLiteLoader` node, if you have pre-installed the kohya-ss/ControlNet-LLLite-ComfyUI package, please move the model files in the models to `ComfyUI\models\controlnet\` (i.e. in the default controlnet path of comfy, please do not change the file name of the model, otherwise it will not be read).
|
| 357 |
+
- Modify `easy controlnetLoader` to the bottom of the loader category.
|
| 358 |
+
- Added size display for `easy imageSize` and `easy imageSizeByLongerSize` outputs.
|
| 359 |
+
</details>
|
| 360 |
+
|
| 361 |
+
<details>
|
| 362 |
+
<summary><b>Updated at 12/11/2023</b></summary>
|
| 363 |
+
- Added the `showSpentTime` node to display the time spent on image diffusion and the time spent on VAE decoding images
|
| 364 |
+
</details>
|
| 365 |
+
|
| 366 |
+
## The relevant node package involved
|
| 367 |
+
|
| 368 |
+
Disclaimer: Opened source was not easy. I have a lot of respect for the contributions of these original authors. I just did some integration and optimization.
|
| 369 |
+
|
| 370 |
+
| Nodes Name(Search Name) | Related libraries | Library-related node |
|
| 371 |
+
|:-------------------------------|:----------------------------------------------------------------------------|:-------------------------|
|
| 372 |
+
| easy setNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.SetNode |
|
| 373 |
+
| easy getNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.GetNode |
|
| 374 |
+
| easy bookmark | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | Bookmark 🔖 |
|
| 375 |
+
| easy portraitMarker | [comfyui-portrait-master](https://github.com/florestefano1975/comfyui-portrait-master) | Portrait Master |
|
| 376 |
+
| easy LLLiteLoader | [ControlNet-LLLite-ComfyUI](https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI) | LLLiteLoader |
|
| 377 |
+
| easy globalSeed | [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) | Global Seed (Inspire) |
|
| 378 |
+
| easy preSamplingDynamicCFG | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
|
| 379 |
+
| dynamicThresholdingFull | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
|
| 380 |
+
| easy imageInsetCrop | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | ImageInsetCrop |
|
| 381 |
+
| easy poseEditor | [ComfyUI_Custom_Nodes_AlekPet](https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet) | poseNode |
|
| 382 |
+
| easy preSamplingLayerDiffusion | [ComfyUI-layerdiffusion](https://github.com/huchenlei/ComfyUI-layerdiffusion) | LayeredDiffusionApply... |
|
| 383 |
+
| easy dynamiCrafterLoader | [ComfyUI-layerdiffusion](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter) | Apply Dynamicrafter |
|
| 384 |
+
| easy imageChooser | [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) | Preview Chooser |
|
| 385 |
+
| easy styleAlignedBatchAlign | [style_aligned_comfy](https://github.com/chrisgoringe/cg-image-picker) | styleAlignedBatchAlign |
|
| 386 |
+
| easy kolorsLoader | [ComfyUI-Kolors-MZ](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) | kolorsLoader |
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
## Credits
|
| 390 |
+
|
| 391 |
+
[ComfyUI](https://github.com/comfyanonymous/ComfyUI) - Powerful and modular Stable Diffusion GUI
|
| 392 |
+
|
| 393 |
+
[ComfyUI-ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) - ComfyUI Manager
|
| 394 |
+
|
| 395 |
+
[tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) - Pipe nodes (node bundles) allow users to reduce unnecessary connections
|
| 396 |
+
|
| 397 |
+
[ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) - Diffus3 gets and sets points that allow the user to detach the composition of the workflow
|
| 398 |
+
|
| 399 |
+
[ComfyUI-Impact-Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack) - General modpack 1
|
| 400 |
+
|
| 401 |
+
[ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) - General Modpack 2
|
| 402 |
+
|
| 403 |
+
[ComfyUI-ResAdapter](https://github.com/jiaxiangc/ComfyUI-ResAdapter) - Make model generation independent of training resolution
|
| 404 |
+
|
| 405 |
+
[ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) - Style migration
|
| 406 |
+
|
| 407 |
+
[ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) - Face migration
|
| 408 |
+
|
| 409 |
+
[ComfyUI_PuLID](https://github.com/cubiq/PuLID_ComfyUI) - Face migration
|
| 410 |
+
|
| 411 |
+
[ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) - pyssss🐍
|
| 412 |
+
|
| 413 |
+
[cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) - Image Preview Chooser
|
| 414 |
+
|
| 415 |
+
[ComfyUI_ExtraModels](https://github.com/city96/ComfyUI_ExtraModels) - DiT custom nodes
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
## 🌟Stargazers
|
| 419 |
+
|
| 420 |
+
My gratitude extends to the generous souls who bestow a star. Your support is much appreciated!
|
| 421 |
+
|
| 422 |
+
[](https://github.com/yolain/ComfyUI-Easy-Use/stargazers)
|
ComfyUI-Easy-Use/README.md
ADDED
|
@@ -0,0 +1,448 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
|
| 3 |
+
<div align="center">
|
| 4 |
+
<a href="https://space.bilibili.com/1840885116">Video Tutorial</a> |
|
| 5 |
+
Docs (Cooming Soon) |
|
| 6 |
+
<a href="https://github.com/yolain/ComfyUI-Yolain-Workflows">Workflow Collection</a> |
|
| 7 |
+
<a href="#%EF%B8%8F-donation">Donation</a>
|
| 8 |
+
<br><br>
|
| 9 |
+
<a href="./README.md"><img src="https://img.shields.io/badge/🇬🇧English-0b8cf5"></a>
|
| 10 |
+
<a href="./README.ZH_CN.md"><img src="https://img.shields.io/badge/🇨🇳中文简体-e9e9e9"></a>
|
| 11 |
+
</div>
|
| 12 |
+
|
| 13 |
+
**ComfyUI-Easy-Use** is an efficiency custom nodes integration package, which is extended on the basis of [TinyTerraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes). It has been integrated and optimized for many popular awesome custom nodes to achieve the purpose of faster and more convenient use of ComfyUI. While ensuring the degree of freedom, it restores the ultimate smooth image production experience that belongs to Stable Diffusion.
|
| 14 |
+
|
| 15 |
+
## 👨🏻🎨 Introduce
|
| 16 |
+
|
| 17 |
+
- Inspire by [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes), which greatly reduces the time cost of tossing workflows。
|
| 18 |
+
- UI interface beautification, the first time you install the user, if you need to use the UI theme, please switch the theme in Settings -> Color Palette and refresh page.
|
| 19 |
+
- Added a node for pre-sampling parameter configuration, which can be separated from the sampling node for easier previewing
|
| 20 |
+
- Wildcards and lora's are supported, for Lora Block Weight usage, ensure that the custom node package has the [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack)
|
| 21 |
+
- Multi-selectable styled cue word selector, default is Fooocus style json, custom json can be placed under styles, samples folder can be placed in the preview image (name and name consistent, image file name such as spaces need to be converted to underscores '_')
|
| 22 |
+
- The loader enables the A1111 prompt mode, which reproduces nearly identical images to those generated by webui, and needs to be installed [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) first.
|
| 23 |
+
- Noise injection into the latent space can be achieved using the `easy latentNoisy` or `easy preSamplingNoiseIn` node
|
| 24 |
+
- Simplified processes for SD1.x, SD2.x, SDXL, SVD, Zero123, etc. [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableDiffusion)
|
| 25 |
+
- Simplified Stable Cascade [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableCascade)
|
| 26 |
+
- Simplified Layer Diffuse [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#LayerDiffusion),The first time you use it you may need to run `pip install -r requirements.txt` to install the required dependencies.
|
| 27 |
+
- Simplified InstantID [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#InstantID), You need to make sure that the custom node package has the [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
|
| 28 |
+
- Extending the usability of XYplot
|
| 29 |
+
- Fooocus Inpaint integration
|
| 30 |
+
- Integration of common logical calculations, conversion of types, display of all types, etc.
|
| 31 |
+
- Background removal nodes for the RMBG-1.4 model supporting BriaAI, [BriaAI Guide](https://huggingface.co/briaai/RMBG-1.4)
|
| 32 |
+
- Forcibly cleared the memory usage of the comfy UI model are supported
|
| 33 |
+
- Stable Diffusion 3 multi-account API nodes are supported
|
| 34 |
+
- Support SD3's model
|
| 35 |
+
- Support Kolors‘s model
|
| 36 |
+
- Support Flux's model
|
| 37 |
+
|
| 38 |
+
## 👨🏻🔧 Installation
|
| 39 |
+
Clone the repo into the **custom_nodes** directory and install the requirements:
|
| 40 |
+
```shell
|
| 41 |
+
#1. Clone the repo
|
| 42 |
+
git clone https://github.com/yolain/ComfyUI-Easy-Use
|
| 43 |
+
#2. Install the requirements
|
| 44 |
+
Double-click install.bat to install the required dependencies
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
## 👨🏻🚀 Plan
|
| 48 |
+
|
| 49 |
+
- [x] Updated new front-end code for easier maintenance
|
| 50 |
+
- [x] Maintain css styles using sass
|
| 51 |
+
- [x] Optimize existing extensions
|
| 52 |
+
- [x] Add new components
|
| 53 |
+
- [ ] Upload new workflows to [ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Yolain-Workflows) and translate readme to english version.
|
| 54 |
+
- [ ] Write gitbook with more detailed function introdution
|
| 55 |
+
|
| 56 |
+
## 📜 Changelog
|
| 57 |
+
|
| 58 |
+
**v1.2.2**
|
| 59 |
+
|
| 60 |
+
- Added v2 web frond-end code
|
| 61 |
+
- Added `easy fluxLoader`
|
| 62 |
+
- Added support for `controlnetApply` Related nodes with SD3 and hunyuanDiT
|
| 63 |
+
|
| 64 |
+
**v1.2.1**
|
| 65 |
+
|
| 66 |
+
- Added `easy ipadapterApplyFaceIDKolors`
|
| 67 |
+
- Added **inspyrenet** to `easy imageRemBg`
|
| 68 |
+
- Added `easy controlnetLoader++`
|
| 69 |
+
- Added **PLUS (kolors genernal)** and **FACEID PLUS KOLORS** preset to `easy ipadapterApply` and `easy ipadapterApplyADV` (Supported kolors ipadapter)
|
| 70 |
+
- Added `easy kolorsLoader` - Code based on [MinusZoneAI](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ)'s and [kijai](https://github.com/kijai/ComfyUI-KwaiKolorsWrapper)'s repo, thanks for their contribution.
|
| 71 |
+
|
| 72 |
+
**v1.2.0**
|
| 73 |
+
|
| 74 |
+
- Added `easy pulIDApply` and `easy pulIDApplyADV`
|
| 75 |
+
- Added `easy huanyuanDiTLoader` and `easy pixArtLoader`
|
| 76 |
+
- Added **easy sliderControl** - Slider control node, which can currently be used to control the parameters of ipadapterMS (double-click the slider to reset to default)
|
| 77 |
+
- Added **layer_weights** in `easy ipadapterApplyADV`
|
| 78 |
+
|
| 79 |
+
**v1.1.9**
|
| 80 |
+
|
| 81 |
+
- Added **gitsScheduler**
|
| 82 |
+
- Added `easy imageBatchToImageList` and `easy imageListToImageBatch`
|
| 83 |
+
- Recursive subcategories nested for models
|
| 84 |
+
- Support for Stable Diffusion 3 model
|
| 85 |
+
- Added `easy applyInpaint` - All inpainting mode in this node
|
| 86 |
+
|
| 87 |
+
**v1.1.8**
|
| 88 |
+
|
| 89 |
+
- Added `easy controlnetStack`
|
| 90 |
+
- Added `easy applyBrushNet` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4brushnet_1.1.8.json)
|
| 91 |
+
- Added `easy applyPowerPaint` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4powerpaint_outpaint_1.1.8.json)
|
| 92 |
+
|
| 93 |
+
**v1.1.7**
|
| 94 |
+
|
| 95 |
+
- Added `easy prompt` - Subject and light presets, maybe adjusted later
|
| 96 |
+
- Added `easy icLightApply` - Light and shadow migration, Code based on [ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light)
|
| 97 |
+
- Added `easy imageSplitGrid`
|
| 98 |
+
- `easy kSamplerInpainting` added options such as different diffusion and brushnet in **additional** widget
|
| 99 |
+
- Support for brushnet model loading - [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)
|
| 100 |
+
- Added `easy applyFooocusInpaint` - Replace FooocusInpaintLoader
|
| 101 |
+
- Removed `easy fooocusInpaintLoader`
|
| 102 |
+
|
| 103 |
+
**v1.1.6**
|
| 104 |
+
|
| 105 |
+
- Added **alignYourSteps** to **schedulder** widget in all `easy preSampling` and `easy fullkSampler`
|
| 106 |
+
- Added **Preview&Choose** to **image_output** widget in `easy kSampler` & `easy fullkSampler`
|
| 107 |
+
- Added `easy styleAlignedBatchAlign` - Credit of [style_aligned_comfy](https://github.com/brianfitzgerald/style_aligned_comfy)
|
| 108 |
+
- Added `easy ckptNames`
|
| 109 |
+
- Added `easy controlnetNames`
|
| 110 |
+
- Added `easy imagesSplitimage` - Batch images split into single images
|
| 111 |
+
- Added `easy imageCount` - Get Image Count
|
| 112 |
+
- Added `easy textSwitch` - Text Switch
|
| 113 |
+
|
| 114 |
+
**v1.1.5**
|
| 115 |
+
|
| 116 |
+
- Rewrite `easy cleanGPUUsed` - the memory usage of the comfyUI can to be cleared
|
| 117 |
+
- Added `easy humanSegmentation` - Human Part Segmentation
|
| 118 |
+
- Added `easy imageColorMatch`
|
| 119 |
+
- Added `easy ipadapterApplyRegional`
|
| 120 |
+
- Added `easy ipadapterApplyFromParams`
|
| 121 |
+
- Added `easy imageInterrogator` - Image To Prompt
|
| 122 |
+
- Added `easy stableDiffusion3API` - Easy Stable Diffusion 3 Multiple accounts API Node
|
| 123 |
+
|
| 124 |
+
**v1.1.4**
|
| 125 |
+
|
| 126 |
+
- Added `easy preSamplingCustom` - Custom-PreSampling, can be supported cosXL-edit
|
| 127 |
+
- Added `easy ipadapterStyleComposition`
|
| 128 |
+
- Added the right-click menu to view checkpoints and lora information in all Loaders
|
| 129 |
+
- Fixed `easy preSamplingNoiseIn`、`easy latentNoisy`、`east Unsampler` compatible with ComfyUI Revision>=2098 [0542088e] or later
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
**v1.1.3**
|
| 133 |
+
|
| 134 |
+
- `easy ipadapterApply` Added **COMPOSITION** preset
|
| 135 |
+
- Supported [ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) when load ResAdapter lora
|
| 136 |
+
- Added `easy promptLine`
|
| 137 |
+
- Added `easy promptReplace`
|
| 138 |
+
- Added `easy promptConcat`
|
| 139 |
+
- `easy wildcards` Added **multiline_mode**
|
| 140 |
+
|
| 141 |
+
<details>
|
| 142 |
+
<summary><b>v1.1.2</b></summary>
|
| 143 |
+
|
| 144 |
+
- Optimized some of the recommended nodes for slots related to EasyUse
|
| 145 |
+
- Added **Enable ContextMenu Auto Nest Subdirectories** The setting item is enabled by default, and it can be classified into subdirectories, checkpoints and loras previews
|
| 146 |
+
- Added `easy sv3dLoader`
|
| 147 |
+
- Added `easy dynamiCrafterLoader`
|
| 148 |
+
- Added `easy ipadapterApply`
|
| 149 |
+
- Added `easy ipadapterApplyADV`
|
| 150 |
+
- Added `easy ipadapterApplyEncoder`
|
| 151 |
+
- Added `easy ipadapterApplyEmbeds`
|
| 152 |
+
- Added `easy preMaskDetailerFix`
|
| 153 |
+
- Fixed `easy stylesSelector` is change the prompt when not select the style
|
| 154 |
+
- Fixed `easy pipeEdit` error when add lora to prompt
|
| 155 |
+
- Fixed layerDiffuse xyplot bug
|
| 156 |
+
- `easy kSamplerInpainting` add *additional* widget,you can choose 'Differential Diffusion' or 'Only InpaintModelConditioning'
|
| 157 |
+
</details>
|
| 158 |
+
|
| 159 |
+
<details>
|
| 160 |
+
<summary><b>v1.1.1</b></summary>
|
| 161 |
+
|
| 162 |
+
- The issue that the seed is 0 when a node with a seed control is added and **control before generate** is fixed for the first time run queue prompt.
|
| 163 |
+
- `easy preSamplingAdvanced` Added **return_with_leftover_noise**
|
| 164 |
+
- Fixed `easy stylesSelector` error when choose the custom file
|
| 165 |
+
- `easy preSamplingLayerDiffusion` Added optional input parameter for mask
|
| 166 |
+
- Renamed all nodes widget name named seed_num to seed
|
| 167 |
+
- Remove forced **control_before_generate** settings。 If you want to use control_before_generate, change widget_value_control_mode to before in system settings
|
| 168 |
+
- Added `easy imageRemBg` - The default is BriaAI's RMBG-1.4 model, which removes the background effect more and faster
|
| 169 |
+
</details>
|
| 170 |
+
|
| 171 |
+
<details>
|
| 172 |
+
<summary><b>v1.1.0</b></summary>
|
| 173 |
+
|
| 174 |
+
- Added `easy imageSplitList` - to split every N images
|
| 175 |
+
- Added `easy preSamplingDiffusionADDTL` - It can modify foreground、background or blended additional prompt
|
| 176 |
+
- Added `easy preSamplingNoiseIn` It can replace the `easy latentNoisy` node that needs to be fronted to achieve better noise injection
|
| 177 |
+
- `easy pipeEdit` Added conditioning splicing mode selection, you can choose to replace, concat, combine, average, and set timestep range
|
| 178 |
+
- Added `easy pipeEdit` - nodes that can edit pipes (including re-enterable prompts)
|
| 179 |
+
- Added `easy preSamplingLayerDiffusion` and `easy kSamplerLayerDiffusion`
|
| 180 |
+
- Added a convenient menu to right-click on nodes such as Loader, Presampler, Sampler, Controlnet, etc. to quickly replace nodes of the same type
|
| 181 |
+
- Added `easy instantIDApplyADV` can link positive and negative
|
| 182 |
+
- Fixed layerDiffusion error when batch size greater than 1
|
| 183 |
+
- Fixed `easy wildcards` When LoRa is not filled in completely, LoRa is not automatically retrieved, resulting in failure to load LoRa
|
| 184 |
+
- Fixed the issue that 'BREAK' non-initiation when didn't use a1111 prompt style
|
| 185 |
+
- Fixed `easy instantIDApply` mask not input right
|
| 186 |
+
</details>
|
| 187 |
+
|
| 188 |
+
<details>
|
| 189 |
+
<summary><b>v1.0.9</b></summary>
|
| 190 |
+
|
| 191 |
+
- Fixed the error when ComfyUI-Impack-Pack and ComfyUI_InstantID were not installed
|
| 192 |
+
- Fixed `easy pipeIn`
|
| 193 |
+
- Added `easy instantIDApply` - you need installed [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) fisrt, Workflow[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#InstantID)
|
| 194 |
+
- Fixed `easy detailerFix` not added to the list of nodes available for saving images formatting extensions
|
| 195 |
+
- Fixed `easy XYInputs: PromptSR` errors are reported when replacing negative prompts
|
| 196 |
+
</details>
|
| 197 |
+
|
| 198 |
+
<details>
|
| 199 |
+
<summary><b>v1.0.8</b></summary>
|
| 200 |
+
|
| 201 |
+
- `easy cascadeLoader` stage_c and stage_b support the checkpoint model (Download [checkpoints](https://huggingface.co/stabilityai/stable-cascade/tree/main/comfyui_checkpoints) models)
|
| 202 |
+
- `easy styleSelector` The search box is modified to be case-insensitive
|
| 203 |
+
- `easy fullLoader` **positive**、**negative**、**latent** added to the output items
|
| 204 |
+
- Fixed the issue that 'easy preSampling' and other similar node, latent could not be generated based on the batch index after passing in
|
| 205 |
+
- Fixed `easy svdLoader` error when the positive or negative is empty
|
| 206 |
+
- Fixed the error of SDXLClipModel in ComfyUI revision 2016[c2cb8e88] and above (the revision number was judged to be compatible with the old revision)
|
| 207 |
+
- Fixed `easy detailerFix` generation error when batch size is greater than 1
|
| 208 |
+
- Optimize the code, reduce a lot of redundant code and improve the running speed
|
| 209 |
+
</details>
|
| 210 |
+
|
| 211 |
+
<details>
|
| 212 |
+
<summary><b>v1.0.7</b></summary>
|
| 213 |
+
|
| 214 |
+
- Added `easy cascadeLoader` - stable cascade Loader
|
| 215 |
+
- Added `easy preSamplingCascade` - stable cascade preSampling Settings
|
| 216 |
+
- Added `easy fullCascadeKSampler` - stable cascade stage-c ksampler full
|
| 217 |
+
- Added `easy cascadeKSampler` - stable cascade stage-c ksampler simple
|
| 218 |
+
-
|
| 219 |
+
- Optimize the image to image[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#image-to-image)
|
| 220 |
+
</details>
|
| 221 |
+
|
| 222 |
+
<details>
|
| 223 |
+
<summary><b>v1.0.6</b></summary>
|
| 224 |
+
|
| 225 |
+
- Added `easy XYInputs: Checkpoint`
|
| 226 |
+
- Added `easy XYInputs: Lora`
|
| 227 |
+
- `easy seed` can manually switch the random seed when increasing the fixed seed value
|
| 228 |
+
- Fixed `easy fullLoader` and all loaders to automatically adjust the node size when switching LoRa
|
| 229 |
+
- Removed the original ttn image saving logic and adapted to the default image saving format extension of ComfyUI
|
| 230 |
+
</details>
|
| 231 |
+
|
| 232 |
+
<details>
|
| 233 |
+
<summary><b>v1.0.5</b></summary>
|
| 234 |
+
|
| 235 |
+
- Added `easy isSDXL`
|
| 236 |
+
- Added prompt word control on `easy svdLoader`, which can be used with open_clip model
|
| 237 |
+
- Added **populated_text** on `easy wildcards`, wildcard populated text can be output
|
| 238 |
+
</details>
|
| 239 |
+
|
| 240 |
+
<details>
|
| 241 |
+
<summary><b>v1.0.4</b></summary>
|
| 242 |
+
|
| 243 |
+
- `easy showAnything` added support for converting other types (e.g., tensor conditions, images, etc.)
|
| 244 |
+
- Added `easy showLoaderSettingsNames` can display the model and VAE name in the output loader assembly
|
| 245 |
+
- Added `easy promptList`
|
| 246 |
+
- Added `easy fooocusInpaintLoader` (only the process of SDXLModel is supported)
|
| 247 |
+
- Added **Logic** nodes
|
| 248 |
+
- Added `easy imageSave` - Image saving node with date conversion and aspect and height formatting
|
| 249 |
+
- Added `easy joinImageBatch`
|
| 250 |
+
- `easy kSamplerInpainting` Added the **patch** input value to be used with the FooocusInpaintLoader node
|
| 251 |
+
|
| 252 |
+
- Fixed xyplot error when with Pillow>9.5
|
| 253 |
+
- Fixed `easy wildcards` An error is reported when running with the PS extension
|
| 254 |
+
- Fixed `easy XYInputs: ControlNet` Error
|
| 255 |
+
- Fixed `easy loraStack` error when **toggle** is disabled
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
- Changing the first-time install node package no longer automatically replaces the theme, you need to manually adjust and refresh the page
|
| 259 |
+
- `easy imageSave` added **only_preivew**
|
| 260 |
+
- Adjust the `easy latentCompositeMaskedWithCond` node
|
| 261 |
+
</details>
|
| 262 |
+
|
| 263 |
+
<details>
|
| 264 |
+
<summary><b>v1.0.3</b></summary>
|
| 265 |
+
|
| 266 |
+
- Added `easy stylesSelector`
|
| 267 |
+
- Added **scale_soft_weights** in `easy controlnetLoader` and `easy controlnetLoaderADV`
|
| 268 |
+
- Added the queue progress bar setting item, which is not enabled by default
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
- Fixed `easy XYInputs: Sampler/Scheduler` Error
|
| 272 |
+
- Fixed the right menu has a problem when clicking the button
|
| 273 |
+
- Fixed `easy comfyLoader` error
|
| 274 |
+
- Fixed xyPlot error when connecting to zero123
|
| 275 |
+
- Fixed the error message in the loader when the prompt word was component
|
| 276 |
+
- Fixed `easy getNode` and `easy setNode` the title does not change when loading
|
| 277 |
+
- Fixed all samplers using subdirectories to store images
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
- Adjust the UI theme, divided into two sets of styles: the official default background and the dark black background, which can be switched in the color palette in the settings
|
| 281 |
+
- Modify the styles path to be compatible with other environments
|
| 282 |
+
</details>
|
| 283 |
+
|
| 284 |
+
<details>
|
| 285 |
+
<summary><b>v1.0.2</b></summary>
|
| 286 |
+
|
| 287 |
+
- Added `easy XYPlotAdvanced` and some nodes about `easy XYInputs`
|
| 288 |
+
- Added **Alt+1-Alt+9** Shortcut keys to quickly paste node presets for Node templates (corresponding to 1~9 sequences)
|
| 289 |
+
- Added a `📜Groups Map(EasyUse)` to the context menu.
|
| 290 |
+
- An `autocomplete` folder has been added, If you have [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) installed, the txt files in that folder will be merged and overwritten to the autocomplete .txt file of the pyssss package at startup.
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
- Fixed XYPlot is not working when `a1111_prompt_style` is True
|
| 294 |
+
- Fixed UI loading failure in the new version of ComfyUI
|
| 295 |
+
- `easy XYInputs ModelMergeBlocks` Values can be imported from CSV files
|
| 296 |
+
- Fixed `easy pipeToBasicPipe` Bug
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
- Removed `easy imageRemBg`
|
| 300 |
+
- Remove the introductory diagram and workflow files from the package to reduce the package size
|
| 301 |
+
- Replaced the font file used in the generation of XY diagrams
|
| 302 |
+
</details>
|
| 303 |
+
|
| 304 |
+
<details>
|
| 305 |
+
<summary><b>v1.0.1</b></summary>
|
| 306 |
+
|
| 307 |
+
- Fixed `easy comfyLoader` error
|
| 308 |
+
- Fixed All nodes that contain the value of the image size
|
| 309 |
+
- Added `easy kSamplerInpainting`
|
| 310 |
+
- Added `easy pipeToBasicPipe`
|
| 311 |
+
- Fixed `width` and `height` can not customize in `easy svdLoader`
|
| 312 |
+
- Fixed all preview image path (Previously, it was not possible to preview the image on the Mac system)
|
| 313 |
+
- Fixed `vae_name` is not working in `easy fullLoader` and `easy a1111Loader` and `easy comfyLoader`
|
| 314 |
+
- Fixed `easy fullkSampler` outputs error
|
| 315 |
+
- Fixed `model_override` is not working in `easy fullLoader`
|
| 316 |
+
- Fixed `easy hiresFix` error
|
| 317 |
+
- Fixed `easy xyplot` font file path error
|
| 318 |
+
- Fixed seed that cannot be fixed when you convert `seed_num` to `easy seed`
|
| 319 |
+
- Fixed `easy pipeIn` inputs bug
|
| 320 |
+
- `easy preDetailerFix` have added a new parameter `optional_image`
|
| 321 |
+
- Fixed `easy zero123Loader` and `easy svdLoader` model into cache.
|
| 322 |
+
- Added `easy seed`
|
| 323 |
+
- Fixed `image_output` default value is "Preview"
|
| 324 |
+
- `easy fullLoader` and `easy a1111Loader` have added a new parameter `a1111_prompt_style`,that can reproduce the same image generated from stable-diffusion-webui on comfyui, but you need to install [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) to use this feature in the current version
|
| 325 |
+
</details>
|
| 326 |
+
|
| 327 |
+
<details>
|
| 328 |
+
<summary><b>v1.0.0</b></summary>
|
| 329 |
+
|
| 330 |
+
- Added `easy positive` - simple positive prompt text
|
| 331 |
+
- Added `easy negative` - simple negative prompt text
|
| 332 |
+
- Added `easy wildcards` - support for wildcards and hint text selected by Lora
|
| 333 |
+
- Added `easy portraitMaster` - PortraitMaster v2.2
|
| 334 |
+
- Added `easy loraStack` - Lora stack
|
| 335 |
+
- Added `easy fullLoader` - full version of the loader
|
| 336 |
+
- Added `easy zero123Loader` - simple zero123 loader
|
| 337 |
+
- Added `easy svdLoader` - easy svd loader
|
| 338 |
+
- Added `easy fullkSampler` - full version of the sampler (no separation)
|
| 339 |
+
- Added `easy hiresFix` - support for HD repair of Pipe
|
| 340 |
+
- Added `easy predetailerFix` and `easy DetailerFix` - support for Pipe detail fixing
|
| 341 |
+
- Added `easy ultralyticsDetectorPipe` and `easy samLoaderPipe` - Detect loader (detail fixed input)
|
| 342 |
+
- Added `easy pipein` `easy pipeout` - Pipe input and output
|
| 343 |
+
- Added `easy xyPlot` - simple xyplot (more controllable parameters will be updated in the future)
|
| 344 |
+
- Added `easy imageRemoveBG` - image to remove background
|
| 345 |
+
- Added `easy imagePixelPerfect` - image pixel perfect
|
| 346 |
+
- Added `easy poseEditor` - Pose editor
|
| 347 |
+
- New UI Theme (Obsidian) - Auto-load UI by default, which can also be changed in the settings
|
| 348 |
+
|
| 349 |
+
- Fixed `easy globalSeed` is not working
|
| 350 |
+
- Fixed an issue where all `seed_num` values were out of order due to [cg-use-everywhere](https://github.com/chrisgoringe/cg-use-everywhere) updating the chart in real time
|
| 351 |
+
- Fixed `easy imageSize`, `easy imageSizeBySide`, `easy imageSizeByLongerSide` as end nodes
|
| 352 |
+
- Fixed the bug that `seed_num` (random seed value) could not be read consistently in history
|
| 353 |
+
</details>
|
| 354 |
+
|
| 355 |
+
<details>
|
| 356 |
+
<summary><b>Updated at 12/14/2023</b></summary>
|
| 357 |
+
|
| 358 |
+
- `easy a1111Loader` and `easy comfyLoader` added `batch_size` of required input parameters
|
| 359 |
+
- Added the `easy controlnetLoaderADV` node
|
| 360 |
+
- `easy controlnetLoaderADV` and `easy controlnetLoader` added `control_net ` of optional input parameters
|
| 361 |
+
- `easy preSampling` and `easy preSamplingAdvanced` added `image_to_latent` optional input parameters
|
| 362 |
+
- Added the `easy imageSizeBySide` node, which can be output as a long side or a short side
|
| 363 |
+
</details>
|
| 364 |
+
|
| 365 |
+
<details>
|
| 366 |
+
<summary><b>Updated at 12/13/2023</b></summary>
|
| 367 |
+
|
| 368 |
+
- Added the `easy LLLiteLoader` node, if you have pre-installed the kohya-ss/ControlNet-LLLite-ComfyUI package, please move the model files in the models to `ComfyUI\models\controlnet\` (i.e. in the default controlnet path of comfy, please do not change the file name of the model, otherwise it will not be read).
|
| 369 |
+
- Modify `easy controlnetLoader` to the bottom of the loader category.
|
| 370 |
+
- Added size display for `easy imageSize` and `easy imageSizeByLongerSize` outputs.
|
| 371 |
+
</details>
|
| 372 |
+
|
| 373 |
+
<details>
|
| 374 |
+
<summary><b>Updated at 12/11/2023</b></summary>
|
| 375 |
+
- Added the `showSpentTime` node to display the time spent on image diffusion and the time spent on VAE decoding images
|
| 376 |
+
</details>
|
| 377 |
+
|
| 378 |
+
## The relevant node package involved
|
| 379 |
+
|
| 380 |
+
Disclaimer: Opened source was not easy. I have a lot of respect for the contributions of these original authors. I just did some integration and optimization.
|
| 381 |
+
|
| 382 |
+
| Nodes Name(Search Name) | Related libraries | Library-related node |
|
| 383 |
+
|:-------------------------------|:----------------------------------------------------------------------------|:-------------------------|
|
| 384 |
+
| easy setNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.SetNode |
|
| 385 |
+
| easy getNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.GetNode |
|
| 386 |
+
| easy bookmark | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | Bookmark 🔖 |
|
| 387 |
+
| easy portraitMarker | [comfyui-portrait-master](https://github.com/florestefano1975/comfyui-portrait-master) | Portrait Master |
|
| 388 |
+
| easy LLLiteLoader | [ControlNet-LLLite-ComfyUI](https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI) | LLLiteLoader |
|
| 389 |
+
| easy globalSeed | [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) | Global Seed (Inspire) |
|
| 390 |
+
| easy preSamplingDynamicCFG | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
|
| 391 |
+
| dynamicThresholdingFull | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
|
| 392 |
+
| easy imageInsetCrop | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | ImageInsetCrop |
|
| 393 |
+
| easy poseEditor | [ComfyUI_Custom_Nodes_AlekPet](https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet) | poseNode |
|
| 394 |
+
| easy preSamplingLayerDiffusion | [ComfyUI-layerdiffusion](https://github.com/huchenlei/ComfyUI-layerdiffusion) | LayeredDiffusionApply... |
|
| 395 |
+
| easy dynamiCrafterLoader | [ComfyUI-layerdiffusion](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter) | Apply Dynamicrafter |
|
| 396 |
+
| easy imageChooser | [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) | Preview Chooser |
|
| 397 |
+
| easy styleAlignedBatchAlign | [style_aligned_comfy](https://github.com/chrisgoringe/cg-image-picker) | styleAlignedBatchAlign |
|
| 398 |
+
| easy kolorsLoader | [ComfyUI-Kolors-MZ](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) | kolorsLoader |
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
## Credits
|
| 402 |
+
|
| 403 |
+
[ComfyUI](https://github.com/comfyanonymous/ComfyUI) - Powerful and modular Stable Diffusion GUI
|
| 404 |
+
|
| 405 |
+
[ComfyUI-ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) - ComfyUI Manager
|
| 406 |
+
|
| 407 |
+
[tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) - Pipe nodes (node bundles) allow users to reduce unnecessary connections
|
| 408 |
+
|
| 409 |
+
[ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) - Diffus3 gets and sets points that allow the user to detach the composition of the workflow
|
| 410 |
+
|
| 411 |
+
[ComfyUI-Impact-Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack) - General modpack 1
|
| 412 |
+
|
| 413 |
+
[ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) - General Modpack 2
|
| 414 |
+
|
| 415 |
+
[ComfyUI-ResAdapter](https://github.com/jiaxiangc/ComfyUI-ResAdapter) - Make model generation independent of training resolution
|
| 416 |
+
|
| 417 |
+
[ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) - Style migration
|
| 418 |
+
|
| 419 |
+
[ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) - Face migration
|
| 420 |
+
|
| 421 |
+
[ComfyUI_PuLID](https://github.com/cubiq/PuLID_ComfyUI) - Face migration
|
| 422 |
+
|
| 423 |
+
[ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) - pyssss🐍
|
| 424 |
+
|
| 425 |
+
[cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) - Image Preview Chooser
|
| 426 |
+
|
| 427 |
+
[ComfyUI_ExtraModels](https://github.com/city96/ComfyUI_ExtraModels) - DiT custom nodes
|
| 428 |
+
|
| 429 |
+
## ☕️ Donation
|
| 430 |
+
|
| 431 |
+
**Comfyui-Easy-Use** is an GPL-licensed open source project. In order to achieve better and sustainable development of the project, i expect to gain more backers. <br>
|
| 432 |
+
If my custom nodes has added value to your day, consider indulging in a coffee to fuel it further! <br>
|
| 433 |
+
💖You can support me in any of the following ways:
|
| 434 |
+
|
| 435 |
+
- [BiliBili](https://space.bilibili.com/1840885116)
|
| 436 |
+
- [Afdian](https://afdian.com/a/yolain)
|
| 437 |
+
- [Wechat / Alipay](https://github.com/user-attachments/assets/803469bd-ed6a-4fab-932d-50e5088a2d03)
|
| 438 |
+
- 🪙 Wallet Address:
|
| 439 |
+
- ETH: 0x01f7CEd3245CaB3891A0ec8f528178db352EaC74
|
| 440 |
+
- USDT(tron): TP3AnJXkAzfebL2GKmFAvQvXgsxzivweV6
|
| 441 |
+
|
| 442 |
+
(This is a newly created wallet, and if it receives sponsorship, I'll use it to rent GPUs or other GPT services for better debugging and refinement of ComfyUI-Easy-Use features.)
|
| 443 |
+
|
| 444 |
+
## 🌟Stargazers
|
| 445 |
+
|
| 446 |
+
My gratitude extends to the generous souls who bestow a star. Your support is much appreciated!
|
| 447 |
+
|
| 448 |
+
[](https://github.com/yolain/ComfyUI-Easy-Use/stargazers)
|
ComfyUI-Easy-Use/__init__.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "1.2.2"
|
| 2 |
+
|
| 3 |
+
import yaml
|
| 4 |
+
import os
|
| 5 |
+
import folder_paths
|
| 6 |
+
import importlib
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
node_list = [
|
| 10 |
+
"server",
|
| 11 |
+
"api",
|
| 12 |
+
"easyNodes",
|
| 13 |
+
"image",
|
| 14 |
+
"logic"
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
NODE_CLASS_MAPPINGS = {}
|
| 18 |
+
NODE_DISPLAY_NAME_MAPPINGS = {}
|
| 19 |
+
|
| 20 |
+
for module_name in node_list:
|
| 21 |
+
imported_module = importlib.import_module(".py.{}".format(module_name), __name__)
|
| 22 |
+
NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS}
|
| 23 |
+
NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS}
|
| 24 |
+
|
| 25 |
+
cwd_path = os.path.dirname(os.path.realpath(__file__))
|
| 26 |
+
comfy_path = folder_paths.base_path
|
| 27 |
+
|
| 28 |
+
#Wildcards
|
| 29 |
+
from .py.libs.wildcards import read_wildcard_dict
|
| 30 |
+
wildcards_path = os.path.join(os.path.dirname(__file__), "wildcards")
|
| 31 |
+
if os.path.exists(wildcards_path):
|
| 32 |
+
read_wildcard_dict(wildcards_path)
|
| 33 |
+
else:
|
| 34 |
+
os.mkdir(wildcards_path)
|
| 35 |
+
|
| 36 |
+
#Styles
|
| 37 |
+
styles_path = os.path.join(os.path.dirname(__file__), "styles")
|
| 38 |
+
samples_path = os.path.join(os.path.dirname(__file__), "styles", "samples")
|
| 39 |
+
if os.path.exists(styles_path):
|
| 40 |
+
if not os.path.exists(samples_path):
|
| 41 |
+
os.mkdir(samples_path)
|
| 42 |
+
else:
|
| 43 |
+
os.mkdir(styles_path)
|
| 44 |
+
os.mkdir(samples_path)
|
| 45 |
+
|
| 46 |
+
# Model thumbnails
|
| 47 |
+
from .py.libs.add_resources import add_static_resource
|
| 48 |
+
from .py.libs.model import easyModelManager
|
| 49 |
+
model_config = easyModelManager().models_config
|
| 50 |
+
for model in model_config:
|
| 51 |
+
paths = folder_paths.get_folder_paths(model)
|
| 52 |
+
for path in paths:
|
| 53 |
+
if not Path(path).exists():
|
| 54 |
+
continue
|
| 55 |
+
add_static_resource(path, path, limit=True)
|
| 56 |
+
|
| 57 |
+
# get comfyui revision
|
| 58 |
+
from .py.libs.utils import compare_revision
|
| 59 |
+
|
| 60 |
+
new_frontend_revision = 2546
|
| 61 |
+
web_default_version = 'v2' if compare_revision(new_frontend_revision) else 'v1'
|
| 62 |
+
# web directory
|
| 63 |
+
config_path = os.path.join(cwd_path, "config.yaml")
|
| 64 |
+
if os.path.isfile(config_path):
|
| 65 |
+
with open(config_path, 'r') as f:
|
| 66 |
+
data = yaml.load(f, Loader=yaml.FullLoader)
|
| 67 |
+
if data and "WEB_VERSION" in data:
|
| 68 |
+
directory = f"web_version/{data['WEB_VERSION']}"
|
| 69 |
+
with open(config_path, 'w') as f:
|
| 70 |
+
yaml.dump(data, f)
|
| 71 |
+
elif web_default_version != 'v1':
|
| 72 |
+
if not data:
|
| 73 |
+
data = {'WEB_VERSION': web_default_version}
|
| 74 |
+
elif 'WEB_VERSION' not in data:
|
| 75 |
+
data = {**data, 'WEB_VERSION': web_default_version}
|
| 76 |
+
with open(config_path, 'w') as f:
|
| 77 |
+
yaml.dump(data, f)
|
| 78 |
+
directory = f"web_version/{web_default_version}"
|
| 79 |
+
else:
|
| 80 |
+
directory = f"web_version/v1"
|
| 81 |
+
if not os.path.exists(os.path.join(cwd_path, directory)):
|
| 82 |
+
print(f"web root {data['WEB_VERSION']} not found, using default")
|
| 83 |
+
directory = f"web_version/{web_default_version}"
|
| 84 |
+
WEB_DIRECTORY = directory
|
| 85 |
+
else:
|
| 86 |
+
directory = f"web_version/{web_default_version}"
|
| 87 |
+
WEB_DIRECTORY = directory
|
| 88 |
+
|
| 89 |
+
__all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"]
|
| 90 |
+
|
| 91 |
+
print(f'\033[34m[ComfyUI-Easy-Use] server: \033[0mv{__version__} \033[92mLoaded\033[0m')
|
| 92 |
+
print(f'\033[34m[ComfyUI-Easy-Use] web root: \033[0m{os.path.join(cwd_path, directory)} \033[92mLoaded\033[0m')
|
ComfyUI-Easy-Use/config.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
STABILITY_API_DEFAULT: 0
|
| 2 |
+
STABILITY_API_KEY:
|
| 3 |
+
- key: ''
|
| 4 |
+
name: Default
|
| 5 |
+
WEB_VERSION: v2
|
ComfyUI-Easy-Use/install.bat
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@echo off
|
| 2 |
+
|
| 3 |
+
set "requirements_txt=%~dp0\requirements.txt"
|
| 4 |
+
set "requirements_repair_txt=%~dp0\repair_dependency_list.txt"
|
| 5 |
+
set "python_exec=..\..\..\python_embedded\python.exe"
|
| 6 |
+
set "aki_python_exec=..\..\python\python.exe"
|
| 7 |
+
|
| 8 |
+
echo Installing EasyUse Requirements...
|
| 9 |
+
|
| 10 |
+
if exist "%python_exec%" (
|
| 11 |
+
echo Installing with ComfyUI Portable
|
| 12 |
+
"%python_exec%" -s -m pip install -r "%requirements_txt%"
|
| 13 |
+
)^
|
| 14 |
+
else (
|
| 15 |
+
echo Installing with Python
|
| 16 |
+
pip install -r "%requirements_txt%"
|
| 17 |
+
)
|
| 18 |
+
pause
|
ComfyUI-Easy-Use/prestartup_script.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import folder_paths
|
| 2 |
+
import os
|
| 3 |
+
def add_folder_path_and_extensions(folder_name, full_folder_paths, extensions):
|
| 4 |
+
for full_folder_path in full_folder_paths:
|
| 5 |
+
folder_paths.add_model_folder_path(folder_name, full_folder_path)
|
| 6 |
+
if folder_name in folder_paths.folder_names_and_paths:
|
| 7 |
+
current_paths, current_extensions = folder_paths.folder_names_and_paths[folder_name]
|
| 8 |
+
updated_extensions = current_extensions | extensions
|
| 9 |
+
folder_paths.folder_names_and_paths[folder_name] = (current_paths, updated_extensions)
|
| 10 |
+
else:
|
| 11 |
+
folder_paths.folder_names_and_paths[folder_name] = (full_folder_paths, extensions)
|
| 12 |
+
|
| 13 |
+
image_suffixs = set([".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"])
|
| 14 |
+
|
| 15 |
+
model_path = folder_paths.models_dir
|
| 16 |
+
add_folder_path_and_extensions("ultralytics_bbox", [os.path.join(model_path, "ultralytics", "bbox")], folder_paths.supported_pt_extensions)
|
| 17 |
+
add_folder_path_and_extensions("ultralytics_segm", [os.path.join(model_path, "ultralytics", "segm")], folder_paths.supported_pt_extensions)
|
| 18 |
+
add_folder_path_and_extensions("ultralytics", [os.path.join(model_path, "ultralytics")], folder_paths.supported_pt_extensions)
|
| 19 |
+
add_folder_path_and_extensions("mmdets_bbox", [os.path.join(model_path, "mmdets", "bbox")], folder_paths.supported_pt_extensions)
|
| 20 |
+
add_folder_path_and_extensions("mmdets_segm", [os.path.join(model_path, "mmdets", "segm")], folder_paths.supported_pt_extensions)
|
| 21 |
+
add_folder_path_and_extensions("mmdets", [os.path.join(model_path, "mmdets")], folder_paths.supported_pt_extensions)
|
| 22 |
+
add_folder_path_and_extensions("sams", [os.path.join(model_path, "sams")], folder_paths.supported_pt_extensions)
|
| 23 |
+
add_folder_path_and_extensions("onnx", [os.path.join(model_path, "onnx")], {'.onnx'})
|
| 24 |
+
add_folder_path_and_extensions("instantid", [os.path.join(model_path, "instantid")], folder_paths.supported_pt_extensions)
|
| 25 |
+
add_folder_path_and_extensions("pulid", [os.path.join(model_path, "pulid")], folder_paths.supported_pt_extensions)
|
| 26 |
+
add_folder_path_and_extensions("layer_model", [os.path.join(model_path, "layer_model")], folder_paths.supported_pt_extensions)
|
| 27 |
+
add_folder_path_and_extensions("rembg", [os.path.join(model_path, "rembg")], folder_paths.supported_pt_extensions)
|
| 28 |
+
add_folder_path_and_extensions("ipadapter", [os.path.join(model_path, "ipadapter")], folder_paths.supported_pt_extensions)
|
| 29 |
+
add_folder_path_and_extensions("dynamicrafter_models", [os.path.join(model_path, "dynamicrafter_models")], folder_paths.supported_pt_extensions)
|
| 30 |
+
add_folder_path_and_extensions("mediapipe", [os.path.join(model_path, "mediapipe")], set(['.tflite','.pth']))
|
| 31 |
+
add_folder_path_and_extensions("inpaint", [os.path.join(model_path, "inpaint")], folder_paths.supported_pt_extensions)
|
| 32 |
+
add_folder_path_and_extensions("prompt_generator", [os.path.join(model_path, "prompt_generator")], folder_paths.supported_pt_extensions)
|
| 33 |
+
add_folder_path_and_extensions("t5", [os.path.join(model_path, "t5")], folder_paths.supported_pt_extensions)
|
| 34 |
+
add_folder_path_and_extensions("llm", [os.path.join(model_path, "LLM")], folder_paths.supported_pt_extensions)
|
| 35 |
+
|
| 36 |
+
add_folder_path_and_extensions("checkpoints_thumb", [os.path.join(model_path, "checkpoints")], image_suffixs)
|
| 37 |
+
add_folder_path_and_extensions("loras_thumb", [os.path.join(model_path, "loras")], image_suffixs)
|
ComfyUI-Easy-Use/py/__init__.py
ADDED
|
File without changes
|
ComfyUI-Easy-Use/py/api.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import hashlib
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
import shutil
|
| 6 |
+
import folder_paths
|
| 7 |
+
from folder_paths import get_directory_by_type
|
| 8 |
+
from server import PromptServer
|
| 9 |
+
from .config import RESOURCES_DIR, FOOOCUS_STYLES_DIR, FOOOCUS_STYLES_SAMPLES
|
| 10 |
+
from .libs.model import easyModelManager
|
| 11 |
+
from .libs.utils import getMetadata, cleanGPUUsedForce, get_local_filepath
|
| 12 |
+
from .libs.cache import remove_cache
|
| 13 |
+
from .libs.translate import has_chinese, zh_to_en
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import aiohttp
|
| 17 |
+
from aiohttp import web
|
| 18 |
+
except ImportError:
|
| 19 |
+
print("Module 'aiohttp' not installed. Please install it via:")
|
| 20 |
+
print("pip install aiohttp")
|
| 21 |
+
sys.exit()
|
| 22 |
+
|
| 23 |
+
@PromptServer.instance.routes.post("/easyuse/cleangpu")
|
| 24 |
+
def cleanGPU(request):
|
| 25 |
+
try:
|
| 26 |
+
cleanGPUUsedForce()
|
| 27 |
+
remove_cache('*')
|
| 28 |
+
return web.Response(status=200)
|
| 29 |
+
except Exception as e:
|
| 30 |
+
return web.Response(status=500)
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
@PromptServer.instance.routes.post("/easyuse/translate")
|
| 34 |
+
async def translate(request):
|
| 35 |
+
post = await request.post()
|
| 36 |
+
text = post.get("text")
|
| 37 |
+
if has_chinese(text):
|
| 38 |
+
return web.json_response({"text": zh_to_en([text])[0]})
|
| 39 |
+
else:
|
| 40 |
+
return web.json_response({"text": text})
|
| 41 |
+
|
| 42 |
+
@PromptServer.instance.routes.get("/easyuse/reboot")
|
| 43 |
+
def reboot(request):
|
| 44 |
+
try:
|
| 45 |
+
sys.stdout.close_log()
|
| 46 |
+
except Exception as e:
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
return os.execv(sys.executable, [sys.executable] + sys.argv)
|
| 50 |
+
|
| 51 |
+
# parse csv
|
| 52 |
+
@PromptServer.instance.routes.post("/easyuse/upload/csv")
|
| 53 |
+
async def parse_csv(request):
|
| 54 |
+
post = await request.post()
|
| 55 |
+
csv = post.get("csv")
|
| 56 |
+
if csv and csv.file:
|
| 57 |
+
file = csv.file
|
| 58 |
+
text = ''
|
| 59 |
+
for line in file.readlines():
|
| 60 |
+
line = str(line.strip())
|
| 61 |
+
line = line.replace("'", "").replace("b",'')
|
| 62 |
+
text += line + '; \n'
|
| 63 |
+
return web.json_response(text)
|
| 64 |
+
|
| 65 |
+
#get style list
|
| 66 |
+
@PromptServer.instance.routes.get("/easyuse/prompt/styles")
|
| 67 |
+
async def getStylesList(request):
|
| 68 |
+
if "name" in request.rel_url.query:
|
| 69 |
+
name = request.rel_url.query["name"]
|
| 70 |
+
if name == 'fooocus_styles':
|
| 71 |
+
file = os.path.join(RESOURCES_DIR, name+'.json')
|
| 72 |
+
cn_file = os.path.join(RESOURCES_DIR, name + '_cn.json')
|
| 73 |
+
else:
|
| 74 |
+
file = os.path.join(FOOOCUS_STYLES_DIR, name+'.json')
|
| 75 |
+
cn_file = os.path.join(FOOOCUS_STYLES_DIR, name + '_cn.json')
|
| 76 |
+
cn_data = None
|
| 77 |
+
if os.path.isfile(cn_file):
|
| 78 |
+
f = open(cn_file, 'r', encoding='utf-8')
|
| 79 |
+
cn_data = json.load(f)
|
| 80 |
+
f.close()
|
| 81 |
+
if os.path.isfile(file):
|
| 82 |
+
f = open(file, 'r', encoding='utf-8')
|
| 83 |
+
data = json.load(f)
|
| 84 |
+
f.close()
|
| 85 |
+
if data:
|
| 86 |
+
ndata = []
|
| 87 |
+
for d in data:
|
| 88 |
+
nd = {}
|
| 89 |
+
name = d['name'].replace('-', ' ')
|
| 90 |
+
words = name.split(' ')
|
| 91 |
+
key = ' '.join(
|
| 92 |
+
word.upper() if word.lower() in ['mre', 'sai', '3d'] else word.capitalize() for word in
|
| 93 |
+
words)
|
| 94 |
+
img_name = '_'.join(words).lower()
|
| 95 |
+
if "name_cn" in d:
|
| 96 |
+
nd['name_cn'] = d['name_cn']
|
| 97 |
+
elif cn_data:
|
| 98 |
+
nd['name_cn'] = cn_data[key] if key in cn_data else key
|
| 99 |
+
nd["name"] = d['name']
|
| 100 |
+
nd['imgName'] = img_name
|
| 101 |
+
if "prompt" in d:
|
| 102 |
+
nd['prompt'] = d['prompt']
|
| 103 |
+
if "negative_prompt" in d:
|
| 104 |
+
nd['negative_prompt'] = d['negative_prompt']
|
| 105 |
+
ndata.append(nd)
|
| 106 |
+
return web.json_response(ndata)
|
| 107 |
+
return web.Response(status=400)
|
| 108 |
+
|
| 109 |
+
# get style preview image
|
| 110 |
+
@PromptServer.instance.routes.get("/easyuse/prompt/styles/image")
|
| 111 |
+
async def getStylesImage(request):
|
| 112 |
+
styles_name = request.rel_url.query["styles_name"] if "styles_name" in request.rel_url.query else None
|
| 113 |
+
if "name" in request.rel_url.query:
|
| 114 |
+
name = request.rel_url.query["name"]
|
| 115 |
+
if os.path.exists(os.path.join(FOOOCUS_STYLES_DIR, 'samples')):
|
| 116 |
+
file = os.path.join(FOOOCUS_STYLES_DIR, 'samples', name + '.jpg')
|
| 117 |
+
if os.path.isfile(file):
|
| 118 |
+
return web.FileResponse(file)
|
| 119 |
+
elif styles_name == 'fooocus_styles':
|
| 120 |
+
return web.Response(text=FOOOCUS_STYLES_SAMPLES + name + '.jpg')
|
| 121 |
+
elif styles_name == 'fooocus_styles':
|
| 122 |
+
return web.Response(text=FOOOCUS_STYLES_SAMPLES + name + '.jpg')
|
| 123 |
+
return web.Response(status=400)
|
| 124 |
+
|
| 125 |
+
# get models lists
|
| 126 |
+
@PromptServer.instance.routes.get("/easyuse/models/list")
|
| 127 |
+
async def getModelsList(request):
|
| 128 |
+
if "type" in request.rel_url.query:
|
| 129 |
+
type = request.rel_url.query["type"]
|
| 130 |
+
if type not in ['checkpoints', 'loras']:
|
| 131 |
+
return web.Response(status=400)
|
| 132 |
+
manager = easyModelManager()
|
| 133 |
+
return web.json_response(manager.get_model_lists(type))
|
| 134 |
+
else:
|
| 135 |
+
return web.Response(status=400)
|
| 136 |
+
|
| 137 |
+
# get models thumbnails
|
| 138 |
+
@PromptServer.instance.routes.get("/easyuse/models/thumbnail")
|
| 139 |
+
async def getModelsThumbnail(request):
|
| 140 |
+
limit = 500
|
| 141 |
+
if "limit" in request.rel_url.query:
|
| 142 |
+
limit = request.rel_url.query.get("limit")
|
| 143 |
+
limit = int(limit)
|
| 144 |
+
checkpoints = folder_paths.get_filename_list("checkpoints_thumb")
|
| 145 |
+
loras = folder_paths.get_filename_list("loras_thumb")
|
| 146 |
+
checkpoints_full = []
|
| 147 |
+
loras_full = []
|
| 148 |
+
if len(checkpoints) + len(loras) >= limit:
|
| 149 |
+
return web.Response(status=400)
|
| 150 |
+
for index, i in enumerate(checkpoints):
|
| 151 |
+
full_path = folder_paths.get_full_path('checkpoints_thumb', str(i))
|
| 152 |
+
if full_path:
|
| 153 |
+
checkpoints_full.append(full_path)
|
| 154 |
+
for index, i in enumerate(loras):
|
| 155 |
+
full_path = folder_paths.get_full_path('loras_thumb', str(i))
|
| 156 |
+
if full_path:
|
| 157 |
+
loras_full.append(full_path)
|
| 158 |
+
return web.json_response(checkpoints_full + loras_full)
|
| 159 |
+
|
| 160 |
+
@PromptServer.instance.routes.post("/easyuse/metadata/notes/{name}")
|
| 161 |
+
async def save_notes(request):
|
| 162 |
+
name = request.match_info["name"]
|
| 163 |
+
pos = name.index("/")
|
| 164 |
+
type = name[0:pos]
|
| 165 |
+
name = name[pos+1:]
|
| 166 |
+
|
| 167 |
+
file_path = None
|
| 168 |
+
if type == "embeddings" or type == "loras":
|
| 169 |
+
name = name.lower()
|
| 170 |
+
files = folder_paths.get_filename_list(type)
|
| 171 |
+
for f in files:
|
| 172 |
+
lower_f = f.lower()
|
| 173 |
+
if lower_f == name:
|
| 174 |
+
file_path = folder_paths.get_full_path(type, f)
|
| 175 |
+
else:
|
| 176 |
+
n = os.path.splitext(f)[0].lower()
|
| 177 |
+
if n == name:
|
| 178 |
+
file_path = folder_paths.get_full_path(type, f)
|
| 179 |
+
|
| 180 |
+
if file_path is not None:
|
| 181 |
+
break
|
| 182 |
+
else:
|
| 183 |
+
file_path = folder_paths.get_full_path(
|
| 184 |
+
type, name)
|
| 185 |
+
if not file_path:
|
| 186 |
+
return web.Response(status=404)
|
| 187 |
+
|
| 188 |
+
file_no_ext = os.path.splitext(file_path)[0]
|
| 189 |
+
info_file = file_no_ext + ".txt"
|
| 190 |
+
with open(info_file, "w") as f:
|
| 191 |
+
f.write(await request.text())
|
| 192 |
+
|
| 193 |
+
return web.Response(status=200)
|
| 194 |
+
|
| 195 |
+
@PromptServer.instance.routes.get("/easyuse/metadata/{name}")
|
| 196 |
+
async def load_metadata(request):
|
| 197 |
+
name = request.match_info["name"]
|
| 198 |
+
pos = name.index("/")
|
| 199 |
+
type = name[0:pos]
|
| 200 |
+
name = name[pos+1:]
|
| 201 |
+
|
| 202 |
+
file_path = None
|
| 203 |
+
if type == "embeddings":
|
| 204 |
+
name = name.lower()
|
| 205 |
+
files = folder_paths.get_filename_list(type)
|
| 206 |
+
for f in files:
|
| 207 |
+
lower_f = f.lower()
|
| 208 |
+
if lower_f == name:
|
| 209 |
+
file_path = folder_paths.get_full_path(type, f)
|
| 210 |
+
else:
|
| 211 |
+
n = os.path.splitext(f)[0].lower()
|
| 212 |
+
if n == name:
|
| 213 |
+
file_path = folder_paths.get_full_path(type, f)
|
| 214 |
+
|
| 215 |
+
if file_path is not None:
|
| 216 |
+
break
|
| 217 |
+
else:
|
| 218 |
+
file_path = folder_paths.get_full_path(type, name)
|
| 219 |
+
if not file_path:
|
| 220 |
+
return web.Response(status=404)
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
header = getMetadata(file_path)
|
| 224 |
+
header_json = json.loads(header)
|
| 225 |
+
meta = header_json["__metadata__"] if "__metadata__" in header_json else None
|
| 226 |
+
except:
|
| 227 |
+
meta = None
|
| 228 |
+
|
| 229 |
+
if meta is None:
|
| 230 |
+
meta = {}
|
| 231 |
+
|
| 232 |
+
file_no_ext = os.path.splitext(file_path)[0]
|
| 233 |
+
|
| 234 |
+
info_file = file_no_ext + ".txt"
|
| 235 |
+
if os.path.isfile(info_file):
|
| 236 |
+
with open(info_file, "r") as f:
|
| 237 |
+
meta["easyuse.notes"] = f.read()
|
| 238 |
+
|
| 239 |
+
hash_file = file_no_ext + ".sha256"
|
| 240 |
+
if os.path.isfile(hash_file):
|
| 241 |
+
with open(hash_file, "rt") as f:
|
| 242 |
+
meta["easyuse.sha256"] = f.read()
|
| 243 |
+
else:
|
| 244 |
+
with open(file_path, "rb") as f:
|
| 245 |
+
meta["easyuse.sha256"] = hashlib.sha256(f.read()).hexdigest()
|
| 246 |
+
with open(hash_file, "wt") as f:
|
| 247 |
+
f.write(meta["easyuse.sha256"])
|
| 248 |
+
|
| 249 |
+
return web.json_response(meta)
|
| 250 |
+
|
| 251 |
+
@PromptServer.instance.routes.post("/easyuse/save/{name}")
|
| 252 |
+
async def save_preview(request):
|
| 253 |
+
name = request.match_info["name"]
|
| 254 |
+
pos = name.index("/")
|
| 255 |
+
type = name[0:pos]
|
| 256 |
+
name = name[pos+1:]
|
| 257 |
+
|
| 258 |
+
body = await request.json()
|
| 259 |
+
|
| 260 |
+
dir = get_directory_by_type(body.get("type", "output"))
|
| 261 |
+
subfolder = body.get("subfolder", "")
|
| 262 |
+
full_output_folder = os.path.join(dir, os.path.normpath(subfolder))
|
| 263 |
+
|
| 264 |
+
if os.path.commonpath((dir, os.path.abspath(full_output_folder))) != dir:
|
| 265 |
+
return web.Response(status=400)
|
| 266 |
+
|
| 267 |
+
filepath = os.path.join(full_output_folder, body.get("filename", ""))
|
| 268 |
+
image_path = folder_paths.get_full_path(type, name)
|
| 269 |
+
image_path = os.path.splitext(
|
| 270 |
+
image_path)[0] + os.path.splitext(filepath)[1]
|
| 271 |
+
|
| 272 |
+
shutil.copyfile(filepath, image_path)
|
| 273 |
+
|
| 274 |
+
return web.json_response({
|
| 275 |
+
"image": type + "/" + os.path.basename(image_path)
|
| 276 |
+
})
|
| 277 |
+
|
| 278 |
+
@PromptServer.instance.routes.post("/easyuse/model/download")
|
| 279 |
+
async def download_model(request):
|
| 280 |
+
post = await request.post()
|
| 281 |
+
url = post.get("url")
|
| 282 |
+
local_dir = post.get("local_dir")
|
| 283 |
+
if local_dir not in ['checkpoints', 'loras', 'controlnet', 'onnx', 'instantid', 'ipadapter', 'dynamicrafter_models', 'mediapipe', 'rembg', 'layer_model']:
|
| 284 |
+
return web.Response(status=400)
|
| 285 |
+
local_path = os.path.join(folder_paths.models_dir, local_dir)
|
| 286 |
+
try:
|
| 287 |
+
get_local_filepath(url, local_path)
|
| 288 |
+
return web.Response(status=200)
|
| 289 |
+
except:
|
| 290 |
+
return web.Response(status=500)
|
| 291 |
+
|
| 292 |
+
NODE_CLASS_MAPPINGS = {}
|
| 293 |
+
NODE_DISPLAY_NAME_MAPPINGS = {}
|
ComfyUI-Easy-Use/py/bitsandbytes_NF4/__init__.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#credit to comfyanonymous for this module
|
| 2 |
+
#from https://github.com/comfyanonymous/ComfyUI_bitsandbytes_NF4
|
| 3 |
+
import comfy.ops
|
| 4 |
+
import torch
|
| 5 |
+
import folder_paths
|
| 6 |
+
from ..libs.utils import install_package
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
from bitsandbytes.nn.modules import Params4bit, QuantState
|
| 10 |
+
except ImportError:
|
| 11 |
+
Params4bit = torch.nn.Parameter
|
| 12 |
+
raise ImportError("Please install bitsandbytes>=0.43.3")
|
| 13 |
+
|
| 14 |
+
def functional_linear_4bits(x, weight, bias):
|
| 15 |
+
try:
|
| 16 |
+
install_package("bitsandbytes", "0.43.3", True, "0.43.3")
|
| 17 |
+
import bitsandbytes as bnb
|
| 18 |
+
except ImportError:
|
| 19 |
+
raise ImportError("Please install bitsandbytes>=0.43.3")
|
| 20 |
+
|
| 21 |
+
out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
|
| 22 |
+
out = out.to(x)
|
| 23 |
+
return out
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def copy_quant_state(state, device: torch.device = None):
|
| 27 |
+
if state is None:
|
| 28 |
+
return None
|
| 29 |
+
|
| 30 |
+
device = device or state.absmax.device
|
| 31 |
+
|
| 32 |
+
state2 = (
|
| 33 |
+
QuantState(
|
| 34 |
+
absmax=state.state2.absmax.to(device),
|
| 35 |
+
shape=state.state2.shape,
|
| 36 |
+
code=state.state2.code.to(device),
|
| 37 |
+
blocksize=state.state2.blocksize,
|
| 38 |
+
quant_type=state.state2.quant_type,
|
| 39 |
+
dtype=state.state2.dtype,
|
| 40 |
+
)
|
| 41 |
+
if state.nested
|
| 42 |
+
else None
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
return QuantState(
|
| 46 |
+
absmax=state.absmax.to(device),
|
| 47 |
+
shape=state.shape,
|
| 48 |
+
code=state.code.to(device),
|
| 49 |
+
blocksize=state.blocksize,
|
| 50 |
+
quant_type=state.quant_type,
|
| 51 |
+
dtype=state.dtype,
|
| 52 |
+
offset=state.offset.to(device) if state.nested else None,
|
| 53 |
+
state2=state2,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ForgeParams4bit(Params4bit):
|
| 58 |
+
|
| 59 |
+
def to(self, *args, **kwargs):
|
| 60 |
+
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
|
| 61 |
+
if device is not None and device.type == "cuda" and not self.bnb_quantized:
|
| 62 |
+
return self._quantize(device)
|
| 63 |
+
else:
|
| 64 |
+
n = ForgeParams4bit(
|
| 65 |
+
torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
|
| 66 |
+
requires_grad=self.requires_grad,
|
| 67 |
+
quant_state=copy_quant_state(self.quant_state, device),
|
| 68 |
+
blocksize=self.blocksize,
|
| 69 |
+
compress_statistics=self.compress_statistics,
|
| 70 |
+
quant_type=self.quant_type,
|
| 71 |
+
quant_storage=self.quant_storage,
|
| 72 |
+
bnb_quantized=self.bnb_quantized,
|
| 73 |
+
module=self.module
|
| 74 |
+
)
|
| 75 |
+
self.module.quant_state = n.quant_state
|
| 76 |
+
self.data = n.data
|
| 77 |
+
self.quant_state = n.quant_state
|
| 78 |
+
return n
|
| 79 |
+
|
| 80 |
+
class ForgeLoader4Bit(torch.nn.Module):
|
| 81 |
+
def __init__(self, *, device, dtype, quant_type, **kwargs):
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
|
| 84 |
+
self.weight = None
|
| 85 |
+
self.quant_state = None
|
| 86 |
+
self.bias = None
|
| 87 |
+
self.quant_type = quant_type
|
| 88 |
+
|
| 89 |
+
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
| 90 |
+
super()._save_to_state_dict(destination, prefix, keep_vars)
|
| 91 |
+
quant_state = getattr(self.weight, "quant_state", None)
|
| 92 |
+
if quant_state is not None:
|
| 93 |
+
for k, v in quant_state.as_dict(packed=True).items():
|
| 94 |
+
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
|
| 95 |
+
return
|
| 96 |
+
|
| 97 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
|
| 98 |
+
quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
|
| 99 |
+
|
| 100 |
+
if any('bitsandbytes' in k for k in quant_state_keys):
|
| 101 |
+
quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
|
| 102 |
+
|
| 103 |
+
self.weight = ForgeParams4bit().from_prequantized(
|
| 104 |
+
data=state_dict[prefix + 'weight'],
|
| 105 |
+
quantized_stats=quant_state_dict,
|
| 106 |
+
requires_grad=False,
|
| 107 |
+
device=self.dummy.device,
|
| 108 |
+
module=self
|
| 109 |
+
)
|
| 110 |
+
self.quant_state = self.weight.quant_state
|
| 111 |
+
|
| 112 |
+
if prefix + 'bias' in state_dict:
|
| 113 |
+
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
| 114 |
+
|
| 115 |
+
del self.dummy
|
| 116 |
+
elif hasattr(self, 'dummy'):
|
| 117 |
+
if prefix + 'weight' in state_dict:
|
| 118 |
+
self.weight = ForgeParams4bit(
|
| 119 |
+
state_dict[prefix + 'weight'].to(self.dummy),
|
| 120 |
+
requires_grad=False,
|
| 121 |
+
compress_statistics=True,
|
| 122 |
+
quant_type=self.quant_type,
|
| 123 |
+
quant_storage=torch.uint8,
|
| 124 |
+
module=self,
|
| 125 |
+
)
|
| 126 |
+
self.quant_state = self.weight.quant_state
|
| 127 |
+
|
| 128 |
+
if prefix + 'bias' in state_dict:
|
| 129 |
+
self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
|
| 130 |
+
|
| 131 |
+
del self.dummy
|
| 132 |
+
else:
|
| 133 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
| 134 |
+
|
| 135 |
+
current_device = None
|
| 136 |
+
current_dtype = None
|
| 137 |
+
current_manual_cast_enabled = False
|
| 138 |
+
current_bnb_dtype = None
|
| 139 |
+
|
| 140 |
+
class OPS(comfy.ops.manual_cast):
|
| 141 |
+
class Linear(ForgeLoader4Bit):
|
| 142 |
+
def __init__(self, *args, device=None, dtype=None, **kwargs):
|
| 143 |
+
super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
|
| 144 |
+
self.parameters_manual_cast = current_manual_cast_enabled
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
self.weight.quant_state = self.quant_state
|
| 148 |
+
|
| 149 |
+
if self.bias is not None and self.bias.dtype != x.dtype:
|
| 150 |
+
# Maybe this can also be set to all non-bnb ops since the cost is very low.
|
| 151 |
+
# And it only invokes one time, and most linear does not have bias
|
| 152 |
+
self.bias.data = self.bias.data.to(x.dtype)
|
| 153 |
+
|
| 154 |
+
if not self.parameters_manual_cast:
|
| 155 |
+
return functional_linear_4bits(x, self.weight, self.bias)
|
| 156 |
+
elif not self.weight.bnb_quantized:
|
| 157 |
+
assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
|
| 158 |
+
layer_original_device = self.weight.device
|
| 159 |
+
self.weight = self.weight._quantize(x.device)
|
| 160 |
+
bias = self.bias.to(x.device) if self.bias is not None else None
|
| 161 |
+
out = functional_linear_4bits(x, self.weight, bias)
|
| 162 |
+
self.weight = self.weight.to(layer_original_device)
|
| 163 |
+
return out
|
| 164 |
+
else:
|
| 165 |
+
weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
|
| 166 |
+
with main_stream_worker(weight, bias, signal):
|
| 167 |
+
return functional_linear_4bits(x, weight, bias)
|