JasonSmithSO commited on
Commit
8866644
·
verified ·
1 Parent(s): 0a3dbb2

Upload 578 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +8 -0
  2. ComfyUI-Advanced-ControlNet/LICENSE +674 -0
  3. ComfyUI-Advanced-ControlNet/README.md +202 -0
  4. ComfyUI-Advanced-ControlNet/__init__.py +3 -0
  5. ComfyUI-Advanced-ControlNet/adv_control/control.py +860 -0
  6. ComfyUI-Advanced-ControlNet/adv_control/control_lllite.py +254 -0
  7. ComfyUI-Advanced-ControlNet/adv_control/control_reference.py +833 -0
  8. ComfyUI-Advanced-ControlNet/adv_control/control_sparsectrl.py +949 -0
  9. ComfyUI-Advanced-ControlNet/adv_control/control_svd.py +517 -0
  10. ComfyUI-Advanced-ControlNet/adv_control/logger.py +36 -0
  11. ComfyUI-Advanced-ControlNet/adv_control/nodes.py +235 -0
  12. ComfyUI-Advanced-ControlNet/adv_control/nodes_deprecated.py +71 -0
  13. ComfyUI-Advanced-ControlNet/adv_control/nodes_keyframes.py +461 -0
  14. ComfyUI-Advanced-ControlNet/adv_control/nodes_loosecontrol.py +67 -0
  15. ComfyUI-Advanced-ControlNet/adv_control/nodes_reference.py +90 -0
  16. ComfyUI-Advanced-ControlNet/adv_control/nodes_sparsectrl.py +163 -0
  17. ComfyUI-Advanced-ControlNet/adv_control/nodes_weight.py +224 -0
  18. ComfyUI-Advanced-ControlNet/adv_control/utils.py +915 -0
  19. ComfyUI-Advanced-ControlNet/pyproject.toml +15 -0
  20. ComfyUI-Advanced-ControlNet/requirements.txt +0 -0
  21. ComfyUI-BrushNet/BIG_IMAGE.md +6 -0
  22. ComfyUI-BrushNet/CN.md +39 -0
  23. ComfyUI-BrushNet/LICENSE +201 -0
  24. ComfyUI-BrushNet/PARAMS.md +47 -0
  25. ComfyUI-BrushNet/RAUNET.md +39 -0
  26. ComfyUI-BrushNet/README.md +261 -0
  27. ComfyUI-BrushNet/__init__.py +32 -0
  28. ComfyUI-BrushNet/brushnet/brushnet.json +58 -0
  29. ComfyUI-BrushNet/brushnet/brushnet.py +948 -0
  30. ComfyUI-BrushNet/brushnet/brushnet_ca.py +956 -0
  31. ComfyUI-BrushNet/brushnet/brushnet_xl.json +63 -0
  32. ComfyUI-BrushNet/brushnet/powerpaint.json +57 -0
  33. ComfyUI-BrushNet/brushnet/powerpaint_utils.py +496 -0
  34. ComfyUI-BrushNet/brushnet/unet_2d_blocks.py +0 -0
  35. ComfyUI-BrushNet/brushnet/unet_2d_condition.py +1355 -0
  36. ComfyUI-BrushNet/brushnet_nodes.py +1080 -0
  37. ComfyUI-BrushNet/model_patch.py +134 -0
  38. ComfyUI-BrushNet/raunet_nodes.py +158 -0
  39. ComfyUI-BrushNet/requirements.txt +3 -0
  40. ComfyUI-Easy-Use/LICENSE +674 -0
  41. ComfyUI-Easy-Use/README.ZH_CN.md +459 -0
  42. ComfyUI-Easy-Use/README.en.md +422 -0
  43. ComfyUI-Easy-Use/README.md +448 -0
  44. ComfyUI-Easy-Use/__init__.py +92 -0
  45. ComfyUI-Easy-Use/config.yaml +5 -0
  46. ComfyUI-Easy-Use/install.bat +18 -0
  47. ComfyUI-Easy-Use/prestartup_script.py +37 -0
  48. ComfyUI-Easy-Use/py/__init__.py +0 -0
  49. ComfyUI-Easy-Use/py/api.py +293 -0
  50. ComfyUI-Easy-Use/py/bitsandbytes_NF4/__init__.py +167 -0
.gitattributes CHANGED
@@ -37,3 +37,11 @@ comfyui_screenshot.png filter=lfs diff=lfs merge=lfs -text
37
  NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
38
  custom_controlnet_aux/mesh_graphormer/hand_landmarker.task filter=lfs diff=lfs merge=lfs -text
39
  custom_controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
37
  NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
38
  custom_controlnet_aux/mesh_graphormer/hand_landmarker.task filter=lfs diff=lfs merge=lfs -text
39
  custom_controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text
40
+ ComfyUI-Easy-Use/py/kolors/chatglm/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
41
+ ComfyUI-Easy-Use/resources/OpenSans-Medium.ttf filter=lfs diff=lfs merge=lfs -text
42
+ ComfyUI-KJNodes/docs/images/2024-04-03_20_49_29-ComfyUI.png filter=lfs diff=lfs merge=lfs -text
43
+ ComfyUI-KJNodes/fonts/FreeMono.ttf filter=lfs diff=lfs merge=lfs -text
44
+ ComfyUI-KJNodes/fonts/FreeMonoBoldOblique.otf filter=lfs diff=lfs merge=lfs -text
45
+ ComfyUI-KJNodes/fonts/TTNorms-Black.otf filter=lfs diff=lfs merge=lfs -text
46
+ ComfyUI-Kolors-MZ/configs/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
47
+ ComfyUI-KwaiKolorsWrapper/configs/tokenizer/vocab.txt filter=lfs diff=lfs merge=lfs -text
ComfyUI-Advanced-ControlNet/LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
ComfyUI-Advanced-ControlNet/README.md ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ComfyUI-Advanced-ControlNet
2
+ Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks. The ControlNet nodes here fully support sliding context sampling, like the one used in the [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) nodes. Currently supports ControlNets, T2IAdapters, ControlLoRAs, ControlLLLite, SparseCtrls, SVD-ControlNets, and Reference.
3
+
4
+ Custom weights allow replication of the "My prompt is more important" feature of Auto1111's sd-webui ControlNet extension via Soft Weights, and the "ControlNet is more important" feature can be granularly controlled by changing the uncond_multiplier on the same Soft Weights.
5
+
6
+ ControlNet preprocessors are available through [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux) nodes.
7
+
8
+ ## Features
9
+ - Timestep and latent strength scheduling
10
+ - Attention masks
11
+ - Replicate ***"My prompt is more important"*** feature from sd-webui-controlnet extension via ***Soft Weights***, and allow softness to be tweaked via ***base_multiplier***
12
+ - Replicate ***"ControlNet is more important"*** feature from sd-webui-controlnet extension via ***uncond_multiplier*** on ***Soft Weights***
13
+ - uncond_multiplier=0.0 gives identical results of auto1111's feature, but values between 0.0 and 1.0 can be used without issue to granularly control the setting.
14
+ - ControlNet, T2IAdapter, and ControlLoRA support for sliding context windows
15
+ - ControlLLLite support (requires model_optional to be passed into and out of Apply Advanced ControlNet node)
16
+ - SparseCtrl support
17
+ - SVD-ControlNet support
18
+ - Stable Video Diffusion ControlNets trained by **CiaraRowles**: [Depth](https://huggingface.co/CiaraRowles/temporal-controlnet-depth-svd-v1/tree/main/controlnet), [Lineart](https://huggingface.co/CiaraRowles/temporal-controlnet-lineart-svd-v1/tree/main/controlnet)
19
+ - Reference support
20
+ - Supports ```reference_attn```, ```reference_adain```, and ```refrence_adain+attn``` modes. ```style_fidelity``` and ```ref_weight``` are equivalent to style_fidelity and control_weight in Auto1111, respectively, and strength of the Apply ControlNet is the balance between ref-influenced result and no-ref result. There is also a Reference ControlNet (Finetune) node that allows adjust the style_fidelity, weight, and strength of attn and adain separately.
21
+
22
+ ## Table of Contents:
23
+ - [Scheduling Explanation](#scheduling-explanation)
24
+ - [Nodes](#nodes)
25
+ - [Usage](#usage) (will fill this out soon)
26
+
27
+
28
+ # Scheduling Explanation
29
+
30
+ The two core concepts for scheduling are ***Timestep Keyframes*** and ***Latent Keyframes***.
31
+
32
+ ***Timestep Keyframes*** hold the values that guide the settings for a controlnet, and begin to take effect based on their start_percent, which corresponds to the percentage of the sampling process. They can contain masks for the strengths of each latent, control_net_weights, and latent_keyframes (specific strengths for each latent), all optional.
33
+
34
+ ***Latent Keyframes*** determine the strength of the controlnet for specific latents - all they contain is the batch_index of the latent, and the strength the controlnet should apply for that latent. As a concept, latent keyframes achieve the same affect as a uniform mask with the chosen strength value.
35
+
36
+ ![advcn_image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/e6275264-6c3f-4246-a319-111ee48f4cd9)
37
+
38
+ # Nodes
39
+
40
+ The ControlNet nodes provided here are the ***Apply Advanced ControlNet*** and ***Load Advanced ControlNet Model*** (or diff) nodes. The vanilla ControlNet nodes are also compatible, and can be used almost interchangeably - the only difference is that **at least one of these nodes must be used** for Advanced versions of ControlNets to be used (important for sliding context sampling, like with AnimateDiff-Evolved).
41
+
42
+ Key:
43
+ - 🟩 - required inputs
44
+ - 🟨 - optional inputs
45
+ - 🟦 - start as widgets, can be converted to inputs
46
+ - 🟥 - optional input/output, but not recommended to use unless needed
47
+ - 🟪 - output
48
+
49
+ ## Apply Advanced ControlNet
50
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/dc541d41-70df-4a71-b832-efa65af98f06)
51
+
52
+ Same functionality as the vanilla Apply Advanced ControlNet (Advanced) node, except with Advanced ControlNet features added to it. Automatically converts any ControlNet from ControlNet loaders into Advanced versions.
53
+
54
+ ### Inputs
55
+ - 🟩***positive***: conditioning (positive).
56
+ - 🟩***negative***: conditioning (negative).
57
+ - 🟩***control_net***: loaded controlnet; will be converted to Advanced version automatically by this node, if it's a supported type.
58
+ - 🟩***image***: images to guide controlnets - if the loaded controlnet requires it, they must preprocessed images. If one image provided, will be used for all latents. If more images provided, will use each image separately for each latent. If not enough images to meet latent count, will repeat the images from the beginning to match vanilla ControlNet functionality.
59
+ - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as image input, if you provide more than one mask, each can apply to a different latent.
60
+ - 🟨***timestep_kf***: timestep keyframes to guide controlnet effect throughout sampling steps.
61
+ - 🟨***latent_kf_override***: override for latent keyframes, useful if no other features from timestep keyframes is needed. *NOTE: this latent keyframe will be applied to ALL timesteps, regardless if there are other latent keyframes attached to connected timestep keyframes.*
62
+ - 🟨***weights_override***: override for weights, useful if no other features from timestep keyframes is needed. *NOTE: this weight will be applied to ALL timesteps, regardless if there are other weights attached to connected timestep keyframes.*
63
+ - 🟦***strength***: strength of controlnet; 1.0 is full strength, 0.0 is no effect at all.
64
+ - 🟦***start_percent***: sampling step percentage at which controlnet should start to be applied - no matter what start_percent is set on timestep keyframes, they won't take effect until this start_percent is reached.
65
+ - 🟦***stop_percent***: sampling step percentage at which controlnet should stop being applied - no matter what start_percent is set on timestep keyframes, they won't take effect once this end_percent is reached.
66
+
67
+ ### Outputs
68
+ - 🟪***positive***: conditioning (positive) with applied controlnets
69
+ - 🟪***negative***: conditioning (negative) with applied controlnets
70
+
71
+ ## Load Advanced ControlNet Model
72
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/4a7f58a9-783d-4da4-bf82-bc9c167e4722)
73
+
74
+ Loads a ControlNet model and converts it into an Advanced version that supports all the features in this repo. When used with **Apply Advanced ControlNet** node, there is no reason to use the timestep_keyframe input on this node - use timestep_kf on the Apply node instead.
75
+
76
+ ### Inputs
77
+ - 🟥***timestep_keyframe***: optional and likely unnecessary input to have ControlNet use selected timestep_keyframes - should not be used unless you need to. Useful if this node is not attached to **Apply Advanced ControlNet** node, but still want to use Timestep Keyframe, or to use TK_SHORTCUT outputs from ControlWeights in the same scenario. Will be overriden by the timestep_kf input on **Apply Advanced ControlNet** node, if one is provided there.
78
+ - 🟨***model***: model to plug into the diff version of the node. Some controlnets are designed for receive the model; if you don't know what this does, you probably don't want tot use the diff version of the node.
79
+
80
+ ### Outputs
81
+ - 🟪***CONTROL_NET***: loaded Advanced ControlNet
82
+
83
+ ## Timestep Keyframe
84
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/404f3cfe-5852-4eed-935b-37e32493d1b5)
85
+
86
+ Scheduling node across timesteps (sampling steps) based on the set start_percent. Chaining Timestep Keyframes allows ControlNet scheduling across sampling steps (percentage-wise), through a timestep keyframe schedule.
87
+
88
+ ### Inputs
89
+ - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
90
+ - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
91
+ - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
92
+ - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
93
+ - 🟦***start_percent***: sampling step percentage at which this Timestep Keyframe qualifies to be used. Acts as the 'key' for the Timestep Keyframe in the timestep keyframe schedule.
94
+ - 🟦***strength***: strength of the controlnet; multiplies the controlnet by this value, basically, applied alongside the strength on the Apply ControlNet node. If set to 0.0 will not have any effect during the duration of this Timestep Keyframe's effect, and will increase sampling speed by not doing any work.
95
+ - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
96
+ - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
97
+ - 🟦***guarantee_steps***: when 1 or greater, even if a Timestep Keyframe's start_percent ahead of this one in the schedule is closer to current sampling percentage, this Timestep Keyframe will still be used for the specified amount of steps before moving on to the next selected Timestep Keyframe in the following step. Whether the Timestep Keyframe is used or not, its inputs will still be accounted for inherit_missing purposes.
98
+
99
+ ### Outputs
100
+ - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
101
+
102
+ ## Timestep Keyframe Interpolation
103
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/9789617c-202c-4271-92a2-0909bcf9b108)
104
+
105
+ Allows to create Timestep Keyframe with interpolated strength values in a given percent range. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
106
+
107
+ ### Inputs
108
+ - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
109
+ - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
110
+ - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
111
+ - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
112
+ - 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
113
+ - 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
114
+ - 🟦***strength_start***: strength of the Timestep Keyframe at start of range.
115
+ - 🟦***strength_end***: strength of the Timestep Keyframe at end of range.
116
+ - 🟦***interpolation***: the method of interpolation.
117
+ - 🟦***intervals***: the amount of keyframes to generate in total - the first will have its start_percent equal to start_percent, the last will have its start_percent equal to end_percent.
118
+ - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
119
+ - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
120
+ - 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
121
+
122
+ ### Outputs
123
+ - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
124
+
125
+ ## Timestep Keyframe From List
126
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/9e9c23bf-6f82-4ce7-b4d1-3016fd14707d)
127
+
128
+ Allows to create Timestep Keyframe via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes. (The first generated keyframe will have guarantee_steps=1, rest that follow will have guarantee_steps=0).
129
+
130
+ ### Inputs
131
+ - 🟨***prev_timestep_kf***: used to chain Timestep Keyframes together to create a schedule. The order does not matter - the Timestep Keyframes sort themselves automatically by their start_percent. *Any Timestep Keyframe contained in the prev_timestep_keyframe that contains the same start_percent as the Timestep Keyframe will be overwritten.*
132
+ - 🟨***cn_weights***: weights to apply to controlnet while this Timestep Keyframe is in effect. Must be compatible with the loaded controlnet, or will throw an error explaining what weight types are compatible. If inherit_missing is True, if no control_net_weight is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a weight_override, the weight_override will be used during sampling instead of control_net_weight.*
133
+ - 🟨***latent_keyframe***: latent keyframes to apply to controlnet while this Timestep Keyframe is in effect. If inherit_missing is True, if no latent_keyframe is passed in, will attempt to reuse the last-used weights in the timestep keyframe schedule. *If Apply Advanced ControlNet node has a latent_kf_override, the latent_lf_override will be used during sampling instead of latent_keyframe.*
134
+ - 🟨***mask_optional***: attention masks to apply to controlnets; basically, decides what part of the image the controlnet to apply to (and the relative strength, if the mask is not binary). Same as mask_optional on the Apply Advanced ControlNet node, can apply either one maks to all latents, or individual masks for each latent. If inherit_missing is True, if no mask_optional is passed in, will attempt to reuse the last-used mask_optional in the timestep keyframe schedule. It is NOT overriden by mask_optional on the Apply Advanced ControlNet node; will be used together.
135
+ - 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Timestep Keyframe; first will be assigned to start_percent, last will be assigned to end_percent, and the rest spread linearly between.
136
+ - 🟦***start_percent***: sampling step percentage at which the first generated Timestep Keyframe qualifies to be used.
137
+ - 🟦***end_percent***: sampling step percentage at which the last generated Timestep Keyframe qualifies to be used.
138
+ - 🟦***null_latent_kf_strength***: strength to assign to latents that are unaccounted for in the passed in latent_keyframes. Has no effect if no latent_keyframes are passed in, or no batch_indeces are unaccounted in the latent_keyframes for during sampling.
139
+ - 🟦***inherit_missing***: determines if should reuse values from previous Timestep Keyframes for optional values (control_net_weights, latent_keyframe, and mask_option) that are not included on this TimestepKeyframe. To inherit only specific inputs, use default inputs.
140
+ - 🟦***print_keyframes***: if True, will print the Timestep Keyframes generated by this node for debugging purposes.
141
+
142
+ ### Outputs
143
+ - 🟪***TIMESTEP_KF***: the created Timestep Keyframe, that can either be linked to another or into a Timestep Keyframe input.
144
+
145
+ ## Latent Keyframe
146
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/7eb2cc4c-255c-4f32-b09b-699f713fada3)
147
+
148
+ A singular Latent Keyframe, selects the strength for a specific batch_index. If batch_index is not present during sampling, will simply have no effect. Can be chained with any other Latent Keyframe-type node to create a latent keyframe schedule.
149
+
150
+ ### Inputs
151
+ - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If a Latent Keyframe contained in prev_latent_keyframes have the same batch_index as this Latent Keyframe, they will take priority over this node's value.*
152
+ - 🟦***batch_index***: index of latent in batch to apply controlnet strength to. Acts as the 'key' for the Latent Keyframe in the latent keyframe schedule.
153
+ - 🟦***strength***: strength of controlnet to apply to the corresponding latent.
154
+
155
+ ### Outputs
156
+ - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
157
+
158
+ ## Latent Keyframe Group
159
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/5ce3b795-f5fc-4dc3-ae30-a4c7f87e278c)
160
+
161
+ Allows to create Latent Keyframes via individual indeces or python-style ranges.
162
+
163
+ ### Inputs
164
+ - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
165
+ - 🟨***latent_optional***: the latents expected to be passed in for sampling; only required if you wish to use negative indeces (will be automatically converted to real values).
166
+ - 🟦***index_strengths***: string list of indeces or python-style ranges of indeces to assign strengths to. If latent_optional is passed in, can contain negative indeces or ranges that contain negative numbers, python-style. The different indeces must be comma separated. Individual latents can be specified by ```batch_index=strength```, like ```0=0.9```. Ranges can be specified by ```start_index_inclusive:end_index_exclusive=strength```, like ```0:8=strength```. Negative indeces are possible when latents_optional has an input, with a string such as ```0,-4=0.25```.
167
+ - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
168
+
169
+ ### Outputs
170
+ - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
171
+
172
+ ## Latent Keyframe Interpolation
173
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/7986c737-83b9-46bc-aab0-ae4c368df446)
174
+
175
+ Allows to create Latent Keyframes with interpolated values in a range.
176
+
177
+ ### Inputs
178
+ - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
179
+ - 🟦***batch_index_from***: starting batch_index of range, included.
180
+ - 🟦***batch_index_to***: end batch_index of range, excluded (python-style range).
181
+ - 🟦***strength_from***: starting strength of interpolation.
182
+ - 🟦***strength_to***: end strength of interpolation.
183
+ - 🟦***interpolation***: the method of interpolation.
184
+ - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
185
+
186
+ ### Outputs
187
+ - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
188
+
189
+ ## Latent Keyframe From List
190
+ ![image](https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet/assets/7365912/6cec701f-6183-4aeb-af5c-cac76f5591b7)
191
+
192
+ Allows to create Latent Keyframes via a list of floats, such as with Batch Value Schedule from [ComfyUI_FizzNodes](https://github.com/FizzleDorf/ComfyUI_FizzNodes) nodes.
193
+
194
+ ### Inputs
195
+ - 🟨***prev_latent_kf***: used to chain Latent Keyframes together to create a schedule. *If any Latent Keyframes contained in prev_latent_keyframes have the same batch_index as a this Latent Keyframe, they will take priority over this node's version.*
196
+ - 🟩***float_strengths***: a list of floats, that will correspond to the strength of each Latent Keyframe; the batch_index is the index of each float value in the list.
197
+ - 🟦***print_keyframes***: if True, will print the Latent Keyframes generated by this node for debugging purposes.
198
+
199
+ ### Outputs
200
+ - 🟪***LATENT_KF***: the created Latent Keyframe, that can either be linked to another or into a Latent Keyframe input.
201
+
202
+ # There are more nodes to document and show usage - will add this soon! TODO
ComfyUI-Advanced-ControlNet/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .adv_control.nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS
2
+
3
+ __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
ComfyUI-Advanced-ControlNet/adv_control/control.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Union
2
+ from torch import Tensor
3
+ import torch
4
+ import os
5
+
6
+ import comfy.utils
7
+ import comfy.model_management
8
+ import comfy.model_detection
9
+ import comfy.controlnet as comfy_cn
10
+ from comfy.controlnet import ControlBase, ControlNet, ControlLora, T2IAdapter, broadcast_image_to
11
+ from comfy.model_patcher import ModelPatcher
12
+
13
+ from .control_sparsectrl import SparseModelPatcher, SparseControlNet, SparseCtrlMotionWrapper, SparseMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper
14
+ from .control_lllite import LLLiteModule, LLLitePatch
15
+ from .control_svd import svd_unet_config_from_diffusers_unet, SVDControlNet, svd_unet_to_diffusers
16
+ from .utils import (AdvancedControlBase, TimestepKeyframeGroup, LatentKeyframeGroup, ControlWeightType, ControlWeights, WeightTypeException,
17
+ manual_cast_clean_groupnorm, disable_weight_init_clean_groupnorm, prepare_mask_batch, get_properly_arranged_t2i_weights, load_torch_file_with_dict_factory)
18
+ from .logger import logger
19
+
20
+
21
+ class ControlNetAdvanced(ControlNet, AdvancedControlBase):
22
+ def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
23
+ super().__init__(control_model=control_model, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
24
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controlnet())
25
+
26
+ def get_universal_weights(self) -> ControlWeights:
27
+ raw_weights = [(self.weights.base_multiplier ** float(12 - i)) for i in range(13)]
28
+ return self.weights.copy_with_new_weights(raw_weights)
29
+
30
+ def get_control_advanced(self, x_noisy, t, cond, batched_number):
31
+ # perform special version of get_control that supports sliding context and masks
32
+ return self.sliding_get_control(x_noisy, t, cond, batched_number)
33
+
34
+ def sliding_get_control(self, x_noisy: Tensor, t, cond, batched_number):
35
+ control_prev = None
36
+ if self.previous_controlnet is not None:
37
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
38
+
39
+ if self.timestep_range is not None:
40
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
41
+ if control_prev is not None:
42
+ return control_prev
43
+ else:
44
+ return None
45
+
46
+ dtype = self.control_model.dtype
47
+ if self.manual_cast_dtype is not None:
48
+ dtype = self.manual_cast_dtype
49
+
50
+ output_dtype = x_noisy.dtype
51
+ # make cond_hint appropriate dimensions
52
+ # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
53
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
54
+ if self.cond_hint is not None:
55
+ del self.cond_hint
56
+ self.cond_hint = None
57
+ # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
58
+ if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
59
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
60
+ else:
61
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
62
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
63
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
64
+
65
+ # prepare mask_cond_hint
66
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
67
+
68
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
69
+ # uses 'y' in new ComfyUI update
70
+ y = cond.get('y', None)
71
+ if y is None: # TODO: remove this in the future since no longer used by newest ComfyUI
72
+ y = cond.get('c_adm', None)
73
+ if y is not None:
74
+ y = y.to(dtype)
75
+ timestep = self.model_sampling_current.timestep(t)
76
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
77
+
78
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
79
+ return self.control_merge(None, control, control_prev, output_dtype)
80
+
81
+ def copy(self):
82
+ c = ControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
83
+ self.copy_to(c)
84
+ self.copy_to_advanced(c)
85
+ return c
86
+
87
+ @staticmethod
88
+ def from_vanilla(v: ControlNet, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlNetAdvanced':
89
+ return ControlNetAdvanced(control_model=v.control_model, timestep_keyframes=timestep_keyframe,
90
+ global_average_pooling=v.global_average_pooling, device=v.device, load_device=v.load_device, manual_cast_dtype=v.manual_cast_dtype)
91
+
92
+
93
+ class T2IAdapterAdvanced(T2IAdapter, AdvancedControlBase):
94
+ def __init__(self, t2i_model, timestep_keyframes: TimestepKeyframeGroup, channels_in, compression_ratio=8, upscale_algorithm="nearest_exact", device=None):
95
+ super().__init__(t2i_model=t2i_model, channels_in=channels_in, compression_ratio=compression_ratio, upscale_algorithm=upscale_algorithm, device=device)
96
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.t2iadapter())
97
+
98
+ def control_merge_inject(self, control_input, control_output, control_prev, output_dtype):
99
+ # if has uncond multiplier, need to make sure control shapes are the same batch size as expected
100
+ if self.weights.has_uncond_multiplier:
101
+ if control_input is not None:
102
+ for i in range(len(control_input)):
103
+ x = control_input[i]
104
+ if x is not None:
105
+ if x.size(0) < self.batch_size:
106
+ control_input[i] = x.repeat(self.batched_number, 1, 1, 1)[:self.batch_size]
107
+ if control_output is not None:
108
+ for i in range(len(control_output)):
109
+ x = control_output[i]
110
+ if x is not None:
111
+ if x.size(0) < self.batch_size:
112
+ control_output[i] = x.repeat(self.batched_number, 1, 1, 1)[:self.batch_size]
113
+ return AdvancedControlBase.control_merge_inject(self, control_input, control_output, control_prev, output_dtype)
114
+
115
+ def get_universal_weights(self) -> ControlWeights:
116
+ raw_weights = [(self.weights.base_multiplier ** float(7 - i)) for i in range(8)]
117
+ raw_weights = [raw_weights[-8], raw_weights[-3], raw_weights[-2], raw_weights[-1]]
118
+ raw_weights = get_properly_arranged_t2i_weights(raw_weights)
119
+ return self.weights.copy_with_new_weights(raw_weights)
120
+
121
+ def get_calc_pow(self, idx: int, layers: int) -> int:
122
+ # match how T2IAdapterAdvanced deals with universal weights
123
+ indeces = [7 - i for i in range(8)]
124
+ indeces = [indeces[-8], indeces[-3], indeces[-2], indeces[-1]]
125
+ indeces = get_properly_arranged_t2i_weights(indeces)
126
+ return indeces[idx]
127
+
128
+ def get_control_advanced(self, x_noisy, t, cond, batched_number):
129
+ try:
130
+ # if sub indexes present, replace original hint with subsection
131
+ if self.sub_idxs is not None:
132
+ # cond hints
133
+ full_cond_hint_original = self.cond_hint_original
134
+ del self.cond_hint
135
+ self.cond_hint = None
136
+ self.cond_hint_original = full_cond_hint_original[self.sub_idxs]
137
+ # mask hints
138
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
139
+ return super().get_control(x_noisy, t, cond, batched_number)
140
+ finally:
141
+ if self.sub_idxs is not None:
142
+ # replace original cond hint
143
+ self.cond_hint_original = full_cond_hint_original
144
+ del full_cond_hint_original
145
+
146
+ def copy(self):
147
+ c = T2IAdapterAdvanced(self.t2i_model, self.timestep_keyframes, self.channels_in, self.compression_ratio, self.upscale_algorithm)
148
+ self.copy_to(c)
149
+ self.copy_to_advanced(c)
150
+ return c
151
+
152
+ def cleanup(self):
153
+ super().cleanup()
154
+ self.cleanup_advanced()
155
+
156
+ @staticmethod
157
+ def from_vanilla(v: T2IAdapter, timestep_keyframe: TimestepKeyframeGroup=None) -> 'T2IAdapterAdvanced':
158
+ return T2IAdapterAdvanced(t2i_model=v.t2i_model, timestep_keyframes=timestep_keyframe, channels_in=v.channels_in,
159
+ compression_ratio=v.compression_ratio, upscale_algorithm=v.upscale_algorithm, device=v.device)
160
+
161
+
162
+ class ControlLoraAdvanced(ControlLora, AdvancedControlBase):
163
+ def __init__(self, control_weights, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None):
164
+ super().__init__(control_weights=control_weights, global_average_pooling=global_average_pooling, device=device)
165
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllora())
166
+ # use some functions from ControlNetAdvanced
167
+ self.get_control_advanced = ControlNetAdvanced.get_control_advanced.__get__(self, type(self))
168
+ self.sliding_get_control = ControlNetAdvanced.sliding_get_control.__get__(self, type(self))
169
+
170
+ def get_universal_weights(self) -> ControlWeights:
171
+ raw_weights = [(self.weights.base_multiplier ** float(9 - i)) for i in range(10)]
172
+ return self.weights.copy_with_new_weights(raw_weights)
173
+
174
+ def copy(self):
175
+ c = ControlLoraAdvanced(self.control_weights, self.timestep_keyframes, global_average_pooling=self.global_average_pooling)
176
+ self.copy_to(c)
177
+ self.copy_to_advanced(c)
178
+ return c
179
+
180
+ def cleanup(self):
181
+ super().cleanup()
182
+ self.cleanup_advanced()
183
+
184
+ @staticmethod
185
+ def from_vanilla(v: ControlLora, timestep_keyframe: TimestepKeyframeGroup=None) -> 'ControlLoraAdvanced':
186
+ return ControlLoraAdvanced(control_weights=v.control_weights, timestep_keyframes=timestep_keyframe,
187
+ global_average_pooling=v.global_average_pooling, device=v.device)
188
+
189
+
190
+ class SVDControlNetAdvanced(ControlNetAdvanced):
191
+ def __init__(self, control_model: SVDControlNet, timestep_keyframes: TimestepKeyframeGroup, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
192
+ super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
193
+
194
+ def set_cond_hint(self, *args, **kwargs):
195
+ to_return = super().set_cond_hint(*args, **kwargs)
196
+ # cond hint for SVD-ControlNet needs to be scaled between (-1, 1) instead of (0, 1)
197
+ self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
198
+ return to_return
199
+
200
+ def get_control_advanced(self, x_noisy, t, cond, batched_number):
201
+ control_prev = None
202
+ if self.previous_controlnet is not None:
203
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
204
+
205
+ if self.timestep_range is not None:
206
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
207
+ if control_prev is not None:
208
+ return control_prev
209
+ else:
210
+ return None
211
+
212
+ dtype = self.control_model.dtype
213
+ if self.manual_cast_dtype is not None:
214
+ dtype = self.manual_cast_dtype
215
+
216
+ output_dtype = x_noisy.dtype
217
+ # make cond_hint appropriate dimensions
218
+ # TODO: change this to not require cond_hint upscaling every step when self.sub_idxs are present
219
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
220
+ if self.cond_hint is not None:
221
+ del self.cond_hint
222
+ self.cond_hint = None
223
+ # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
224
+ if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
225
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
226
+ else:
227
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
228
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
229
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
230
+
231
+ # prepare mask_cond_hint
232
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
233
+
234
+ context = cond.get('crossattn_controlnet', cond['c_crossattn'])
235
+ # uses 'y' in new ComfyUI update
236
+ y = cond.get('y', None)
237
+ if y is not None:
238
+ y = y.to(dtype)
239
+ timestep = self.model_sampling_current.timestep(t)
240
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
241
+ # concat c_concat if exists (should exist for SVD), doubling channels to 8
242
+ if cond.get('c_concat', None) is not None:
243
+ x_noisy = torch.cat([x_noisy] + [cond['c_concat']], dim=1)
244
+
245
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, cond=cond)
246
+ return self.control_merge(None, control, control_prev, output_dtype)
247
+
248
+ def copy(self):
249
+ c = SVDControlNetAdvanced(self.control_model, self.timestep_keyframes, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
250
+ self.copy_to(c)
251
+ self.copy_to_advanced(c)
252
+ return c
253
+
254
+
255
+ class SparseCtrlAdvanced(ControlNetAdvanced):
256
+ def __init__(self, control_model, timestep_keyframes: TimestepKeyframeGroup, sparse_settings: SparseSettings=None, global_average_pooling=False, device=None, load_device=None, manual_cast_dtype=None):
257
+ super().__init__(control_model=control_model, timestep_keyframes=timestep_keyframes, global_average_pooling=global_average_pooling, device=device, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
258
+ self.control_model_wrapped = SparseModelPatcher(self.control_model, load_device=load_device, offload_device=comfy.model_management.unet_offload_device())
259
+ self.add_compatible_weight(ControlWeightType.SPARSECTRL)
260
+ self.control_model: SparseControlNet = self.control_model # does nothing except help with IDE hints
261
+ self.sparse_settings = sparse_settings if sparse_settings is not None else SparseSettings.default()
262
+ self.latent_format = None
263
+ self.preprocessed = False
264
+
265
+ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
266
+ # normal ControlNet stuff
267
+ control_prev = None
268
+ if self.previous_controlnet is not None:
269
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
270
+
271
+ if self.timestep_range is not None:
272
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
273
+ if control_prev is not None:
274
+ return control_prev
275
+ else:
276
+ return None
277
+
278
+ dtype = self.control_model.dtype
279
+ if self.manual_cast_dtype is not None:
280
+ dtype = self.manual_cast_dtype
281
+ output_dtype = x_noisy.dtype
282
+ # set actual input length on motion model
283
+ actual_length = x_noisy.size(0)//batched_number
284
+ full_length = actual_length if self.sub_idxs is None else self.full_latent_length
285
+ self.control_model.set_actual_length(actual_length=actual_length, full_length=full_length)
286
+ # prepare cond_hint, if needed
287
+ dim_mult = 1 if self.control_model.use_simplified_conditioning_embedding else 8
288
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2]*dim_mult != self.cond_hint.shape[2] or x_noisy.shape[3]*dim_mult != self.cond_hint.shape[3]:
289
+ # clear out cond_hint and conditioning_mask
290
+ if self.cond_hint is not None:
291
+ del self.cond_hint
292
+ self.cond_hint = None
293
+ # first, figure out which cond idxs are relevant, and where they fit in
294
+ cond_idxs = self.sparse_settings.sparse_method.get_indexes(hint_length=self.cond_hint_original.size(0), full_length=full_length)
295
+
296
+ range_idxs = list(range(full_length)) if self.sub_idxs is None else self.sub_idxs
297
+ hint_idxs = [] # idxs in cond_idxs
298
+ local_idxs = [] # idx to pun in final cond_hint
299
+ for i,cond_idx in enumerate(cond_idxs):
300
+ if cond_idx in range_idxs:
301
+ hint_idxs.append(i)
302
+ local_idxs.append(range_idxs.index(cond_idx))
303
+ # sub_cond_hint now contains the hints relevant to current x_noisy
304
+ sub_cond_hint = self.cond_hint_original[hint_idxs].to(dtype).to(self.device)
305
+
306
+ # scale cond_hints to match noisy input
307
+ if self.control_model.use_simplified_conditioning_embedding:
308
+ # RGB SparseCtrl; the inputs are latents - use bilinear to avoid blocky artifacts
309
+ sub_cond_hint = self.latent_format.process_in(sub_cond_hint) # multiplies by model scale factor
310
+ sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3], x_noisy.shape[2], "nearest-exact", "center").to(dtype).to(self.device)
311
+ else:
312
+ # other SparseCtrl; inputs are typical images
313
+ sub_cond_hint = comfy.utils.common_upscale(sub_cond_hint, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
314
+ # prepare cond_hint (b, c, h ,w)
315
+ cond_shape = list(sub_cond_hint.shape)
316
+ cond_shape[0] = len(range_idxs)
317
+ self.cond_hint = torch.zeros(cond_shape).to(dtype).to(self.device)
318
+ self.cond_hint[local_idxs] = sub_cond_hint[:]
319
+ # prepare cond_mask (b, 1, h, w)
320
+ cond_shape[1] = 1
321
+ cond_mask = torch.zeros(cond_shape).to(dtype).to(self.device)
322
+ cond_mask[local_idxs] = 1.0
323
+ # combine cond_hint and cond_mask into (b, c+1, h, w)
324
+ if not self.sparse_settings.merged:
325
+ self.cond_hint = torch.cat([self.cond_hint, cond_mask], dim=1)
326
+ del sub_cond_hint
327
+ del cond_mask
328
+ # make cond_hint match x_noisy batch
329
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
330
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
331
+
332
+ # prepare mask_cond_hint
333
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, dtype=dtype)
334
+
335
+ context = cond['c_crossattn']
336
+ y = cond.get('y', None)
337
+ if y is not None:
338
+ y = y.to(dtype)
339
+ timestep = self.model_sampling_current.timestep(t)
340
+ x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
341
+
342
+ control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
343
+ return self.control_merge(None, control, control_prev, output_dtype)
344
+
345
+ def pre_run_advanced(self, model, percent_to_timestep_function):
346
+ super().pre_run_advanced(model, percent_to_timestep_function)
347
+ if type(self.cond_hint_original) == PreprocSparseRGBWrapper:
348
+ if not self.control_model.use_simplified_conditioning_embedding:
349
+ raise ValueError("Any model besides RGB SparseCtrl should NOT have its images go through the RGB SparseCtrl preprocessor.")
350
+ self.cond_hint_original = self.cond_hint_original.condhint
351
+ self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond hint
352
+ if self.control_model.motion_wrapper is not None:
353
+ self.control_model.motion_wrapper.reset()
354
+ self.control_model.motion_wrapper.set_strength(self.sparse_settings.motion_strength)
355
+ self.control_model.motion_wrapper.set_scale_multiplier(self.sparse_settings.motion_scale)
356
+
357
+ def cleanup_advanced(self):
358
+ super().cleanup_advanced()
359
+ if self.latent_format is not None:
360
+ del self.latent_format
361
+ self.latent_format = None
362
+
363
+ def copy(self):
364
+ c = SparseCtrlAdvanced(self.control_model, self.timestep_keyframes, self.sparse_settings, self.global_average_pooling, self.device, self.load_device, self.manual_cast_dtype)
365
+ self.copy_to(c)
366
+ self.copy_to_advanced(c)
367
+ return c
368
+
369
+
370
+ class ControlLLLiteAdvanced(ControlBase, AdvancedControlBase):
371
+ # This ControlNet is more of an attention patch than a traditional controlnet
372
+ def __init__(self, patch_attn1: LLLitePatch, patch_attn2: LLLitePatch, timestep_keyframes: TimestepKeyframeGroup, device=None):
373
+ super().__init__(device)
374
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite(), require_model=True)
375
+ self.patch_attn1 = patch_attn1.set_control(self)
376
+ self.patch_attn2 = patch_attn2.set_control(self)
377
+ self.latent_dims_div2 = None
378
+ self.latent_dims_div4 = None
379
+
380
+ def patch_model(self, model: ModelPatcher):
381
+ model.set_model_attn1_patch(self.patch_attn1)
382
+ model.set_model_attn2_patch(self.patch_attn2)
383
+
384
+ def set_cond_hint(self, *args, **kwargs):
385
+ to_return = super().set_cond_hint(*args, **kwargs)
386
+ # cond hint for LLLite needs to be scaled between (-1, 1) instead of (0, 1)
387
+ self.cond_hint_original = self.cond_hint_original * 2.0 - 1.0
388
+ return to_return
389
+
390
+ def pre_run_advanced(self, *args, **kwargs):
391
+ AdvancedControlBase.pre_run_advanced(self, *args, **kwargs)
392
+ #logger.error(f"in cn: {id(self.patch_attn1)},{id(self.patch_attn2)}")
393
+ self.patch_attn1.set_control(self)
394
+ self.patch_attn2.set_control(self)
395
+ #logger.warn(f"in pre_run_advanced: {id(self)}")
396
+
397
+ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
398
+ # normal ControlNet stuff
399
+ control_prev = None
400
+ if self.previous_controlnet is not None:
401
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
402
+
403
+ if self.timestep_range is not None:
404
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
405
+ return control_prev
406
+
407
+ dtype = x_noisy.dtype
408
+ # prepare cond_hint
409
+ if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
410
+ if self.cond_hint is not None:
411
+ del self.cond_hint
412
+ self.cond_hint = None
413
+ # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
414
+ if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
415
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original[self.sub_idxs], x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
416
+ else:
417
+ self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
418
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
419
+ self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
420
+ # some special logic here compared to other controlnets:
421
+ # * The cond_emb in attn patches will divide latent dims by 2 or 4, integer
422
+ # * Due to this loss, the cond_emb will become smaller than x input if latent dims are not divisble by 2 or 4
423
+ divisible_by_2_h = x_noisy.shape[2]%2==0
424
+ divisible_by_2_w = x_noisy.shape[3]%2==0
425
+ if not (divisible_by_2_h and divisible_by_2_w):
426
+ #logger.warn(f"{x_noisy.shape} not divisible by 2!")
427
+ new_h = (x_noisy.shape[2]//2)*2
428
+ new_w = (x_noisy.shape[3]//2)*2
429
+ if not divisible_by_2_h:
430
+ new_h += 2
431
+ if not divisible_by_2_w:
432
+ new_w += 2
433
+ self.latent_dims_div2 = (new_h, new_w)
434
+ divisible_by_4_h = x_noisy.shape[2]%4==0
435
+ divisible_by_4_w = x_noisy.shape[3]%4==0
436
+ if not (divisible_by_4_h and divisible_by_4_w):
437
+ #logger.warn(f"{x_noisy.shape} not divisible by 4!")
438
+ new_h = (x_noisy.shape[2]//4)*4
439
+ new_w = (x_noisy.shape[3]//4)*4
440
+ if not divisible_by_4_h:
441
+ new_h += 4
442
+ if not divisible_by_4_w:
443
+ new_w += 4
444
+ self.latent_dims_div4 = (new_h, new_w)
445
+ # prepare mask
446
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number)
447
+ # done preparing; model patches will take care of everything now.
448
+ # return normal controlnet stuff
449
+ return control_prev
450
+
451
+ def cleanup_advanced(self):
452
+ super().cleanup_advanced()
453
+ self.patch_attn1.cleanup()
454
+ self.patch_attn2.cleanup()
455
+ self.latent_dims_div2 = None
456
+ self.latent_dims_div4 = None
457
+
458
+ def copy(self):
459
+ c = ControlLLLiteAdvanced(self.patch_attn1, self.patch_attn2, self.timestep_keyframes)
460
+ self.copy_to(c)
461
+ self.copy_to_advanced(c)
462
+ return c
463
+
464
+ # deepcopy needs to properly keep track of objects to work between model.clone calls!
465
+ # def __deepcopy__(self, *args, **kwargs):
466
+ # self.cleanup_advanced()
467
+ # return self
468
+
469
+ # def get_models(self):
470
+ # # get_models is called once at the start of every KSampler run - use to reset already_patched status
471
+ # out = super().get_models()
472
+ # logger.error(f"in get_models! {id(self)}")
473
+ # return out
474
+
475
+
476
+ def load_controlnet(ckpt_path, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
477
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
478
+ control = None
479
+ # check if a non-vanilla ControlNet
480
+ controlnet_type = ControlWeightType.DEFAULT
481
+ has_controlnet_key = False
482
+ has_motion_modules_key = False
483
+ has_temporal_res_block_key = False
484
+ for key in controlnet_data:
485
+ # LLLite check
486
+ if "lllite" in key:
487
+ controlnet_type = ControlWeightType.CONTROLLLLITE
488
+ break
489
+ # SparseCtrl check
490
+ elif "motion_modules" in key:
491
+ has_motion_modules_key = True
492
+ elif "controlnet" in key:
493
+ has_controlnet_key = True
494
+ # SVD-ControlNet check
495
+ elif "temporal_res_block" in key:
496
+ has_temporal_res_block_key = True
497
+ if has_controlnet_key and has_motion_modules_key:
498
+ controlnet_type = ControlWeightType.SPARSECTRL
499
+ elif has_controlnet_key and has_temporal_res_block_key:
500
+ controlnet_type = ControlWeightType.SVD_CONTROLNET
501
+
502
+ if controlnet_type != ControlWeightType.DEFAULT:
503
+ if controlnet_type == ControlWeightType.CONTROLLLLITE:
504
+ control = load_controllllite(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe)
505
+ elif controlnet_type == ControlWeightType.SPARSECTRL:
506
+ control = load_sparsectrl(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe, model=model)
507
+ elif controlnet_type == ControlWeightType.SVD_CONTROLNET:
508
+ control = load_svdcontrolnet(ckpt_path, controlnet_data=controlnet_data, timestep_keyframe=timestep_keyframe)
509
+ #raise Exception(f"SVD-ControlNet is not supported yet!")
510
+ #control = comfy_cn.load_controlnet(ckpt_path, model=model)
511
+ # otherwise, load vanilla ControlNet
512
+ else:
513
+ try:
514
+ # hacky way of getting load_torch_file in load_controlnet to use already-present controlnet_data and not redo loading
515
+ orig_load_torch_file = comfy.utils.load_torch_file
516
+ comfy.utils.load_torch_file = load_torch_file_with_dict_factory(controlnet_data, orig_load_torch_file)
517
+ control = comfy_cn.load_controlnet(ckpt_path, model=model)
518
+ finally:
519
+ comfy.utils.load_torch_file = orig_load_torch_file
520
+ return convert_to_advanced(control, timestep_keyframe=timestep_keyframe)
521
+
522
+
523
+ def convert_to_advanced(control, timestep_keyframe: TimestepKeyframeGroup=None):
524
+ # if already advanced, leave it be
525
+ if is_advanced_controlnet(control):
526
+ return control
527
+ # if exactly ControlNet returned, transform it into ControlNetAdvanced
528
+ if type(control) == ControlNet:
529
+ return ControlNetAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
530
+ # if exactly ControlLora returned, transform it into ControlLoraAdvanced
531
+ elif type(control) == ControlLora:
532
+ return ControlLoraAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
533
+ # if T2IAdapter returned, transform it into T2IAdapterAdvanced
534
+ elif isinstance(control, T2IAdapter):
535
+ return T2IAdapterAdvanced.from_vanilla(v=control, timestep_keyframe=timestep_keyframe)
536
+ # otherwise, leave it be - might be something I am not supporting yet
537
+ return control
538
+
539
+
540
+ def is_advanced_controlnet(input_object):
541
+ return hasattr(input_object, "sub_idxs")
542
+
543
+
544
+ def load_sparsectrl(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, sparse_settings=SparseSettings.default(), model=None) -> SparseCtrlAdvanced:
545
+ if controlnet_data is None:
546
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
547
+ # first, separate out motion part from normal controlnet part and attempt to load that portion
548
+ motion_data = {}
549
+ for key in list(controlnet_data.keys()):
550
+ if "temporal" in key:
551
+ motion_data[key] = controlnet_data.pop(key)
552
+ if len(motion_data) == 0:
553
+ raise ValueError(f"No motion-related keys in '{ckpt_path}'; not a valid SparseCtrl model!")
554
+ motion_wrapper: SparseCtrlMotionWrapper = SparseCtrlMotionWrapper(motion_data).to(comfy.model_management.unet_dtype())
555
+ missing, unexpected = motion_wrapper.load_state_dict(motion_data)
556
+ if len(missing) > 0 or len(unexpected) > 0:
557
+ logger.info(f"SparseCtrlMotionWrapper: {missing}, {unexpected}")
558
+
559
+ # now, load as if it was a normal controlnet - mostly copied from comfy load_controlnet function
560
+ controlnet_config = None
561
+ is_diffusers = False
562
+ use_simplified_conditioning_embedding = False
563
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data:
564
+ is_diffusers = True
565
+ if "controlnet_cond_embedding.weight" in controlnet_data:
566
+ is_diffusers = True
567
+ use_simplified_conditioning_embedding = True
568
+ if is_diffusers: #diffusers format
569
+ unet_dtype = comfy.model_management.unet_dtype()
570
+ controlnet_config = comfy.model_detection.unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
571
+ diffusers_keys = comfy.utils.unet_to_diffusers(controlnet_config)
572
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
573
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
574
+
575
+ count = 0
576
+ loop = True
577
+ while loop:
578
+ suffix = [".weight", ".bias"]
579
+ for s in suffix:
580
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
581
+ k_out = "zero_convs.{}.0{}".format(count, s)
582
+ if k_in not in controlnet_data:
583
+ loop = False
584
+ break
585
+ diffusers_keys[k_in] = k_out
586
+ count += 1
587
+ # normal conditioning embedding
588
+ if not use_simplified_conditioning_embedding:
589
+ count = 0
590
+ loop = True
591
+ while loop:
592
+ suffix = [".weight", ".bias"]
593
+ for s in suffix:
594
+ if count == 0:
595
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
596
+ else:
597
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
598
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
599
+ if k_in not in controlnet_data:
600
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
601
+ loop = False
602
+ diffusers_keys[k_in] = k_out
603
+ count += 1
604
+ # simplified conditioning embedding
605
+ else:
606
+ count = 0
607
+ suffix = [".weight", ".bias"]
608
+ for s in suffix:
609
+ k_in = "controlnet_cond_embedding{}".format(s)
610
+ k_out = "input_hint_block.{}{}".format(count, s)
611
+ diffusers_keys[k_in] = k_out
612
+
613
+ new_sd = {}
614
+ for k in diffusers_keys:
615
+ if k in controlnet_data:
616
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
617
+
618
+ leftover_keys = controlnet_data.keys()
619
+ if len(leftover_keys) > 0:
620
+ logger.info("leftover keys:", leftover_keys)
621
+ controlnet_data = new_sd
622
+
623
+ pth_key = 'control_model.zero_convs.0.0.weight'
624
+ pth = False
625
+ key = 'zero_convs.0.0.weight'
626
+ if pth_key in controlnet_data:
627
+ pth = True
628
+ key = pth_key
629
+ prefix = "control_model."
630
+ elif key in controlnet_data:
631
+ prefix = ""
632
+ else:
633
+ raise ValueError("The provided model is not a valid SparseCtrl model! [ErrorCode: HORSERADISH]")
634
+
635
+ if controlnet_config is None:
636
+ unet_dtype = comfy.model_management.unet_dtype()
637
+ controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
638
+ load_device = comfy.model_management.get_torch_device()
639
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
640
+ if manual_cast_dtype is not None:
641
+ controlnet_config["operations"] = manual_cast_clean_groupnorm
642
+ else:
643
+ controlnet_config["operations"] = disable_weight_init_clean_groupnorm
644
+ controlnet_config.pop("out_channels")
645
+ # get proper hint channels
646
+ if use_simplified_conditioning_embedding:
647
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
648
+ controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding
649
+ else:
650
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
651
+ controlnet_config["use_simplified_conditioning_embedding"] = use_simplified_conditioning_embedding
652
+ control_model = SparseControlNet(**controlnet_config)
653
+
654
+ if pth:
655
+ if 'difference' in controlnet_data:
656
+ if model is not None:
657
+ comfy.model_management.load_models_gpu([model])
658
+ model_sd = model.model_state_dict()
659
+ for x in controlnet_data:
660
+ c_m = "control_model."
661
+ if x.startswith(c_m):
662
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
663
+ if sd_key in model_sd:
664
+ cd = controlnet_data[x]
665
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
666
+ else:
667
+ logger.warning("WARNING: Loaded a diff SparseCtrl without a model. It will very likely not work.")
668
+
669
+ class WeightsLoader(torch.nn.Module):
670
+ pass
671
+ w = WeightsLoader()
672
+ w.control_model = control_model
673
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
674
+ else:
675
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
676
+ if len(missing) > 0 or len(unexpected) > 0:
677
+ logger.info(f"SparseCtrl ControlNet: {missing}, {unexpected}")
678
+
679
+ global_average_pooling = False
680
+ filename = os.path.splitext(ckpt_path)[0]
681
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
682
+ global_average_pooling = True
683
+
684
+ # both motion portion and controlnet portions are loaded; bring them together if using motion model
685
+ if sparse_settings.use_motion:
686
+ motion_wrapper.inject(control_model)
687
+
688
+ control = SparseCtrlAdvanced(control_model, timestep_keyframes=timestep_keyframe, sparse_settings=sparse_settings, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
689
+ return control
690
+
691
+
692
+ def load_controllllite(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None):
693
+ if controlnet_data is None:
694
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
695
+ # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
696
+ # first, split weights for each module
697
+ module_weights = {}
698
+ for key, value in controlnet_data.items():
699
+ fragments = key.split(".")
700
+ module_name = fragments[0]
701
+ weight_name = ".".join(fragments[1:])
702
+
703
+ if module_name not in module_weights:
704
+ module_weights[module_name] = {}
705
+ module_weights[module_name][weight_name] = value
706
+
707
+ # next, load each module
708
+ modules = {}
709
+ for module_name, weights in module_weights.items():
710
+ # kohya planned to do something about how these should be chosen, so I'm not touching this
711
+ # since I am not familiar with the logic for this
712
+ if "conditioning1.4.weight" in weights:
713
+ depth = 3
714
+ elif weights["conditioning1.2.weight"].shape[-1] == 4:
715
+ depth = 2
716
+ else:
717
+ depth = 1
718
+
719
+ module = LLLiteModule(
720
+ name=module_name,
721
+ is_conv2d=weights["down.0.weight"].ndim == 4,
722
+ in_dim=weights["down.0.weight"].shape[1],
723
+ depth=depth,
724
+ cond_emb_dim=weights["conditioning1.0.weight"].shape[0] * 2,
725
+ mlp_dim=weights["down.0.weight"].shape[0],
726
+ )
727
+ # load weights into module
728
+ module.load_state_dict(weights)
729
+ modules[module_name] = module
730
+ if len(modules) == 1:
731
+ module.is_first = True
732
+
733
+ #logger.info(f"loaded {ckpt_path} successfully, {len(modules)} modules")
734
+
735
+ patch_attn1 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN1)
736
+ patch_attn2 = LLLitePatch(modules=modules, patch_type=LLLitePatch.ATTN2)
737
+ control = ControlLLLiteAdvanced(patch_attn1=patch_attn1, patch_attn2=patch_attn2, timestep_keyframes=timestep_keyframe)
738
+ return control
739
+
740
+
741
+ def load_svdcontrolnet(ckpt_path: str, controlnet_data: dict[str, Tensor]=None, timestep_keyframe: TimestepKeyframeGroup=None, model=None):
742
+ if controlnet_data is None:
743
+ controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
744
+
745
+ controlnet_config = None
746
+ if "controlnet_cond_embedding.conv_in.weight" in controlnet_data: #diffusers format
747
+ unet_dtype = comfy.model_management.unet_dtype()
748
+ controlnet_config = svd_unet_config_from_diffusers_unet(controlnet_data, unet_dtype)
749
+ diffusers_keys = svd_unet_to_diffusers(controlnet_config)
750
+ diffusers_keys["controlnet_mid_block.weight"] = "middle_block_out.0.weight"
751
+ diffusers_keys["controlnet_mid_block.bias"] = "middle_block_out.0.bias"
752
+
753
+ count = 0
754
+ loop = True
755
+ while loop:
756
+ suffix = [".weight", ".bias"]
757
+ for s in suffix:
758
+ k_in = "controlnet_down_blocks.{}{}".format(count, s)
759
+ k_out = "zero_convs.{}.0{}".format(count, s)
760
+ if k_in not in controlnet_data:
761
+ loop = False
762
+ break
763
+ diffusers_keys[k_in] = k_out
764
+ count += 1
765
+
766
+ count = 0
767
+ loop = True
768
+ while loop:
769
+ suffix = [".weight", ".bias"]
770
+ for s in suffix:
771
+ if count == 0:
772
+ k_in = "controlnet_cond_embedding.conv_in{}".format(s)
773
+ else:
774
+ k_in = "controlnet_cond_embedding.blocks.{}{}".format(count - 1, s)
775
+ k_out = "input_hint_block.{}{}".format(count * 2, s)
776
+ if k_in not in controlnet_data:
777
+ k_in = "controlnet_cond_embedding.conv_out{}".format(s)
778
+ loop = False
779
+ diffusers_keys[k_in] = k_out
780
+ count += 1
781
+
782
+ new_sd = {}
783
+ for k in diffusers_keys:
784
+ if k in controlnet_data:
785
+ new_sd[diffusers_keys[k]] = controlnet_data.pop(k)
786
+
787
+ leftover_keys = controlnet_data.keys()
788
+ if len(leftover_keys) > 0:
789
+ spatial_leftover_keys = []
790
+ temporal_leftover_keys = []
791
+ other_leftover_keys = []
792
+ for key in leftover_keys:
793
+ if "spatial" in key:
794
+ spatial_leftover_keys.append(key)
795
+ elif "temporal" in key:
796
+ temporal_leftover_keys.append(key)
797
+ else:
798
+ other_leftover_keys.append(key)
799
+ logger.warn(f"spatial_leftover_keys ({len(spatial_leftover_keys)}): {spatial_leftover_keys}")
800
+ logger.warn(f"temporal_leftover_keys ({len(temporal_leftover_keys)}): {temporal_leftover_keys}")
801
+ logger.warn(f"other_leftover_keys ({len(other_leftover_keys)}): {other_leftover_keys}")
802
+ #print("leftover keys:", leftover_keys)
803
+ controlnet_data = new_sd
804
+
805
+ pth_key = 'control_model.zero_convs.0.0.weight'
806
+ pth = False
807
+ key = 'zero_convs.0.0.weight'
808
+ if pth_key in controlnet_data:
809
+ pth = True
810
+ key = pth_key
811
+ prefix = "control_model."
812
+ elif key in controlnet_data:
813
+ prefix = ""
814
+ else:
815
+ raise ValueError("The provided model is not a valid SVD-ControlNet model! [ErrorCode: MUSTARD]")
816
+
817
+ if controlnet_config is None:
818
+ unet_dtype = comfy.model_management.unet_dtype()
819
+ controlnet_config = comfy.model_detection.model_config_from_unet(controlnet_data, prefix, unet_dtype, True).unet_config
820
+ load_device = comfy.model_management.get_torch_device()
821
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
822
+ if manual_cast_dtype is not None:
823
+ controlnet_config["operations"] = comfy.ops.manual_cast
824
+ controlnet_config.pop("out_channels")
825
+ controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
826
+ control_model = SVDControlNet(**controlnet_config)
827
+
828
+ if pth:
829
+ if 'difference' in controlnet_data:
830
+ if model is not None:
831
+ comfy.model_management.load_models_gpu([model])
832
+ model_sd = model.model_state_dict()
833
+ for x in controlnet_data:
834
+ c_m = "control_model."
835
+ if x.startswith(c_m):
836
+ sd_key = "diffusion_model.{}".format(x[len(c_m):])
837
+ if sd_key in model_sd:
838
+ cd = controlnet_data[x]
839
+ cd += model_sd[sd_key].type(cd.dtype).to(cd.device)
840
+ else:
841
+ print("WARNING: Loaded a diff controlnet without a model. It will very likely not work.")
842
+
843
+ class WeightsLoader(torch.nn.Module):
844
+ pass
845
+ w = WeightsLoader()
846
+ w.control_model = control_model
847
+ missing, unexpected = w.load_state_dict(controlnet_data, strict=False)
848
+ else:
849
+ missing, unexpected = control_model.load_state_dict(controlnet_data, strict=False)
850
+ if len(missing) > 0 or len(unexpected) > 0:
851
+ logger.info(f"SVD-ControlNet: {missing}, {unexpected}")
852
+
853
+ global_average_pooling = False
854
+ filename = os.path.splitext(ckpt_path)[0]
855
+ if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
856
+ global_average_pooling = True
857
+
858
+ control = SVDControlNetAdvanced(control_model, timestep_keyframes=timestep_keyframe, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
859
+ return control
860
+
ComfyUI-Advanced-ControlNet/adv_control/control_lllite.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI
2
+ # basically, all the LLLite core code is from there, which I then combined with
3
+ # Advanced-ControlNet features and QoL
4
+ import math
5
+ from typing import Union
6
+ from torch import Tensor
7
+ import torch
8
+ import os
9
+
10
+ import comfy.utils
11
+ from comfy.controlnet import ControlBase
12
+
13
+ from .logger import logger
14
+ from .utils import AdvancedControlBase, deepcopy_with_sharing, prepare_mask_batch
15
+
16
+
17
+ def extra_options_to_module_prefix(extra_options):
18
+ # extra_options = {'transformer_index': 2, 'block_index': 8, 'original_shape': [2, 4, 128, 128], 'block': ('input', 7), 'n_heads': 20, 'dim_head': 64}
19
+
20
+ # block is: [('input', 4), ('input', 5), ('input', 7), ('input', 8), ('middle', 0),
21
+ # ('output', 0), ('output', 1), ('output', 2), ('output', 3), ('output', 4), ('output', 5)]
22
+ # transformer_index is: [0, 1, 2, 3, 4, 5, 6, 7, 8], for each block
23
+ # block_index is: 0-1 or 0-9, depends on the block
24
+ # input 7 and 8, middle has 10 blocks
25
+
26
+ # make module name from extra_options
27
+ block = extra_options["block"]
28
+ block_index = extra_options["block_index"]
29
+ if block[0] == "input":
30
+ module_pfx = f"lllite_unet_input_blocks_{block[1]}_1_transformer_blocks_{block_index}"
31
+ elif block[0] == "middle":
32
+ module_pfx = f"lllite_unet_middle_block_1_transformer_blocks_{block_index}"
33
+ elif block[0] == "output":
34
+ module_pfx = f"lllite_unet_output_blocks_{block[1]}_1_transformer_blocks_{block_index}"
35
+ else:
36
+ raise Exception(f"ControlLLLite: invalid block name '{block[0]}'. Expected 'input', 'middle', or 'output'.")
37
+ return module_pfx
38
+
39
+
40
+ class LLLitePatch:
41
+ ATTN1 = "attn1"
42
+ ATTN2 = "attn2"
43
+ def __init__(self, modules: dict[str, 'LLLiteModule'], patch_type: str, control: Union[AdvancedControlBase, ControlBase]=None):
44
+ self.modules = modules
45
+ self.control = control
46
+ self.patch_type = patch_type
47
+ #logger.error(f"create LLLitePatch: {id(self)},{control}")
48
+
49
+ def __call__(self, q, k, v, extra_options):
50
+ #logger.error(f"in __call__: {id(self)}")
51
+ # determine if have anything to run
52
+ if self.control.timestep_range is not None:
53
+ # it turns out comparing single-value tensors to floats is extremely slow
54
+ # a: Tensor = extra_options["sigmas"][0]
55
+ if self.control.t > self.control.timestep_range[0] or self.control.t < self.control.timestep_range[1]:
56
+ return q, k, v
57
+
58
+ module_pfx = extra_options_to_module_prefix(extra_options)
59
+
60
+ is_attn1 = q.shape[-1] == k.shape[-1] # self attention
61
+ if is_attn1:
62
+ module_pfx = module_pfx + "_attn1"
63
+ else:
64
+ module_pfx = module_pfx + "_attn2"
65
+
66
+ module_pfx_to_q = module_pfx + "_to_q"
67
+ module_pfx_to_k = module_pfx + "_to_k"
68
+ module_pfx_to_v = module_pfx + "_to_v"
69
+
70
+ if module_pfx_to_q in self.modules:
71
+ q = q + self.modules[module_pfx_to_q](q, self.control)
72
+ if module_pfx_to_k in self.modules:
73
+ k = k + self.modules[module_pfx_to_k](k, self.control)
74
+ if module_pfx_to_v in self.modules:
75
+ v = v + self.modules[module_pfx_to_v](v, self.control)
76
+
77
+ return q, k, v
78
+
79
+ def to(self, device):
80
+ #logger.info(f"to... has control? {self.control}")
81
+ for d in self.modules.keys():
82
+ self.modules[d] = self.modules[d].to(device)
83
+ return self
84
+
85
+ def set_control(self, control: Union[AdvancedControlBase, ControlBase]) -> 'LLLitePatch':
86
+ self.control = control
87
+ return self
88
+ #logger.error(f"set control for LLLitePatch: {id(self)}, cn: {id(control)}")
89
+
90
+ def clone_with_control(self, control: AdvancedControlBase):
91
+ #logger.error(f"clone-set control for LLLitePatch: {id(self)},{id(control)}")
92
+ return LLLitePatch(self.modules, self.patch_type, control)
93
+
94
+ def cleanup(self):
95
+ #total_cleaned = 0
96
+ for module in self.modules.values():
97
+ module.cleanup()
98
+ # total_cleaned += 1
99
+ #logger.info(f"cleaned modules: {total_cleaned}, {id(self)}")
100
+ #logger.error(f"cleanup LLLitePatch: {id(self)}")
101
+
102
+ # make sure deepcopy does not copy control, and deepcopied LLLitePatch should be assigned to control
103
+ def __deepcopy__(self, memo):
104
+ self.cleanup()
105
+ to_return: LLLitePatch = deepcopy_with_sharing(self, shared_attribute_names = ['control'], memo=memo)
106
+ #logger.warn(f"patch {id(self)} turned into {id(to_return)}")
107
+ try:
108
+ if self.patch_type == self.ATTN1:
109
+ to_return.control.patch_attn1 = to_return
110
+ elif self.patch_type == self.ATTN2:
111
+ to_return.control.patch_attn2 = to_return
112
+ except Exception:
113
+ pass
114
+ return to_return
115
+
116
+
117
+ # TODO: use comfy.ops to support fp8 properly
118
+ class LLLiteModule(torch.nn.Module):
119
+ def __init__(
120
+ self,
121
+ name: str,
122
+ is_conv2d: bool,
123
+ in_dim: int,
124
+ depth: int,
125
+ cond_emb_dim: int,
126
+ mlp_dim: int,
127
+ ):
128
+ super().__init__()
129
+ self.name = name
130
+ self.is_conv2d = is_conv2d
131
+ self.is_first = False
132
+
133
+ modules = []
134
+ modules.append(torch.nn.Conv2d(3, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0)) # to latent (from VAE) size*2
135
+ if depth == 1:
136
+ modules.append(torch.nn.ReLU(inplace=True))
137
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
138
+ elif depth == 2:
139
+ modules.append(torch.nn.ReLU(inplace=True))
140
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=4, stride=4, padding=0))
141
+ elif depth == 3:
142
+ # kernel size 8 is too large, so set it to 4
143
+ modules.append(torch.nn.ReLU(inplace=True))
144
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim // 2, kernel_size=4, stride=4, padding=0))
145
+ modules.append(torch.nn.ReLU(inplace=True))
146
+ modules.append(torch.nn.Conv2d(cond_emb_dim // 2, cond_emb_dim, kernel_size=2, stride=2, padding=0))
147
+
148
+ self.conditioning1 = torch.nn.Sequential(*modules)
149
+
150
+ if self.is_conv2d:
151
+ self.down = torch.nn.Sequential(
152
+ torch.nn.Conv2d(in_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
153
+ torch.nn.ReLU(inplace=True),
154
+ )
155
+ self.mid = torch.nn.Sequential(
156
+ torch.nn.Conv2d(mlp_dim + cond_emb_dim, mlp_dim, kernel_size=1, stride=1, padding=0),
157
+ torch.nn.ReLU(inplace=True),
158
+ )
159
+ self.up = torch.nn.Sequential(
160
+ torch.nn.Conv2d(mlp_dim, in_dim, kernel_size=1, stride=1, padding=0),
161
+ )
162
+ else:
163
+ self.down = torch.nn.Sequential(
164
+ torch.nn.Linear(in_dim, mlp_dim),
165
+ torch.nn.ReLU(inplace=True),
166
+ )
167
+ self.mid = torch.nn.Sequential(
168
+ torch.nn.Linear(mlp_dim + cond_emb_dim, mlp_dim),
169
+ torch.nn.ReLU(inplace=True),
170
+ )
171
+ self.up = torch.nn.Sequential(
172
+ torch.nn.Linear(mlp_dim, in_dim),
173
+ )
174
+
175
+ self.depth = depth
176
+ self.cond_emb = None
177
+ self.cx_shape = None
178
+ self.prev_batch = 0
179
+ self.prev_sub_idxs = None
180
+
181
+ def cleanup(self):
182
+ del self.cond_emb
183
+ self.cond_emb = None
184
+ self.cx_shape = None
185
+ self.prev_batch = 0
186
+ self.prev_sub_idxs = None
187
+
188
+ def forward(self, x: Tensor, control: Union[AdvancedControlBase, ControlBase]):
189
+ mask = None
190
+ mask_tk = None
191
+ #logger.info(x.shape)
192
+ if self.cond_emb is None or control.sub_idxs != self.prev_sub_idxs or x.shape[0] != self.prev_batch:
193
+ # print(f"cond_emb is None, {self.name}")
194
+ cond_hint = control.cond_hint.to(x.device, dtype=x.dtype)
195
+ if control.latent_dims_div2 is not None and x.shape[-1] != 1280:
196
+ cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div2[0] * 8, control.latent_dims_div2[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
197
+ elif control.latent_dims_div4 is not None and x.shape[-1] == 1280:
198
+ cond_hint = comfy.utils.common_upscale(cond_hint, control.latent_dims_div4[0] * 8, control.latent_dims_div4[1] * 8, 'nearest-exact', "center").to(x.device, dtype=x.dtype)
199
+ cx = self.conditioning1(cond_hint)
200
+ self.cx_shape = cx.shape
201
+ if not self.is_conv2d:
202
+ # reshape / b,c,h,w -> b,h*w,c
203
+ n, c, h, w = cx.shape
204
+ cx = cx.view(n, c, h * w).permute(0, 2, 1)
205
+ self.cond_emb = cx
206
+ # save prev values
207
+ self.prev_batch = x.shape[0]
208
+ self.prev_sub_idxs = control.sub_idxs
209
+
210
+ cx: torch.Tensor = self.cond_emb
211
+ # print(f"forward {self.name}, {cx.shape}, {x.shape}")
212
+
213
+ # TODO: make masks work for conv2d (could not find any ControlLLLites at this time that use them)
214
+ # create masks
215
+ if not self.is_conv2d:
216
+ n, c, h, w = self.cx_shape
217
+ if control.mask_cond_hint is not None:
218
+ mask = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
219
+ mask = mask.view(mask.shape[0], 1, h * w).permute(0, 2, 1)
220
+ if control.tk_mask_cond_hint is not None:
221
+ mask_tk = prepare_mask_batch(control.mask_cond_hint, (1, 1, h, w)).to(cx.dtype)
222
+ mask_tk = mask_tk.view(mask_tk.shape[0], 1, h * w).permute(0, 2, 1)
223
+
224
+ # x in uncond/cond doubles batch size
225
+ if x.shape[0] != cx.shape[0]:
226
+ if self.is_conv2d:
227
+ cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1, 1)
228
+ else:
229
+ # print("x.shape[0] != cx.shape[0]", x.shape[0], cx.shape[0])
230
+ cx = cx.repeat(x.shape[0] // cx.shape[0], 1, 1)
231
+ if mask is not None:
232
+ mask = mask.repeat(x.shape[0] // mask.shape[0], 1, 1)
233
+ if mask_tk is not None:
234
+ mask_tk = mask_tk.repeat(x.shape[0] // mask_tk.shape[0], 1, 1)
235
+
236
+ if mask is None:
237
+ mask = 1.0
238
+ elif mask_tk is not None:
239
+ mask = mask * mask_tk
240
+
241
+ #logger.info(f"cs: {cx.shape}, x: {x.shape}, is_conv2d: {self.is_conv2d}")
242
+ cx = torch.cat([cx, self.down(x)], dim=1 if self.is_conv2d else 2)
243
+ cx = self.mid(cx)
244
+ cx = self.up(cx)
245
+ if control.latent_keyframes is not None:
246
+ cx = cx * control.calc_latent_keyframe_mults(x=cx, batched_number=control.batched_number)
247
+ if control.weights is not None and control.weights.has_uncond_multiplier:
248
+ cond_or_uncond = control.batched_number.cond_or_uncond
249
+ actual_length = cx.size(0) // control.batched_number
250
+ for idx, cond_type in enumerate(cond_or_uncond):
251
+ # if uncond, set to weight's uncond_multiplier
252
+ if cond_type == 1:
253
+ cx[actual_length*idx:actual_length*(idx+1)] *= control.weights.uncond_multiplier
254
+ return cx * mask * control.strength * control._current_timestep_keyframe.strength
ComfyUI-Advanced-ControlNet/adv_control/control_reference.py ADDED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Union
2
+
3
+ import math
4
+ import torch
5
+ from torch import Tensor
6
+
7
+ import comfy.sample
8
+ import comfy.model_patcher
9
+ import comfy.utils
10
+ from comfy.controlnet import ControlBase
11
+ from comfy.model_patcher import ModelPatcher
12
+ from comfy.ldm.modules.attention import BasicTransformerBlock
13
+ from comfy.ldm.modules.diffusionmodules import openaimodel
14
+
15
+ from .logger import logger
16
+ from .utils import (AdvancedControlBase, ControlWeights, TimestepKeyframeGroup, AbstractPreprocWrapper,
17
+ deepcopy_with_sharing, prepare_mask_batch, broadcast_image_to_full)
18
+
19
+
20
+ def refcn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
21
+ def get_refcn(control: ControlBase, order: int=-1):
22
+ ref_set: set[ReferenceAdvanced] = set()
23
+ if control is None:
24
+ return ref_set
25
+ if type(control) == ReferenceAdvanced:
26
+ control.order = order
27
+ order -= 1
28
+ ref_set.add(control)
29
+ ref_set.update(get_refcn(control.previous_controlnet, order=order))
30
+ return ref_set
31
+
32
+ def refcn_sample(model: ModelPatcher, *args, **kwargs):
33
+ # check if positive or negative conds contain ref cn
34
+ positive = args[-3]
35
+ negative = args[-2]
36
+ ref_set = set()
37
+ if positive is not None:
38
+ for cond in positive:
39
+ if "control" in cond[1]:
40
+ ref_set.update(get_refcn(cond[1]["control"]))
41
+ if negative is not None:
42
+ for cond in negative:
43
+ if "control" in cond[1]:
44
+ ref_set.update(get_refcn(cond[1]["control"]))
45
+ # if no ref cn found, do original function immediately
46
+ if len(ref_set) == 0:
47
+ return orig_comfy_sample(model, *args, **kwargs)
48
+ # otherwise, injection time
49
+ try:
50
+ # inject
51
+ # storage for all Reference-related injections
52
+ reference_injections = ReferenceInjections()
53
+
54
+ # first, handle attn module injection
55
+ all_modules = torch_dfs(model.model)
56
+ attn_modules: list[RefBasicTransformerBlock] = []
57
+ for module in all_modules:
58
+ if isinstance(module, BasicTransformerBlock):
59
+ attn_modules.append(module)
60
+ attn_modules = [module for module in all_modules if isinstance(module, BasicTransformerBlock)]
61
+ attn_modules = sorted(attn_modules, key=lambda x: -x.norm1.normalized_shape[0])
62
+ for i, module in enumerate(attn_modules):
63
+ injection_holder = InjectionBasicTransformerBlockHolder(block=module, idx=i)
64
+ injection_holder.attn_weight = float(i) / float(len(attn_modules))
65
+ if hasattr(module, "_forward"): # backward compatibility
66
+ module._forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
67
+ else:
68
+ module.forward = _forward_inject_BasicTransformerBlock.__get__(module, type(module))
69
+ module.injection_holder = injection_holder
70
+ reference_injections.attn_modules.append(module)
71
+ # figure out which module is middle block
72
+ if hasattr(model.model.diffusion_model, "middle_block"):
73
+ mid_modules = torch_dfs(model.model.diffusion_model.middle_block)
74
+ mid_attn_modules: list[RefBasicTransformerBlock] = [module for module in mid_modules if isinstance(module, BasicTransformerBlock)]
75
+ for module in mid_attn_modules:
76
+ module.injection_holder.is_middle = True
77
+
78
+ # next, handle gn module injection (TimestepEmbedSequential)
79
+ # TODO: figure out the logic behind these hardcoded indexes
80
+ if type(model.model).__name__ == "SDXL":
81
+ input_block_indices = [4, 5, 7, 8]
82
+ output_block_indices = [0, 1, 2, 3, 4, 5]
83
+ else:
84
+ input_block_indices = [4, 5, 7, 8, 10, 11]
85
+ output_block_indices = [0, 1, 2, 3, 4, 5, 6, 7]
86
+ if hasattr(model.model.diffusion_model, "middle_block"):
87
+ module = model.model.diffusion_model.middle_block
88
+ injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=0, is_middle=True)
89
+ injection_holder.gn_weight = 0.0
90
+ module.injection_holder = injection_holder
91
+ reference_injections.gn_modules.append(module)
92
+ for w, i in enumerate(input_block_indices):
93
+ module = model.model.diffusion_model.input_blocks[i]
94
+ injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_input=True)
95
+ injection_holder.gn_weight = 1.0 - float(w) / float(len(input_block_indices))
96
+ module.injection_holder = injection_holder
97
+ reference_injections.gn_modules.append(module)
98
+ for w, i in enumerate(output_block_indices):
99
+ module = model.model.diffusion_model.output_blocks[i]
100
+ injection_holder = InjectionTimestepEmbedSequentialHolder(block=module, idx=i, is_output=True)
101
+ injection_holder.gn_weight = float(w) / float(len(output_block_indices))
102
+ module.injection_holder = injection_holder
103
+ reference_injections.gn_modules.append(module)
104
+ # hack gn_module forwards and update weights
105
+ for i, module in enumerate(reference_injections.gn_modules):
106
+ module.injection_holder.gn_weight *= 2
107
+
108
+ # handle diffusion_model forward injection
109
+ reference_injections.diffusion_model_orig_forward = model.model.diffusion_model.forward
110
+ model.model.diffusion_model.forward = factory_forward_inject_UNetModel(reference_injections).__get__(model.model.diffusion_model, type(model.model.diffusion_model))
111
+ # store ordered ref cns in model's transformer options
112
+ orig_model_options = model.model_options
113
+ new_model_options = model.model_options.copy()
114
+ new_model_options["transformer_options"] = model.model_options["transformer_options"].copy()
115
+ ref_list: list[ReferenceAdvanced] = list(ref_set)
116
+ new_model_options["transformer_options"][REF_CONTROL_LIST_ALL] = sorted(ref_list, key=lambda x: x.order)
117
+ model.model_options = new_model_options
118
+ # continue with original function
119
+ return orig_comfy_sample(model, *args, **kwargs)
120
+ finally:
121
+ # cleanup injections
122
+ # restore attn modules
123
+ attn_modules: list[RefBasicTransformerBlock] = reference_injections.attn_modules
124
+ for module in attn_modules:
125
+ module.injection_holder.restore(module)
126
+ module.injection_holder.clean()
127
+ del module.injection_holder
128
+ del attn_modules
129
+ # restore gn modules
130
+ gn_modules: list[RefTimestepEmbedSequential] = reference_injections.gn_modules
131
+ for module in gn_modules:
132
+ module.injection_holder.restore(module)
133
+ module.injection_holder.clean()
134
+ del module.injection_holder
135
+ del gn_modules
136
+ # restore diffusion_model forward function
137
+ model.model.diffusion_model.forward = reference_injections.diffusion_model_orig_forward.__get__(model.model.diffusion_model, type(model.model.diffusion_model))
138
+ # restore model_options
139
+ model.model_options = orig_model_options
140
+ # cleanup
141
+ reference_injections.cleanup()
142
+ return refcn_sample
143
+ # inject sample functions
144
+ comfy.sample.sample = refcn_sample_factory(comfy.sample.sample)
145
+ comfy.sample.sample_custom = refcn_sample_factory(comfy.sample.sample_custom, is_custom=True)
146
+
147
+
148
+ REF_ATTN_CONTROL_LIST = "ref_attn_control_list"
149
+ REF_ADAIN_CONTROL_LIST = "ref_adain_control_list"
150
+ REF_CONTROL_LIST_ALL = "ref_control_list_all"
151
+ REF_CONTROL_INFO = "ref_control_info"
152
+ REF_ATTN_MACHINE_STATE = "ref_attn_machine_state"
153
+ REF_ADAIN_MACHINE_STATE = "ref_adain_machine_state"
154
+ REF_COND_IDXS = "ref_cond_idxs"
155
+ REF_UNCOND_IDXS = "ref_uncond_idxs"
156
+
157
+
158
+ class MachineState:
159
+ WRITE = "write"
160
+ READ = "read"
161
+ STYLEALIGN = "stylealign"
162
+ OFF = "off"
163
+
164
+
165
+ class ReferenceType:
166
+ ATTN = "reference_attn"
167
+ ADAIN = "reference_adain"
168
+ ATTN_ADAIN = "reference_attn+adain"
169
+ STYLE_ALIGN = "StyleAlign"
170
+
171
+ _LIST = [ATTN, ADAIN, ATTN_ADAIN]
172
+ _LIST_ATTN = [ATTN, ATTN_ADAIN]
173
+ _LIST_ADAIN = [ADAIN, ATTN_ADAIN]
174
+
175
+ @classmethod
176
+ def is_attn(cls, ref_type: str):
177
+ return ref_type in cls._LIST_ATTN
178
+
179
+ @classmethod
180
+ def is_adain(cls, ref_type: str):
181
+ return ref_type in cls._LIST_ADAIN
182
+
183
+
184
+ class ReferenceOptions:
185
+ def __init__(self, reference_type: str,
186
+ attn_style_fidelity: float, adain_style_fidelity: float,
187
+ attn_ref_weight: float, adain_ref_weight: float,
188
+ attn_strength: float=1.0, adain_strength: float=1.0,
189
+ ref_with_other_cns: bool=False):
190
+ self.reference_type = reference_type
191
+ # attn
192
+ self.original_attn_style_fidelity = attn_style_fidelity
193
+ self.attn_style_fidelity = attn_style_fidelity
194
+ self.attn_ref_weight = attn_ref_weight
195
+ self.attn_strength = attn_strength
196
+ # adain
197
+ self.original_adain_style_fidelity = adain_style_fidelity
198
+ self.adain_style_fidelity = adain_style_fidelity
199
+ self.adain_ref_weight = adain_ref_weight
200
+ self.adain_strength = adain_strength
201
+ # other
202
+ self.ref_with_other_cns = ref_with_other_cns
203
+
204
+ def clone(self):
205
+ return ReferenceOptions(reference_type=self.reference_type,
206
+ attn_style_fidelity=self.original_attn_style_fidelity, adain_style_fidelity=self.original_adain_style_fidelity,
207
+ attn_ref_weight=self.attn_ref_weight, adain_ref_weight=self.adain_ref_weight,
208
+ attn_strength=self.attn_strength, adain_strength=self.adain_strength,
209
+ ref_with_other_cns=self.ref_with_other_cns)
210
+
211
+ @staticmethod
212
+ def create_combo(reference_type: str, style_fidelity: float, ref_weight: float, ref_with_other_cns: bool=False):
213
+ return ReferenceOptions(reference_type=reference_type,
214
+ attn_style_fidelity=style_fidelity, adain_style_fidelity=style_fidelity,
215
+ attn_ref_weight=ref_weight, adain_ref_weight=ref_weight,
216
+ ref_with_other_cns=ref_with_other_cns)
217
+
218
+
219
+
220
+ class ReferencePreprocWrapper(AbstractPreprocWrapper):
221
+ error_msg = error_msg = "Invalid use of Reference Preprocess output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply Advanced ControlNet node. It cannot be used for anything else that accepts IMAGE input."
222
+ def __init__(self, condhint: Tensor):
223
+ super().__init__(condhint)
224
+
225
+
226
+ class ReferenceAdvanced(ControlBase, AdvancedControlBase):
227
+ CHANNEL_TO_MULT = {320: 1, 640: 2, 1280: 4}
228
+
229
+ def __init__(self, ref_opts: ReferenceOptions, timestep_keyframes: TimestepKeyframeGroup, device=None):
230
+ super().__init__(device)
231
+ AdvancedControlBase.__init__(self, super(), timestep_keyframes=timestep_keyframes, weights_default=ControlWeights.controllllite())
232
+ self.ref_opts = ref_opts
233
+ self.order = 0
234
+ self.latent_format = None
235
+ self.model_sampling_current = None
236
+ self.should_apply_attn_effective_strength = False
237
+ self.should_apply_adain_effective_strength = False
238
+ self.should_apply_effective_masks = False
239
+ self.latent_shape = None
240
+
241
+ def any_attn_strength_to_apply(self):
242
+ return self.should_apply_attn_effective_strength or self.should_apply_effective_masks
243
+
244
+ def any_adain_strength_to_apply(self):
245
+ return self.should_apply_adain_effective_strength or self.should_apply_effective_masks
246
+
247
+ def get_effective_strength(self):
248
+ effective_strength = self.strength
249
+ if self._current_timestep_keyframe is not None:
250
+ effective_strength = effective_strength * self._current_timestep_keyframe.strength
251
+ return effective_strength
252
+
253
+ def get_effective_attn_mask_or_float(self, x: Tensor, channels: int, is_mid: bool):
254
+ if not self.should_apply_effective_masks:
255
+ return self.get_effective_strength() * self.ref_opts.attn_strength
256
+ if is_mid:
257
+ div = 8
258
+ else:
259
+ div = self.CHANNEL_TO_MULT[channels]
260
+ real_mask = torch.ones([self.latent_shape[0], 1, self.latent_shape[2]//div, self.latent_shape[3]//div]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.attn_strength
261
+ self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number)
262
+ # mask is now shape [b, 1, h ,w]; need to turn into [b, h*w, 1]
263
+ b, c, h, w = real_mask.shape
264
+ real_mask = real_mask.permute(0, 2, 3, 1).reshape(b, h*w, c)
265
+ return real_mask
266
+
267
+ def get_effective_adain_mask_or_float(self, x: Tensor):
268
+ if not self.should_apply_effective_masks:
269
+ return self.get_effective_strength() * self.ref_opts.adain_strength
270
+ b, c, h, w = x.shape
271
+ real_mask = torch.ones([b, 1, h, w]).to(dtype=x.dtype, device=x.device) * self.strength * self.ref_opts.adain_strength
272
+ self.apply_advanced_strengths_and_masks(x=real_mask, batched_number=self.batched_number)
273
+ return real_mask
274
+
275
+ def should_run(self):
276
+ running = super().should_run()
277
+ if not running:
278
+ return running
279
+ attn_run = False
280
+ adain_run = False
281
+ if ReferenceType.is_attn(self.ref_opts.reference_type):
282
+ # attn will run as long as neither weight or strength is zero
283
+ attn_run = not (math.isclose(self.ref_opts.attn_ref_weight, 0.0) or math.isclose(self.ref_opts.attn_strength, 0.0))
284
+ if ReferenceType.is_adain(self.ref_opts.reference_type):
285
+ # adain will run as long as neither weight or strength is zero
286
+ adain_run = not (math.isclose(self.ref_opts.adain_ref_weight, 0.0) or math.isclose(self.ref_opts.adain_strength, 0.0))
287
+ return attn_run or adain_run
288
+
289
+ def pre_run_advanced(self, model, percent_to_timestep_function):
290
+ AdvancedControlBase.pre_run_advanced(self, model, percent_to_timestep_function)
291
+ if type(self.cond_hint_original) == ReferencePreprocWrapper:
292
+ self.cond_hint_original = self.cond_hint_original.condhint
293
+ self.latent_format = model.latent_format # LatentFormat object, used to process_in latent cond_hint
294
+ self.model_sampling_current = model.model_sampling
295
+ # SDXL is more sensitive to style_fidelity according to sd-webui-controlnet comments
296
+ if type(model).__name__ == "SDXL":
297
+ self.ref_opts.attn_style_fidelity = self.ref_opts.original_attn_style_fidelity ** 3.0
298
+ self.ref_opts.adain_style_fidelity = self.ref_opts.original_adain_style_fidelity ** 3.0
299
+ else:
300
+ self.ref_opts.attn_style_fidelity = self.ref_opts.original_attn_style_fidelity
301
+ self.ref_opts.adain_style_fidelity = self.ref_opts.original_adain_style_fidelity
302
+
303
+ def get_control_advanced(self, x_noisy: Tensor, t, cond, batched_number: int):
304
+ # normal ControlNet stuff
305
+ control_prev = None
306
+ if self.previous_controlnet is not None:
307
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
308
+
309
+ if self.timestep_range is not None:
310
+ if t[0] > self.timestep_range[0] or t[0] < self.timestep_range[1]:
311
+ return control_prev
312
+
313
+ dtype = x_noisy.dtype
314
+ # prepare cond_hint - it is a latent, NOT an image
315
+ #if self.sub_idxs is not None or self.cond_hint is None or x_noisy.shape[2] != self.cond_hint.shape[2] or x_noisy.shape[3] != self.cond_hint.shape[3]:
316
+ if self.cond_hint is not None:
317
+ del self.cond_hint
318
+ self.cond_hint = None
319
+ # if self.cond_hint_original length greater or equal to real latent count, subdivide it before scaling
320
+ if self.sub_idxs is not None and self.cond_hint_original.size(0) >= self.full_latent_length:
321
+ self.cond_hint = comfy.utils.common_upscale(
322
+ self.cond_hint_original[self.sub_idxs],
323
+ x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(self.device)
324
+ else:
325
+ self.cond_hint = comfy.utils.common_upscale(
326
+ self.cond_hint_original,
327
+ x_noisy.shape[3], x_noisy.shape[2], 'nearest-exact', "center").to(dtype).to(self.device)
328
+ if x_noisy.shape[0] != self.cond_hint.shape[0]:
329
+ self.cond_hint = broadcast_image_to_full(self.cond_hint, x_noisy.shape[0], batched_number, except_one=False)
330
+ # noise cond_hint based on sigma (current step)
331
+ self.cond_hint = self.latent_format.process_in(self.cond_hint)
332
+ self.cond_hint = ref_noise_latents(self.cond_hint, sigma=t, noise=None)
333
+ timestep = self.model_sampling_current.timestep(t)
334
+ self.should_apply_attn_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.attn_strength, 1.0))
335
+ self.should_apply_adain_effective_strength = not (math.isclose(self.strength, 1.0) and math.isclose(self._current_timestep_keyframe.strength, 1.0) and math.isclose(self.ref_opts.adain_strength, 1.0))
336
+ # prepare mask - use direct_attn, so the mask dims will match source latents (and be smaller)
337
+ self.prepare_mask_cond_hint(x_noisy=x_noisy, t=t, cond=cond, batched_number=batched_number, direct_attn=True)
338
+ self.should_apply_effective_masks = self.latent_keyframes is not None or self.mask_cond_hint is not None or self.tk_mask_cond_hint is not None
339
+ self.latent_shape = list(x_noisy.shape)
340
+ # done preparing; model patches will take care of everything now.
341
+ # return normal controlnet stuff
342
+ return control_prev
343
+
344
+ def cleanup_advanced(self):
345
+ super().cleanup_advanced()
346
+ del self.latent_format
347
+ self.latent_format = None
348
+ del self.model_sampling_current
349
+ self.model_sampling_current = None
350
+ self.should_apply_attn_effective_strength = False
351
+ self.should_apply_adain_effective_strength = False
352
+ self.should_apply_effective_masks = False
353
+
354
+ def copy(self):
355
+ c = ReferenceAdvanced(self.ref_opts, self.timestep_keyframes)
356
+ c.order = self.order
357
+ self.copy_to(c)
358
+ self.copy_to_advanced(c)
359
+ return c
360
+
361
+ # avoid deepcopy shenanigans by making deepcopy not do anything to the reference
362
+ # TODO: do the bookkeeping to do this in a proper way for all Adv-ControlNets
363
+ def __deepcopy__(self, memo):
364
+ return self
365
+
366
+
367
+ def ref_noise_latents(latents: Tensor, sigma: Tensor, noise: Tensor=None):
368
+ sigma = sigma.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
369
+ alpha_cumprod = 1 / ((sigma * sigma) + 1)
370
+ sqrt_alpha_prod = alpha_cumprod ** 0.5
371
+ sqrt_one_minus_alpha_prod = (1. - alpha_cumprod) ** 0.5
372
+ if noise is None:
373
+ # generator = torch.Generator(device="cuda")
374
+ # generator.manual_seed(0)
375
+ # noise = torch.empty_like(latents).normal_(generator=generator)
376
+ # generator = torch.Generator()
377
+ # generator.manual_seed(0)
378
+ # noise = torch.randn(latents.size(), generator=generator).to(latents.device)
379
+ noise = torch.randn_like(latents).to(latents.device)
380
+ return sqrt_alpha_prod * latents + sqrt_one_minus_alpha_prod * noise
381
+
382
+
383
+ def simple_noise_latents(latents: Tensor, sigma: float, noise: Tensor=None):
384
+ if noise is None:
385
+ noise = torch.rand_like(latents)
386
+ return latents + noise * sigma
387
+
388
+
389
+ class BankStylesBasicTransformerBlock:
390
+ def __init__(self):
391
+ self.bank = []
392
+ self.style_cfgs = []
393
+ self.cn_idx: list[int] = []
394
+
395
+ def get_avg_style_fidelity(self):
396
+ return sum(self.style_cfgs) / float(len(self.style_cfgs))
397
+
398
+ def clean(self):
399
+ del self.bank
400
+ self.bank = []
401
+ del self.style_cfgs
402
+ self.style_cfgs = []
403
+ del self.cn_idx
404
+ self.cn_idx = []
405
+
406
+
407
+ class BankStylesTimestepEmbedSequential:
408
+ def __init__(self):
409
+ self.var_bank = []
410
+ self.mean_bank = []
411
+ self.style_cfgs = []
412
+ self.cn_idx: list[int] = []
413
+
414
+ def get_avg_var_bank(self):
415
+ return sum(self.var_bank) / float(len(self.var_bank))
416
+
417
+ def get_avg_mean_bank(self):
418
+ return sum(self.mean_bank) / float(len(self.mean_bank))
419
+
420
+ def get_avg_style_fidelity(self):
421
+ return sum(self.style_cfgs) / float(len(self.style_cfgs))
422
+
423
+ def clean(self):
424
+ del self.mean_bank
425
+ self.mean_bank = []
426
+ del self.var_bank
427
+ self.var_bank = []
428
+ del self.style_cfgs
429
+ self.style_cfgs = []
430
+ del self.cn_idx
431
+ self.cn_idx = []
432
+
433
+
434
+ class InjectionBasicTransformerBlockHolder:
435
+ def __init__(self, block: BasicTransformerBlock, idx=None):
436
+ if hasattr(block, "_forward"): # backward compatibility
437
+ self.original_forward = block._forward
438
+ else:
439
+ self.original_forward = block.forward
440
+ self.idx = idx
441
+ self.attn_weight = 1.0
442
+ self.is_middle = False
443
+ self.bank_styles = BankStylesBasicTransformerBlock()
444
+
445
+ def restore(self, block: BasicTransformerBlock):
446
+ if hasattr(block, "_forward"): # backward compatibility
447
+ block._forward = self.original_forward
448
+ else:
449
+ block.forward = self.original_forward
450
+
451
+ def clean(self):
452
+ self.bank_styles.clean()
453
+
454
+
455
+ class InjectionTimestepEmbedSequentialHolder:
456
+ def __init__(self, block: openaimodel.TimestepEmbedSequential, idx=None, is_middle=False, is_input=False, is_output=False):
457
+ self.original_forward = block.forward
458
+ self.idx = idx
459
+ self.gn_weight = 1.0
460
+ self.is_middle = is_middle
461
+ self.is_input = is_input
462
+ self.is_output = is_output
463
+ self.bank_styles = BankStylesTimestepEmbedSequential()
464
+
465
+ def restore(self, block: openaimodel.TimestepEmbedSequential):
466
+ block.forward = self.original_forward
467
+
468
+ def clean(self):
469
+ self.bank_styles.clean()
470
+
471
+
472
+ class ReferenceInjections:
473
+ def __init__(self, attn_modules: list['RefBasicTransformerBlock']=None, gn_modules: list['RefTimestepEmbedSequential']=None):
474
+ self.attn_modules = attn_modules if attn_modules else []
475
+ self.gn_modules = gn_modules if gn_modules else []
476
+ self.diffusion_model_orig_forward: Callable = None
477
+
478
+ def clean_module_mem(self):
479
+ for attn_module in self.attn_modules:
480
+ try:
481
+ attn_module.injection_holder.clean()
482
+ except Exception:
483
+ pass
484
+ for gn_module in self.gn_modules:
485
+ try:
486
+ gn_module.injection_holder.clean()
487
+ except Exception:
488
+ pass
489
+
490
+ def cleanup(self):
491
+ self.clean_module_mem()
492
+ del self.attn_modules
493
+ self.attn_modules = []
494
+ del self.gn_modules
495
+ self.gn_modules = []
496
+ self.diffusion_model_orig_forward = None
497
+
498
+
499
+ def factory_forward_inject_UNetModel(reference_injections: ReferenceInjections):
500
+ def forward_inject_UNetModel(self, x: Tensor, *args, **kwargs):
501
+ # get control and transformer_options from kwargs
502
+ real_args = list(args)
503
+ real_kwargs = list(kwargs.keys())
504
+ control = kwargs.get("control", None)
505
+ transformer_options = kwargs.get("transformer_options", None)
506
+ # look for ReferenceAttnPatch objects to get ReferenceAdvanced objects
507
+ ref_controlnets: list[ReferenceAdvanced] = transformer_options[REF_CONTROL_LIST_ALL]
508
+ # discard any controlnets that should not run
509
+ ref_controlnets = [x for x in ref_controlnets if x.should_run()]
510
+ # if nothing related to reference controlnets, do nothing special
511
+ if len(ref_controlnets) == 0:
512
+ return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs)
513
+ try:
514
+ # assign cond and uncond idxs
515
+ batched_number = len(transformer_options["cond_or_uncond"])
516
+ per_batch = x.shape[0] // batched_number
517
+ indiv_conds = []
518
+ for cond_type in transformer_options["cond_or_uncond"]:
519
+ indiv_conds.extend([cond_type] * per_batch)
520
+ transformer_options[REF_UNCOND_IDXS] = [i for i, x in enumerate(indiv_conds) if x == 1]
521
+ transformer_options[REF_COND_IDXS] = [i for i, x in enumerate(indiv_conds) if x == 0]
522
+ # check which controlnets do which thing
523
+ attn_controlnets = []
524
+ adain_controlnets = []
525
+ for control in ref_controlnets:
526
+ if ReferenceType.is_attn(control.ref_opts.reference_type):
527
+ attn_controlnets.append(control)
528
+ if ReferenceType.is_adain(control.ref_opts.reference_type):
529
+ adain_controlnets.append(control)
530
+ if len(adain_controlnets) > 0:
531
+ # ComfyUI uses forward_timestep_embed with the TimestepEmbedSequential passed into it
532
+ orig_forward_timestep_embed = openaimodel.forward_timestep_embed
533
+ openaimodel.forward_timestep_embed = forward_timestep_embed_ref_inject_factory(orig_forward_timestep_embed)
534
+ # handle running diffusion with ref cond hints
535
+ for control in ref_controlnets:
536
+ if ReferenceType.is_attn(control.ref_opts.reference_type):
537
+ transformer_options[REF_ATTN_MACHINE_STATE] = MachineState.WRITE
538
+ else:
539
+ transformer_options[REF_ATTN_MACHINE_STATE] = MachineState.OFF
540
+ if ReferenceType.is_adain(control.ref_opts.reference_type):
541
+ transformer_options[REF_ADAIN_MACHINE_STATE] = MachineState.WRITE
542
+ else:
543
+ transformer_options[REF_ADAIN_MACHINE_STATE] = MachineState.OFF
544
+ transformer_options[REF_ATTN_CONTROL_LIST] = [control]
545
+ transformer_options[REF_ADAIN_CONTROL_LIST] = [control]
546
+
547
+ orig_kwargs = kwargs
548
+ if not control.ref_opts.ref_with_other_cns:
549
+ kwargs = kwargs.copy()
550
+ kwargs["control"] = None
551
+ reference_injections.diffusion_model_orig_forward(control.cond_hint.to(dtype=x.dtype).to(device=x.device), *args, **kwargs)
552
+ kwargs = orig_kwargs
553
+ # run diffusion for real now
554
+ transformer_options[REF_ATTN_MACHINE_STATE] = MachineState.READ
555
+ transformer_options[REF_ADAIN_MACHINE_STATE] = MachineState.READ
556
+ transformer_options[REF_ATTN_CONTROL_LIST] = attn_controlnets
557
+ transformer_options[REF_ADAIN_CONTROL_LIST] = adain_controlnets
558
+ return reference_injections.diffusion_model_orig_forward(x, *args, **kwargs)
559
+ finally:
560
+ # make sure banks are cleared no matter what happens - otherwise, RIP VRAM
561
+ reference_injections.clean_module_mem()
562
+ if len(adain_controlnets) > 0:
563
+ openaimodel.forward_timestep_embed = orig_forward_timestep_embed
564
+
565
+ return forward_inject_UNetModel
566
+
567
+
568
+ # dummy class just to help IDE keep track of injected variables
569
+ class RefBasicTransformerBlock(BasicTransformerBlock):
570
+ injection_holder: InjectionBasicTransformerBlockHolder = None
571
+
572
+ def _forward_inject_BasicTransformerBlock(self: RefBasicTransformerBlock, x: Tensor, context: Tensor=None, transformer_options: dict[str]={}):
573
+ extra_options = {}
574
+ block = transformer_options.get("block", None)
575
+ block_index = transformer_options.get("block_index", 0)
576
+ transformer_patches = {}
577
+ transformer_patches_replace = {}
578
+
579
+ for k in transformer_options:
580
+ if k == "patches":
581
+ transformer_patches = transformer_options[k]
582
+ elif k == "patches_replace":
583
+ transformer_patches_replace = transformer_options[k]
584
+ else:
585
+ extra_options[k] = transformer_options[k]
586
+
587
+ extra_options["n_heads"] = self.n_heads
588
+ extra_options["dim_head"] = self.d_head
589
+
590
+ if self.ff_in:
591
+ x_skip = x
592
+ x = self.ff_in(self.norm_in(x))
593
+ if self.is_res:
594
+ x += x_skip
595
+
596
+ n: Tensor = self.norm1(x)
597
+ if self.disable_self_attn:
598
+ context_attn1 = context
599
+ else:
600
+ context_attn1 = None
601
+ value_attn1 = None
602
+
603
+ # Reference CN stuff
604
+ uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, [])
605
+ c_idx_mask = transformer_options.get(REF_COND_IDXS, [])
606
+ # WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced
607
+ ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_ATTN_CONTROL_LIST, None)
608
+ ref_machine_state: str = transformer_options.get(REF_ATTN_MACHINE_STATE, None)
609
+ # if in WRITE mode, save n and style_fidelity
610
+ if ref_controlnets and ref_machine_state == MachineState.WRITE:
611
+ if ref_controlnets[0].ref_opts.attn_ref_weight > self.injection_holder.attn_weight:
612
+ self.injection_holder.bank_styles.bank.append(n.detach().clone())
613
+ self.injection_holder.bank_styles.style_cfgs.append(ref_controlnets[0].ref_opts.attn_style_fidelity)
614
+ self.injection_holder.bank_styles.cn_idx.append(ref_controlnets[0].order)
615
+
616
+ if "attn1_patch" in transformer_patches:
617
+ patch = transformer_patches["attn1_patch"]
618
+ if context_attn1 is None:
619
+ context_attn1 = n
620
+ value_attn1 = context_attn1
621
+ for p in patch:
622
+ n, context_attn1, value_attn1 = p(n, context_attn1, value_attn1, extra_options)
623
+
624
+ if block is not None:
625
+ transformer_block = (block[0], block[1], block_index)
626
+ else:
627
+ transformer_block = None
628
+ attn1_replace_patch = transformer_patches_replace.get("attn1", {})
629
+ block_attn1 = transformer_block
630
+ if block_attn1 not in attn1_replace_patch:
631
+ block_attn1 = block
632
+
633
+ if block_attn1 in attn1_replace_patch:
634
+ if context_attn1 is None:
635
+ context_attn1 = n
636
+ value_attn1 = n
637
+ n = self.attn1.to_q(n)
638
+ # Reference CN READ - use attn1_replace_patch appropriately
639
+ if ref_machine_state == MachineState.READ and len(self.injection_holder.bank_styles.bank) > 0:
640
+ bank_styles = self.injection_holder.bank_styles
641
+ style_fidelity = bank_styles.get_avg_style_fidelity()
642
+ real_bank = bank_styles.bank.copy()
643
+ cn_idx = 0
644
+ for idx, order in enumerate(bank_styles.cn_idx):
645
+ # make sure matching ref cn is selected
646
+ for i in range(cn_idx, len(ref_controlnets)):
647
+ if ref_controlnets[i].order == order:
648
+ cn_idx = i
649
+ break
650
+ assert order == ref_controlnets[cn_idx].order
651
+ if ref_controlnets[cn_idx].any_attn_strength_to_apply():
652
+ effective_strength = ref_controlnets[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle)
653
+ real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
654
+ n_uc = self.attn1.to_out(attn1_replace_patch[block_attn1](
655
+ n,
656
+ self.attn1.to_k(torch.cat([context_attn1] + real_bank, dim=1)),
657
+ self.attn1.to_v(torch.cat([value_attn1] + real_bank, dim=1)),
658
+ extra_options))
659
+ n_c = n_uc.clone()
660
+ if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
661
+ n_c[uc_idx_mask] = self.attn1.to_out(attn1_replace_patch[block_attn1](
662
+ n[uc_idx_mask],
663
+ self.attn1.to_k(context_attn1[uc_idx_mask]),
664
+ self.attn1.to_v(value_attn1[uc_idx_mask]),
665
+ extra_options))
666
+ n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc
667
+ bank_styles.clean()
668
+ else:
669
+ context_attn1 = self.attn1.to_k(context_attn1)
670
+ value_attn1 = self.attn1.to_v(value_attn1)
671
+ n = attn1_replace_patch[block_attn1](n, context_attn1, value_attn1, extra_options)
672
+ n = self.attn1.to_out(n)
673
+ else:
674
+ # Reference CN READ - no attn1_replace_patch
675
+ if ref_machine_state == MachineState.READ and len(self.injection_holder.bank_styles.bank) > 0:
676
+ if context_attn1 is None:
677
+ context_attn1 = n
678
+ bank_styles = self.injection_holder.bank_styles
679
+ style_fidelity = bank_styles.get_avg_style_fidelity()
680
+ real_bank = bank_styles.bank.copy()
681
+ cn_idx = 0
682
+ for idx, order in enumerate(bank_styles.cn_idx):
683
+ # make sure matching ref cn is selected
684
+ for i in range(cn_idx, len(ref_controlnets)):
685
+ if ref_controlnets[i].order == order:
686
+ cn_idx = i
687
+ break
688
+ assert order == ref_controlnets[cn_idx].order
689
+ if ref_controlnets[cn_idx].any_attn_strength_to_apply():
690
+ effective_strength = ref_controlnets[cn_idx].get_effective_attn_mask_or_float(x=n, channels=n.shape[2], is_mid=self.injection_holder.is_middle)
691
+ real_bank[idx] = real_bank[idx] * effective_strength + context_attn1 * (1-effective_strength)
692
+ n_uc: Tensor = self.attn1(
693
+ n,
694
+ context=torch.cat([context_attn1] + real_bank, dim=1),
695
+ value=torch.cat([value_attn1] + real_bank, dim=1) if value_attn1 is not None else value_attn1)
696
+ n_c = n_uc.clone()
697
+ if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
698
+ n_c[uc_idx_mask] = self.attn1(
699
+ n[uc_idx_mask],
700
+ context=context_attn1[uc_idx_mask],
701
+ value=value_attn1[uc_idx_mask] if value_attn1 is not None else value_attn1)
702
+ n = style_fidelity * n_c + (1.0-style_fidelity) * n_uc
703
+ bank_styles.clean()
704
+ else:
705
+ n = self.attn1(n, context=context_attn1, value=value_attn1)
706
+
707
+ if "attn1_output_patch" in transformer_patches:
708
+ patch = transformer_patches["attn1_output_patch"]
709
+ for p in patch:
710
+ n = p(n, extra_options)
711
+
712
+ x += n
713
+ if "middle_patch" in transformer_patches:
714
+ patch = transformer_patches["middle_patch"]
715
+ for p in patch:
716
+ x = p(x, extra_options)
717
+
718
+ if self.attn2 is not None:
719
+ n = self.norm2(x)
720
+ if self.switch_temporal_ca_to_sa:
721
+ context_attn2 = n
722
+ else:
723
+ context_attn2 = context
724
+ value_attn2 = None
725
+ if "attn2_patch" in transformer_patches:
726
+ patch = transformer_patches["attn2_patch"]
727
+ value_attn2 = context_attn2
728
+ for p in patch:
729
+ n, context_attn2, value_attn2 = p(n, context_attn2, value_attn2, extra_options)
730
+
731
+ attn2_replace_patch = transformer_patches_replace.get("attn2", {})
732
+ block_attn2 = transformer_block
733
+ if block_attn2 not in attn2_replace_patch:
734
+ block_attn2 = block
735
+
736
+ if block_attn2 in attn2_replace_patch:
737
+ if value_attn2 is None:
738
+ value_attn2 = context_attn2
739
+ n = self.attn2.to_q(n)
740
+ context_attn2 = self.attn2.to_k(context_attn2)
741
+ value_attn2 = self.attn2.to_v(value_attn2)
742
+ n = attn2_replace_patch[block_attn2](n, context_attn2, value_attn2, extra_options)
743
+ n = self.attn2.to_out(n)
744
+ else:
745
+ n = self.attn2(n, context=context_attn2, value=value_attn2)
746
+
747
+ if "attn2_output_patch" in transformer_patches:
748
+ patch = transformer_patches["attn2_output_patch"]
749
+ for p in patch:
750
+ n = p(n, extra_options)
751
+
752
+ x += n
753
+ if self.is_res:
754
+ x_skip = x
755
+ x = self.ff(self.norm3(x))
756
+ if self.is_res:
757
+ x += x_skip
758
+
759
+ return x
760
+
761
+
762
+ class RefTimestepEmbedSequential(openaimodel.TimestepEmbedSequential):
763
+ injection_holder: InjectionTimestepEmbedSequentialHolder = None
764
+
765
+ def forward_timestep_embed_ref_inject_factory(orig_timestep_embed_inject_factory: Callable):
766
+ def forward_timestep_embed_ref_inject(*args, **kwargs):
767
+ ts: RefTimestepEmbedSequential = args[0]
768
+ if not hasattr(ts, "injection_holder"):
769
+ return orig_timestep_embed_inject_factory(*args, **kwargs)
770
+ eps = 1e-6
771
+ x: Tensor = orig_timestep_embed_inject_factory(*args, **kwargs)
772
+ y: Tensor = None
773
+ transformer_options: dict[str] = args[4]
774
+ # Reference CN stuff
775
+ uc_idx_mask = transformer_options.get(REF_UNCOND_IDXS, [])
776
+ c_idx_mask = transformer_options.get(REF_COND_IDXS, [])
777
+ # WRITE mode will only have one ReferenceAdvanced, other modes will have all ReferenceAdvanced
778
+ ref_controlnets: list[ReferenceAdvanced] = transformer_options.get(REF_ADAIN_CONTROL_LIST, None)
779
+ ref_machine_state: str = transformer_options.get(REF_ADAIN_MACHINE_STATE, None)
780
+
781
+ # if in WRITE mode, save var, mean, and style_cfg
782
+ if ref_machine_state == MachineState.WRITE:
783
+ if ref_controlnets[0].ref_opts.adain_ref_weight > ts.injection_holder.gn_weight:
784
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
785
+ ts.injection_holder.bank_styles.var_bank.append(var)
786
+ ts.injection_holder.bank_styles.mean_bank.append(mean)
787
+ ts.injection_holder.bank_styles.style_cfgs.append(ref_controlnets[0].ref_opts.adain_style_fidelity)
788
+ ts.injection_holder.bank_styles.cn_idx.append(ref_controlnets[0].order)
789
+ # if in READ mode, do math with saved var, mean, and style_cfg
790
+ if ref_machine_state == MachineState.READ:
791
+ if len(ts.injection_holder.bank_styles.var_bank) > 0:
792
+ bank_styles = ts.injection_holder.bank_styles
793
+ var, mean = torch.var_mean(x, dim=(2, 3), keepdim=True, correction=0)
794
+ std = torch.maximum(var, torch.zeros_like(var) + eps) ** 0.5
795
+ y_uc = torch.zeros_like(x)
796
+ cn_idx = 0
797
+ for idx, order in enumerate(bank_styles.cn_idx):
798
+ # make sure matching ref cn is selected
799
+ for i in range(cn_idx, len(ref_controlnets)):
800
+ if ref_controlnets[i].order == order:
801
+ cn_idx = i
802
+ break
803
+ assert order == ref_controlnets[cn_idx].order
804
+ style_fidelity = bank_styles.style_cfgs[idx]
805
+ var_acc = bank_styles.var_bank[idx]
806
+ mean_acc = bank_styles.mean_bank[idx]
807
+ std_acc = torch.maximum(var_acc, torch.zeros_like(var_acc) + eps) ** 0.5
808
+ sub_y_uc = (((x - mean) / std) * std_acc) + mean_acc
809
+ if ref_controlnets[cn_idx].any_adain_strength_to_apply():
810
+ effective_strength = ref_controlnets[cn_idx].get_effective_adain_mask_or_float(x=x)
811
+ sub_y_uc = sub_y_uc * effective_strength + x * (1-effective_strength)
812
+ y_uc += sub_y_uc
813
+ # get average, if more than one
814
+ if len(bank_styles.cn_idx) > 1:
815
+ y_uc /= len(bank_styles.cn_idx)
816
+ y_c = y_uc.clone()
817
+ if len(uc_idx_mask) > 0 and not math.isclose(style_fidelity, 0.0):
818
+ y_c[uc_idx_mask] = x.to(y_c.dtype)[uc_idx_mask]
819
+ y = style_fidelity * y_c + (1.0 - style_fidelity) * y_uc
820
+ ts.injection_holder.bank_styles.clean()
821
+
822
+ if y is None:
823
+ y = x
824
+ return y.to(x.dtype)
825
+
826
+ return forward_timestep_embed_ref_inject
827
+
828
+ # DFS Search for Torch.nn.Module, Written by Lvmin
829
+ def torch_dfs(model: torch.nn.Module):
830
+ result = [model]
831
+ for child in model.children():
832
+ result += torch_dfs(child)
833
+ return result
ComfyUI-Advanced-ControlNet/adv_control/control_sparsectrl.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #taken from: https://github.com/lllyasviel/ControlNet
2
+ #and modified
3
+ #and then taken from comfy/cldm/cldm.py and modified again
4
+
5
+ from abc import ABC, abstractmethod
6
+ import math
7
+ import numpy as np
8
+ from typing import Iterable, Union
9
+ import torch
10
+ import torch as th
11
+ import torch.nn as nn
12
+ from torch import Tensor
13
+ from einops import rearrange, repeat
14
+
15
+ from comfy.ldm.modules.diffusionmodules.util import (
16
+ zero_module,
17
+ timestep_embedding,
18
+ )
19
+
20
+ from comfy.cli_args import args
21
+ from comfy.cldm.cldm import ControlNet as ControlNetCLDM
22
+ from comfy.ldm.modules.attention import SpatialTransformer
23
+ from comfy.ldm.modules.attention import attention_basic, attention_pytorch, attention_split, attention_sub_quad, default
24
+ from comfy.ldm.modules.attention import FeedForward, SpatialTransformer
25
+ from comfy.ldm.modules.diffusionmodules.openaimodel import TimestepEmbedSequential, ResBlock, Downsample
26
+ from comfy.model_patcher import ModelPatcher
27
+ from comfy.controlnet import broadcast_image_to
28
+ from comfy.utils import repeat_to_batch_size
29
+ import comfy.ops
30
+ import comfy.model_management
31
+
32
+ from .utils import TimestepKeyframeGroup, disable_weight_init_clean_groupnorm, prepare_mask_batch
33
+
34
+
35
+ # until xformers bug is fixed, do not use xformers for VersatileAttention! TODO: change this when fix is out
36
+ # logic for choosing optimized_attention method taken from comfy/ldm/modules/attention.py
37
+ optimized_attention_mm = attention_basic
38
+ if comfy.model_management.xformers_enabled():
39
+ pass
40
+ #optimized_attention_mm = attention_xformers
41
+ if comfy.model_management.pytorch_attention_enabled():
42
+ optimized_attention_mm = attention_pytorch
43
+ else:
44
+ if args.use_split_cross_attention:
45
+ optimized_attention_mm = attention_split
46
+ else:
47
+ optimized_attention_mm = attention_sub_quad
48
+
49
+
50
+ class SparseControlNet(ControlNetCLDM):
51
+ def __init__(self, *args,**kwargs):
52
+ super().__init__(*args, **kwargs)
53
+ hint_channels = kwargs.get("hint_channels")
54
+ operations: disable_weight_init_clean_groupnorm = kwargs.get("operations", disable_weight_init_clean_groupnorm)
55
+ device = kwargs.get("device", None)
56
+ self.use_simplified_conditioning_embedding = kwargs.get("use_simplified_conditioning_embedding", False)
57
+ if self.use_simplified_conditioning_embedding:
58
+ self.input_hint_block = TimestepEmbedSequential(
59
+ zero_module(operations.conv_nd(self.dims, hint_channels, self.model_channels, 3, padding=1, dtype=self.dtype, device=device)),
60
+ )
61
+ self.motion_wrapper: SparseCtrlMotionWrapper = None
62
+
63
+ def set_actual_length(self, actual_length: int, full_length: int):
64
+ if self.motion_wrapper is not None:
65
+ self.motion_wrapper.set_video_length(video_length=actual_length, full_length=full_length)
66
+
67
+ def forward(self, x: Tensor, hint: Tensor, timesteps, context, y=None, **kwargs):
68
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
69
+ emb = self.time_embed(t_emb)
70
+
71
+ # SparseCtrl sets noisy input to zeros
72
+ x = torch.zeros_like(x)
73
+ guided_hint = self.input_hint_block(hint, emb, context)
74
+
75
+ outs = []
76
+
77
+ hs = []
78
+ if self.num_classes is not None:
79
+ assert y.shape[0] == x.shape[0]
80
+ emb = emb + self.label_emb(y)
81
+
82
+ h = x
83
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
84
+ if guided_hint is not None:
85
+ h = module(h, emb, context)
86
+ h += guided_hint
87
+ guided_hint = None
88
+ else:
89
+ h = module(h, emb, context)
90
+ outs.append(zero_conv(h, emb, context))
91
+
92
+ h = self.middle_block(h, emb, context)
93
+ outs.append(self.middle_block_out(h, emb, context))
94
+
95
+ return outs
96
+
97
+
98
+ class SparseModelPatcher(ModelPatcher):
99
+ def __init__(self, *args, **kwargs):
100
+ self.model: SparseControlNet
101
+ super().__init__(*args, **kwargs)
102
+
103
+ def patch_model(self, device_to=None, patch_weights=True):
104
+ if patch_weights:
105
+ patched_model = super().patch_model(device_to)
106
+ else:
107
+ patched_model = super().patch_model(device_to, patch_weights)
108
+ try:
109
+ if self.model.motion_wrapper is not None:
110
+ self.model.motion_wrapper.to(device=device_to)
111
+ except Exception:
112
+ pass
113
+ return patched_model
114
+
115
+ def unpatch_model(self, device_to=None, unpatch_weights=True):
116
+ try:
117
+ if self.model.motion_wrapper is not None:
118
+ self.model.motion_wrapper.to(device=device_to)
119
+ except Exception:
120
+ pass
121
+ if unpatch_weights:
122
+ return super().unpatch_model(device_to)
123
+ else:
124
+ return super().unpatch_model(device_to, unpatch_weights)
125
+
126
+ def clone(self):
127
+ # normal ModelPatcher clone actions
128
+ n = SparseModelPatcher(self.model, self.load_device, self.offload_device, self.size, self.current_device, weight_inplace_update=self.weight_inplace_update)
129
+ n.patches = {}
130
+ for k in self.patches:
131
+ n.patches[k] = self.patches[k][:]
132
+ if hasattr(n, "patches_uuid"):
133
+ self.patches_uuid = n.patches_uuid
134
+
135
+ n.object_patches = self.object_patches.copy()
136
+ n.model_options = copy.deepcopy(self.model_options)
137
+ n.model_keys = self.model_keys
138
+ if hasattr(n, "backup"):
139
+ self.backup = n.backup
140
+ if hasattr(n, "object_patches_backup"):
141
+ self.object_patches_backup = n.object_patches_backup
142
+
143
+
144
+ class PreprocSparseRGBWrapper:
145
+ error_msg = "Invalid use of RGB SparseCtrl output. The output of RGB SparseCtrl preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
146
+ def __init__(self, condhint: Tensor):
147
+ self.condhint = condhint
148
+
149
+ def movedim(self, *args, **kwargs):
150
+ return self
151
+
152
+ def __getattr__(self, *args, **kwargs):
153
+ raise AttributeError(self.error_msg)
154
+
155
+ def __setattr__(self, name, value):
156
+ if name != "condhint":
157
+ raise AttributeError(self.error_msg)
158
+ super().__setattr__(name, value)
159
+
160
+ def __iter__(self, *args, **kwargs):
161
+ raise AttributeError(self.error_msg)
162
+
163
+ def __next__(self, *args, **kwargs):
164
+ raise AttributeError(self.error_msg)
165
+
166
+ def __len__(self, *args, **kwargs):
167
+ raise AttributeError(self.error_msg)
168
+
169
+ def __getitem__(self, *args, **kwargs):
170
+ raise AttributeError(self.error_msg)
171
+
172
+ def __setitem__(self, *args, **kwargs):
173
+ raise AttributeError(self.error_msg)
174
+
175
+
176
+ class SparseSettings:
177
+ def __init__(self, sparse_method: 'SparseMethod', use_motion: bool=True, motion_strength=1.0, motion_scale=1.0, merged=False):
178
+ self.sparse_method = sparse_method
179
+ self.use_motion = use_motion
180
+ self.motion_strength = motion_strength
181
+ self.motion_scale = motion_scale
182
+ self.merged = merged
183
+
184
+ @classmethod
185
+ def default(cls):
186
+ return SparseSettings(sparse_method=SparseSpreadMethod(), use_motion=True)
187
+
188
+
189
+ class SparseMethod(ABC):
190
+ SPREAD = "spread"
191
+ INDEX = "index"
192
+ def __init__(self, method: str):
193
+ self.method = method
194
+
195
+ @abstractmethod
196
+ def get_indexes(self, hint_length: int, full_length: int) -> list[int]:
197
+ pass
198
+
199
+
200
+ class SparseSpreadMethod(SparseMethod):
201
+ UNIFORM = "uniform"
202
+ STARTING = "starting"
203
+ ENDING = "ending"
204
+ CENTER = "center"
205
+
206
+ LIST = [UNIFORM, STARTING, ENDING, CENTER]
207
+
208
+ def __init__(self, spread=UNIFORM):
209
+ super().__init__(self.SPREAD)
210
+ self.spread = spread
211
+
212
+ def get_indexes(self, hint_length: int, full_length: int) -> list[int]:
213
+ # if hint_length >= full_length, limit hints to full_length
214
+ if hint_length >= full_length:
215
+ return list(range(full_length))
216
+ # handle special case of 1 hint image
217
+ if hint_length == 1:
218
+ if self.spread in [self.UNIFORM, self.STARTING]:
219
+ return [0]
220
+ elif self.spread == self.ENDING:
221
+ return [full_length-1]
222
+ elif self.spread == self.CENTER:
223
+ # return second (of three) values as the center
224
+ return [np.linspace(0, full_length-1, 3, endpoint=True, dtype=int)[1]]
225
+ else:
226
+ raise ValueError(f"Unrecognized spread: {self.spread}")
227
+ # otherwise, handle other cases
228
+ if self.spread == self.UNIFORM:
229
+ return list(np.linspace(0, full_length-1, hint_length, endpoint=True, dtype=int))
230
+ elif self.spread == self.STARTING:
231
+ # make split 1 larger, remove last element
232
+ return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
233
+ elif self.spread == self.ENDING:
234
+ # make split 1 larger, remove first element
235
+ return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[1:]
236
+ elif self.spread == self.CENTER:
237
+ # if hint length is not 3 greater than full length, do STARTING behavior
238
+ if full_length-hint_length < 3:
239
+ return list(np.linspace(0, full_length-1, hint_length+1, endpoint=True, dtype=int))[:-1]
240
+ # otherwise, get linspace of 2 greater than needed, then cut off first and last
241
+ return list(np.linspace(0, full_length-1, hint_length+2, endpoint=True, dtype=int))[1:-1]
242
+ return ValueError(f"Unrecognized spread: {self.spread}")
243
+
244
+
245
+ class SparseIndexMethod(SparseMethod):
246
+ def __init__(self, idxs: list[int]):
247
+ super().__init__(self.INDEX)
248
+ self.idxs = idxs
249
+
250
+ def get_indexes(self, hint_length: int, full_length: int) -> list[int]:
251
+ orig_hint_length = hint_length
252
+ if hint_length > full_length:
253
+ hint_length = full_length
254
+ # if idxs is less than hint_length, throw error
255
+ if len(self.idxs) < hint_length:
256
+ err_msg = f"There are not enough indexes ({len(self.idxs)}) provided to fit the usable {hint_length} input images."
257
+ if orig_hint_length != hint_length:
258
+ err_msg = f"{err_msg} (original input images: {orig_hint_length})"
259
+ raise ValueError(err_msg)
260
+ # cap idxs to hint_length
261
+ idxs = self.idxs[:hint_length]
262
+ new_idxs = []
263
+ real_idxs = set()
264
+ for idx in idxs:
265
+ if idx < 0:
266
+ real_idx = full_length+idx
267
+ if real_idx in real_idxs:
268
+ raise ValueError(f"Index '{idx}' maps to '{real_idx}' and is duplicate - indexes in Sparse Index Method must be unique.")
269
+ else:
270
+ real_idx = idx
271
+ if real_idx in real_idxs:
272
+ raise ValueError(f"Index '{idx}' is duplicate (or a negative index is equivalent) - indexes in Sparse Index Method must be unique.")
273
+ real_idxs.add(real_idx)
274
+ new_idxs.append(real_idx)
275
+ return new_idxs
276
+
277
+
278
+ #########################################
279
+ # motion-related portion of controlnet
280
+ class BlockType:
281
+ UP = "up"
282
+ DOWN = "down"
283
+ MID = "mid"
284
+
285
+ def get_down_block_max(mm_state_dict: dict[str, Tensor]) -> int:
286
+ return get_block_max(mm_state_dict, "down_blocks")
287
+
288
+ def get_up_block_max(mm_state_dict: dict[str, Tensor]) -> int:
289
+ return get_block_max(mm_state_dict, "up_blocks")
290
+
291
+ def get_block_max(mm_state_dict: dict[str, Tensor], block_name: str) -> int:
292
+ # keep track of biggest down_block count in module
293
+ biggest_block = -1
294
+ for key in mm_state_dict.keys():
295
+ if block_name in key:
296
+ try:
297
+ block_int = key.split(".")[1]
298
+ block_num = int(block_int)
299
+ if block_num > biggest_block:
300
+ biggest_block = block_num
301
+ except ValueError:
302
+ pass
303
+ return biggest_block
304
+
305
+ def has_mid_block(mm_state_dict: dict[str, Tensor]):
306
+ # check if keys contain mid_block
307
+ for key in mm_state_dict.keys():
308
+ if key.startswith("mid_block."):
309
+ return True
310
+ return False
311
+
312
+ def get_position_encoding_max_len(mm_state_dict: dict[str, Tensor], mm_name: str=None) -> int:
313
+ # use pos_encoder.pe entries to determine max length - [1, {max_length}, {320|640|1280}]
314
+ for key in mm_state_dict.keys():
315
+ if key.endswith("pos_encoder.pe"):
316
+ return mm_state_dict[key].size(1) # get middle dim
317
+ raise ValueError(f"No pos_encoder.pe found in SparseCtrl state_dict - {mm_name} is not a valid SparseCtrl model!")
318
+
319
+
320
+ class SparseCtrlMotionWrapper(nn.Module):
321
+ def __init__(self, mm_state_dict: dict[str, Tensor]):
322
+ super().__init__()
323
+ self.down_blocks: Iterable[MotionModule] = None
324
+ self.up_blocks: Iterable[MotionModule] = None
325
+ self.mid_block: MotionModule = None
326
+ self.encoding_max_len = get_position_encoding_max_len(mm_state_dict, "")
327
+ layer_channels = (320, 640, 1280, 1280)
328
+ if get_down_block_max(mm_state_dict) > -1:
329
+ self.down_blocks = nn.ModuleList([])
330
+ for c in layer_channels:
331
+ self.down_blocks.append(MotionModule(c, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.DOWN))
332
+ if get_up_block_max(mm_state_dict) > -1:
333
+ self.up_blocks = nn.ModuleList([])
334
+ for c in reversed(layer_channels):
335
+ self.up_blocks.append(MotionModule(c, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.UP))
336
+ if has_mid_block(mm_state_dict):
337
+ self.mid_block = MotionModule(1280, temporal_position_encoding_max_len=self.encoding_max_len, block_type=BlockType.MID)
338
+
339
+ def inject(self, unet: SparseControlNet):
340
+ # inject input (down) blocks
341
+ self._inject(unet.input_blocks, self.down_blocks)
342
+ # inject mid block, if present
343
+ if self.mid_block is not None:
344
+ self._inject([unet.middle_block], [self.mid_block])
345
+ unet.motion_wrapper = self
346
+
347
+ def _inject(self, unet_blocks: nn.ModuleList, mm_blocks: nn.ModuleList):
348
+ # Rules for injection:
349
+ # For each component list in a unet block:
350
+ # if SpatialTransformer exists in list, place next block after last occurrence
351
+ # elif ResBlock exists in list, place next block after first occurrence
352
+ # else don't place block
353
+ injection_count = 0
354
+ unet_idx = 0
355
+ # details about blocks passed in
356
+ per_block = len(mm_blocks[0].motion_modules)
357
+ injection_goal = len(mm_blocks) * per_block
358
+ # only stop injecting when modules exhausted
359
+ while injection_count < injection_goal:
360
+ # figure out which VanillaTemporalModule from mm to inject
361
+ mm_blk_idx, mm_vtm_idx = injection_count // per_block, injection_count % per_block
362
+ # figure out layout of unet block components
363
+ st_idx = -1 # SpatialTransformer index
364
+ res_idx = -1 # first ResBlock index
365
+ # first, figure out indeces of relevant blocks
366
+ for idx, component in enumerate(unet_blocks[unet_idx]):
367
+ if type(component) == SpatialTransformer:
368
+ st_idx = idx
369
+ elif type(component).__name__ == "ResBlock" and res_idx < 0:
370
+ res_idx = idx
371
+ # if SpatialTransformer exists, inject right after
372
+ if st_idx >= 0:
373
+ unet_blocks[unet_idx].insert(st_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx])
374
+ injection_count += 1
375
+ # otherwise, if only ResBlock exists, inject right after
376
+ elif res_idx >= 0:
377
+ unet_blocks[unet_idx].insert(res_idx+1, mm_blocks[mm_blk_idx].motion_modules[mm_vtm_idx])
378
+ injection_count += 1
379
+ # increment unet_idx
380
+ unet_idx += 1
381
+
382
+ def eject(self, unet: SparseControlNet):
383
+ # remove from input blocks (downblocks)
384
+ self._eject(unet.input_blocks)
385
+ # remove from middle block (encapsulate in list to make compatible)
386
+ self._eject([unet.middle_block])
387
+ del unet.motion_wrapper
388
+ unet.motion_wrapper = None
389
+
390
+ def _eject(self, unet_blocks: nn.ModuleList):
391
+ # eject all VanillaTemporalModule objects from all blocks
392
+ for block in unet_blocks:
393
+ idx_to_pop = []
394
+ for idx, component in enumerate(block):
395
+ if type(component) == VanillaTemporalModule:
396
+ idx_to_pop.append(idx)
397
+ # pop in backwards order, as to not disturb what the indeces refer to
398
+ for idx in sorted(idx_to_pop, reverse=True):
399
+ block.pop(idx)
400
+
401
+ def set_video_length(self, video_length: int, full_length: int):
402
+ self.AD_video_length = video_length
403
+ if self.down_blocks is not None:
404
+ for block in self.down_blocks:
405
+ block.set_video_length(video_length, full_length)
406
+ if self.up_blocks is not None:
407
+ for block in self.up_blocks:
408
+ block.set_video_length(video_length, full_length)
409
+ if self.mid_block is not None:
410
+ self.mid_block.set_video_length(video_length, full_length)
411
+
412
+ def set_scale_multiplier(self, multiplier: Union[float, None]):
413
+ if self.down_blocks is not None:
414
+ for block in self.down_blocks:
415
+ block.set_scale_multiplier(multiplier)
416
+ if self.up_blocks is not None:
417
+ for block in self.up_blocks:
418
+ block.set_scale_multiplier(multiplier)
419
+ if self.mid_block is not None:
420
+ self.mid_block.set_scale_multiplier(multiplier)
421
+
422
+ def set_strength(self, strength: float):
423
+ if self.down_blocks is not None:
424
+ for block in self.down_blocks:
425
+ block.set_strength(strength)
426
+ if self.up_blocks is not None:
427
+ for block in self.up_blocks:
428
+ block.set_strength(strength)
429
+ if self.mid_block is not None:
430
+ self.mid_block.set_strength(strength)
431
+
432
+ def reset_temp_vars(self):
433
+ if self.down_blocks is not None:
434
+ for block in self.down_blocks:
435
+ block.reset_temp_vars()
436
+ if self.up_blocks is not None:
437
+ for block in self.up_blocks:
438
+ block.reset_temp_vars()
439
+ if self.mid_block is not None:
440
+ self.mid_block.reset_temp_vars()
441
+
442
+ def reset_scale_multiplier(self):
443
+ self.set_scale_multiplier(None)
444
+
445
+ def reset(self):
446
+ self.reset_scale_multiplier()
447
+ self.reset_temp_vars()
448
+
449
+
450
+ class MotionModule(nn.Module):
451
+ def __init__(self, in_channels, temporal_position_encoding_max_len=24, block_type: str=BlockType.DOWN):
452
+ super().__init__()
453
+ if block_type == BlockType.MID:
454
+ # mid blocks contain only a single VanillaTemporalModule
455
+ self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList([get_motion_module(in_channels, temporal_position_encoding_max_len)])
456
+ else:
457
+ # down blocks contain two VanillaTemporalModules
458
+ self.motion_modules: Iterable[VanillaTemporalModule] = nn.ModuleList(
459
+ [
460
+ get_motion_module(in_channels, temporal_position_encoding_max_len),
461
+ get_motion_module(in_channels, temporal_position_encoding_max_len)
462
+ ]
463
+ )
464
+ # up blocks contain one additional VanillaTemporalModule
465
+ if block_type == BlockType.UP:
466
+ self.motion_modules.append(get_motion_module(in_channels, temporal_position_encoding_max_len))
467
+
468
+ def set_video_length(self, video_length: int, full_length: int):
469
+ for motion_module in self.motion_modules:
470
+ motion_module.set_video_length(video_length, full_length)
471
+
472
+ def set_scale_multiplier(self, multiplier: Union[float, None]):
473
+ for motion_module in self.motion_modules:
474
+ motion_module.set_scale_multiplier(multiplier)
475
+
476
+ def set_masks(self, masks: Tensor, min_val: float, max_val: float):
477
+ for motion_module in self.motion_modules:
478
+ motion_module.set_masks(masks, min_val, max_val)
479
+
480
+ def set_sub_idxs(self, sub_idxs: list[int]):
481
+ for motion_module in self.motion_modules:
482
+ motion_module.set_sub_idxs(sub_idxs)
483
+
484
+ def set_strength(self, strength: float):
485
+ for motion_module in self.motion_modules:
486
+ motion_module.set_strength(strength)
487
+
488
+ def reset_temp_vars(self):
489
+ for motion_module in self.motion_modules:
490
+ motion_module.reset_temp_vars()
491
+
492
+
493
+ def get_motion_module(in_channels, temporal_position_encoding_max_len):
494
+ # unlike normal AD, there is only one attention block expected in SparseCtrl models
495
+ return VanillaTemporalModule(in_channels=in_channels, attention_block_types=("Temporal_Self",), temporal_position_encoding_max_len=temporal_position_encoding_max_len)
496
+
497
+
498
+ class VanillaTemporalModule(nn.Module):
499
+ def __init__(
500
+ self,
501
+ in_channels,
502
+ num_attention_heads=8,
503
+ num_transformer_block=1,
504
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
505
+ cross_frame_attention_mode=None,
506
+ temporal_position_encoding=True,
507
+ temporal_position_encoding_max_len=24,
508
+ temporal_attention_dim_div=1,
509
+ zero_initialize=True,
510
+ ):
511
+ super().__init__()
512
+ self.strength = 1.0
513
+ self.temporal_transformer = TemporalTransformer3DModel(
514
+ in_channels=in_channels,
515
+ num_attention_heads=num_attention_heads,
516
+ attention_head_dim=in_channels
517
+ // num_attention_heads
518
+ // temporal_attention_dim_div,
519
+ num_layers=num_transformer_block,
520
+ attention_block_types=attention_block_types,
521
+ cross_frame_attention_mode=cross_frame_attention_mode,
522
+ temporal_position_encoding=temporal_position_encoding,
523
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
524
+ )
525
+
526
+ if zero_initialize:
527
+ self.temporal_transformer.proj_out = zero_module(
528
+ self.temporal_transformer.proj_out
529
+ )
530
+
531
+ def set_video_length(self, video_length: int, full_length: int):
532
+ self.temporal_transformer.set_video_length(video_length, full_length)
533
+
534
+ def set_scale_multiplier(self, multiplier: Union[float, None]):
535
+ self.temporal_transformer.set_scale_multiplier(multiplier)
536
+
537
+ def set_masks(self, masks: Tensor, min_val: float, max_val: float):
538
+ self.temporal_transformer.set_masks(masks, min_val, max_val)
539
+
540
+ def set_sub_idxs(self, sub_idxs: list[int]):
541
+ self.temporal_transformer.set_sub_idxs(sub_idxs)
542
+
543
+ def set_strength(self, strength: float):
544
+ self.strength = strength
545
+
546
+ def reset_temp_vars(self):
547
+ self.set_strength(1.0)
548
+ self.temporal_transformer.reset_temp_vars()
549
+
550
+ def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None):
551
+ if math.isclose(self.strength, 1.0):
552
+ return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)
553
+ elif math.isclose(self.strength, 0.0):
554
+ return input_tensor
555
+ elif self.strength > 1.0:
556
+ return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)*self.strength
557
+ else:
558
+ return self.temporal_transformer(input_tensor, encoder_hidden_states, attention_mask)*self.strength + input_tensor*(1.0-self.strength)
559
+
560
+
561
+ class TemporalTransformer3DModel(nn.Module):
562
+ def __init__(
563
+ self,
564
+ in_channels,
565
+ num_attention_heads,
566
+ attention_head_dim,
567
+ num_layers,
568
+ attention_block_types=(
569
+ "Temporal_Self",
570
+ "Temporal_Self",
571
+ ),
572
+ dropout=0.0,
573
+ norm_num_groups=32,
574
+ cross_attention_dim=768,
575
+ activation_fn="geglu",
576
+ attention_bias=False,
577
+ upcast_attention=False,
578
+ cross_frame_attention_mode=None,
579
+ temporal_position_encoding=False,
580
+ temporal_position_encoding_max_len=24,
581
+ ):
582
+ super().__init__()
583
+ self.video_length = 16
584
+ self.full_length = 16
585
+ self.scale_min = 1.0
586
+ self.scale_max = 1.0
587
+ self.raw_scale_mask: Union[Tensor, None] = None
588
+ self.temp_scale_mask: Union[Tensor, None] = None
589
+ self.sub_idxs: Union[list[int], None] = None
590
+ self.prev_hidden_states_batch = 0
591
+
592
+
593
+ inner_dim = num_attention_heads * attention_head_dim
594
+
595
+ self.norm = disable_weight_init_clean_groupnorm.GroupNorm(
596
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
597
+ )
598
+ self.proj_in = nn.Linear(in_channels, inner_dim)
599
+
600
+ self.transformer_blocks: Iterable[TemporalTransformerBlock] = nn.ModuleList(
601
+ [
602
+ TemporalTransformerBlock(
603
+ dim=inner_dim,
604
+ num_attention_heads=num_attention_heads,
605
+ attention_head_dim=attention_head_dim,
606
+ attention_block_types=attention_block_types,
607
+ dropout=dropout,
608
+ norm_num_groups=norm_num_groups,
609
+ cross_attention_dim=cross_attention_dim,
610
+ activation_fn=activation_fn,
611
+ attention_bias=attention_bias,
612
+ upcast_attention=upcast_attention,
613
+ cross_frame_attention_mode=cross_frame_attention_mode,
614
+ temporal_position_encoding=temporal_position_encoding,
615
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
616
+ )
617
+ for d in range(num_layers)
618
+ ]
619
+ )
620
+ self.proj_out = nn.Linear(inner_dim, in_channels)
621
+
622
+ def set_video_length(self, video_length: int, full_length: int):
623
+ self.video_length = video_length
624
+ self.full_length = full_length
625
+
626
+ def set_scale_multiplier(self, multiplier: Union[float, None]):
627
+ for block in self.transformer_blocks:
628
+ block.set_scale_multiplier(multiplier)
629
+
630
+ def set_masks(self, masks: Tensor, min_val: float, max_val: float):
631
+ self.scale_min = min_val
632
+ self.scale_max = max_val
633
+ self.raw_scale_mask = masks
634
+
635
+ def set_sub_idxs(self, sub_idxs: list[int]):
636
+ self.sub_idxs = sub_idxs
637
+ for block in self.transformer_blocks:
638
+ block.set_sub_idxs(sub_idxs)
639
+
640
+ def reset_temp_vars(self):
641
+ del self.temp_scale_mask
642
+ self.temp_scale_mask = None
643
+ self.prev_hidden_states_batch = 0
644
+
645
+ def get_scale_mask(self, hidden_states: Tensor) -> Union[Tensor, None]:
646
+ # if no raw mask, return None
647
+ if self.raw_scale_mask is None:
648
+ return None
649
+ shape = hidden_states.shape
650
+ batch, channel, height, width = shape
651
+ # if temp mask already calculated, return it
652
+ if self.temp_scale_mask != None:
653
+ # check if hidden_states batch matches
654
+ if batch == self.prev_hidden_states_batch:
655
+ if self.sub_idxs is not None:
656
+ return self.temp_scale_mask[:, self.sub_idxs, :]
657
+ return self.temp_scale_mask
658
+ # if does not match, reset cached temp_scale_mask and recalculate it
659
+ del self.temp_scale_mask
660
+ self.temp_scale_mask = None
661
+ # otherwise, calculate temp mask
662
+ self.prev_hidden_states_batch = batch
663
+ mask = prepare_mask_batch(self.raw_scale_mask, shape=(self.full_length, 1, height, width))
664
+ mask = repeat_to_batch_size(mask, self.full_length)
665
+ # if mask not the same amount length as full length, make it match
666
+ if self.full_length != mask.shape[0]:
667
+ mask = broadcast_image_to(mask, self.full_length, 1)
668
+ # reshape mask to attention K shape (h*w, latent_count, 1)
669
+ batch, channel, height, width = mask.shape
670
+ # first, perform same operations as on hidden_states,
671
+ # turning (b, c, h, w) -> (b, h*w, c)
672
+ mask = mask.permute(0, 2, 3, 1).reshape(batch, height*width, channel)
673
+ # then, make it the same shape as attention's k, (h*w, b, c)
674
+ mask = mask.permute(1, 0, 2)
675
+ # make masks match the expected length of h*w
676
+ batched_number = shape[0] // self.video_length
677
+ if batched_number > 1:
678
+ mask = torch.cat([mask] * batched_number, dim=0)
679
+ # cache mask and set to proper device
680
+ self.temp_scale_mask = mask
681
+ # move temp_scale_mask to proper dtype + device
682
+ self.temp_scale_mask = self.temp_scale_mask.to(dtype=hidden_states.dtype, device=hidden_states.device)
683
+ # return subset of masks, if needed
684
+ if self.sub_idxs is not None:
685
+ return self.temp_scale_mask[:, self.sub_idxs, :]
686
+ return self.temp_scale_mask
687
+
688
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
689
+ batch, channel, height, width = hidden_states.shape
690
+ residual = hidden_states
691
+ scale_mask = self.get_scale_mask(hidden_states)
692
+ # add some casts for fp8 purposes - does not affect speed otherwise
693
+ hidden_states = self.norm(hidden_states).to(hidden_states.dtype)
694
+ inner_dim = hidden_states.shape[1]
695
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
696
+ batch, height * width, inner_dim
697
+ )
698
+ hidden_states = self.proj_in(hidden_states).to(hidden_states.dtype)
699
+
700
+ # Transformer Blocks
701
+ for block in self.transformer_blocks:
702
+ hidden_states = block(
703
+ hidden_states,
704
+ encoder_hidden_states=encoder_hidden_states,
705
+ attention_mask=attention_mask,
706
+ video_length=self.video_length,
707
+ scale_mask=scale_mask
708
+ )
709
+
710
+ # output
711
+ hidden_states = self.proj_out(hidden_states)
712
+ hidden_states = (
713
+ hidden_states.reshape(batch, height, width, inner_dim)
714
+ .permute(0, 3, 1, 2)
715
+ .contiguous()
716
+ )
717
+
718
+ output = hidden_states + residual
719
+
720
+ return output
721
+
722
+
723
+ class TemporalTransformerBlock(nn.Module):
724
+ def __init__(
725
+ self,
726
+ dim,
727
+ num_attention_heads,
728
+ attention_head_dim,
729
+ attention_block_types=(
730
+ "Temporal_Self",
731
+ "Temporal_Self",
732
+ ),
733
+ dropout=0.0,
734
+ norm_num_groups=32,
735
+ cross_attention_dim=768,
736
+ activation_fn="geglu",
737
+ attention_bias=False,
738
+ upcast_attention=False,
739
+ cross_frame_attention_mode=None,
740
+ temporal_position_encoding=False,
741
+ temporal_position_encoding_max_len=24,
742
+ ):
743
+ super().__init__()
744
+
745
+ attention_blocks = []
746
+ norms = []
747
+
748
+ for block_name in attention_block_types:
749
+ attention_blocks.append(
750
+ VersatileAttention(
751
+ attention_mode=block_name.split("_")[0],
752
+ context_dim=cross_attention_dim # called context_dim for ComfyUI impl
753
+ if block_name.endswith("_Cross")
754
+ else None,
755
+ query_dim=dim,
756
+ heads=num_attention_heads,
757
+ dim_head=attention_head_dim,
758
+ dropout=dropout,
759
+ #bias=attention_bias, # remove for Comfy CrossAttention
760
+ #upcast_attention=upcast_attention, # remove for Comfy CrossAttention
761
+ cross_frame_attention_mode=cross_frame_attention_mode,
762
+ temporal_position_encoding=temporal_position_encoding,
763
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
764
+ )
765
+ )
766
+ norms.append(nn.LayerNorm(dim))
767
+
768
+ self.attention_blocks: Iterable[VersatileAttention] = nn.ModuleList(attention_blocks)
769
+ self.norms = nn.ModuleList(norms)
770
+
771
+ self.ff = FeedForward(dim, dropout=dropout, glu=(activation_fn == "geglu"))
772
+ self.ff_norm = nn.LayerNorm(dim)
773
+
774
+ def set_scale_multiplier(self, multiplier: Union[float, None]):
775
+ for block in self.attention_blocks:
776
+ block.set_scale_multiplier(multiplier)
777
+
778
+ def set_sub_idxs(self, sub_idxs: list[int]):
779
+ for block in self.attention_blocks:
780
+ block.set_sub_idxs(sub_idxs)
781
+
782
+ def forward(
783
+ self,
784
+ hidden_states,
785
+ encoder_hidden_states=None,
786
+ attention_mask=None,
787
+ video_length=None,
788
+ scale_mask=None
789
+ ):
790
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
791
+ norm_hidden_states = norm(hidden_states).to(hidden_states.dtype)
792
+ hidden_states = (
793
+ attention_block(
794
+ norm_hidden_states,
795
+ encoder_hidden_states=encoder_hidden_states
796
+ if attention_block.is_cross_attention
797
+ else None,
798
+ attention_mask=attention_mask,
799
+ video_length=video_length,
800
+ scale_mask=scale_mask
801
+ )
802
+ + hidden_states
803
+ )
804
+
805
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
806
+
807
+ output = hidden_states
808
+ return output
809
+
810
+
811
+ class PositionalEncoding(nn.Module):
812
+ def __init__(self, d_model, dropout=0.0, max_len=24):
813
+ super().__init__()
814
+ self.dropout = nn.Dropout(p=dropout)
815
+ position = torch.arange(max_len).unsqueeze(1)
816
+ div_term = torch.exp(
817
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
818
+ )
819
+ pe = torch.zeros(1, max_len, d_model)
820
+ pe[0, :, 0::2] = torch.sin(position * div_term)
821
+ pe[0, :, 1::2] = torch.cos(position * div_term)
822
+ self.register_buffer("pe", pe)
823
+ self.sub_idxs = None
824
+
825
+ def set_sub_idxs(self, sub_idxs: list[int]):
826
+ self.sub_idxs = sub_idxs
827
+
828
+ def forward(self, x):
829
+ #if self.sub_idxs is not None:
830
+ # x = x + self.pe[:, self.sub_idxs]
831
+ #else:
832
+ x = x + self.pe[:, : x.size(1)]
833
+ return self.dropout(x)
834
+
835
+
836
+ class CrossAttentionMM(nn.Module):
837
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None,
838
+ operations=comfy.ops.disable_weight_init):
839
+ super().__init__()
840
+ inner_dim = dim_head * heads
841
+ context_dim = default(context_dim, query_dim)
842
+
843
+ self.heads = heads
844
+ self.dim_head = dim_head
845
+ self.scale = None
846
+
847
+ self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device)
848
+ self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
849
+ self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device)
850
+
851
+ self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout))
852
+
853
+ def forward(self, x, context=None, value=None, mask=None, scale_mask=None):
854
+ q = self.to_q(x)
855
+ context = default(context, x)
856
+ k: Tensor = self.to_k(context)
857
+ if value is not None:
858
+ v = self.to_v(value)
859
+ del value
860
+ else:
861
+ v = self.to_v(context)
862
+
863
+ # apply custom scale by multiplying k by scale factor
864
+ if self.scale is not None:
865
+ k *= self.scale
866
+
867
+ # apply scale mask, if present
868
+ if scale_mask is not None:
869
+ k *= scale_mask
870
+
871
+ out = optimized_attention_mm(q, k, v, self.heads, mask)
872
+ return self.to_out(out)
873
+
874
+
875
+ class VersatileAttention(CrossAttentionMM):
876
+ def __init__(
877
+ self,
878
+ attention_mode=None,
879
+ cross_frame_attention_mode=None,
880
+ temporal_position_encoding=False,
881
+ temporal_position_encoding_max_len=24,
882
+ *args,
883
+ **kwargs,
884
+ ):
885
+ super().__init__(*args, **kwargs)
886
+ assert attention_mode == "Temporal"
887
+
888
+ self.attention_mode = attention_mode
889
+ self.is_cross_attention = kwargs["context_dim"] is not None
890
+
891
+ self.pos_encoder = (
892
+ PositionalEncoding(
893
+ kwargs["query_dim"],
894
+ dropout=0.0,
895
+ max_len=temporal_position_encoding_max_len,
896
+ )
897
+ if (temporal_position_encoding and attention_mode == "Temporal")
898
+ else None
899
+ )
900
+
901
+ def extra_repr(self):
902
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
903
+
904
+ def set_scale_multiplier(self, multiplier: Union[float, None]):
905
+ if multiplier is None or math.isclose(multiplier, 1.0):
906
+ self.scale = None
907
+ else:
908
+ self.scale = multiplier
909
+
910
+ def set_sub_idxs(self, sub_idxs: list[int]):
911
+ if self.pos_encoder != None:
912
+ self.pos_encoder.set_sub_idxs(sub_idxs)
913
+
914
+ def forward(
915
+ self,
916
+ hidden_states: Tensor,
917
+ encoder_hidden_states=None,
918
+ attention_mask=None,
919
+ video_length=None,
920
+ scale_mask=None,
921
+ ):
922
+ if self.attention_mode != "Temporal":
923
+ raise NotImplementedError
924
+
925
+ d = hidden_states.shape[1]
926
+ hidden_states = rearrange(
927
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
928
+ )
929
+
930
+ if self.pos_encoder is not None:
931
+ hidden_states = self.pos_encoder(hidden_states).to(hidden_states.dtype)
932
+
933
+ encoder_hidden_states = (
934
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
935
+ if encoder_hidden_states is not None
936
+ else encoder_hidden_states
937
+ )
938
+
939
+ hidden_states = super().forward(
940
+ hidden_states,
941
+ encoder_hidden_states,
942
+ value=None,
943
+ mask=attention_mask,
944
+ scale_mask=scale_mask,
945
+ )
946
+
947
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
948
+
949
+ return hidden_states
ComfyUI-Advanced-ControlNet/adv_control/control_svd.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+
5
+ import comfy.model_detection
6
+ from comfy.utils import UNET_MAP_BASIC, UNET_MAP_RESNET, UNET_MAP_ATTENTIONS, TRANSFORMER_BLOCKS
7
+
8
+ import torch
9
+
10
+
11
+ from comfy.ldm.modules.diffusionmodules.util import (
12
+ zero_module,
13
+ timestep_embedding,
14
+ )
15
+
16
+ from comfy.ldm.modules.attention import SpatialVideoTransformer
17
+ from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, TimestepEmbedSequential, VideoResBlock, Downsample
18
+ from comfy.ldm.util import exists
19
+ import comfy.ops
20
+
21
+
22
+ class SVDControlNet(nn.Module):
23
+ def __init__(
24
+ self,
25
+ image_size,
26
+ in_channels,
27
+ model_channels,
28
+ hint_channels,
29
+ num_res_blocks,
30
+ dropout=0,
31
+ channel_mult=(1, 2, 4, 8),
32
+ conv_resample=True,
33
+ dims=2,
34
+ num_classes=None,
35
+ use_checkpoint=False,
36
+ dtype=torch.float32,
37
+ num_heads=-1,
38
+ num_head_channels=-1,
39
+ num_heads_upsample=-1,
40
+ use_scale_shift_norm=False,
41
+ resblock_updown=False,
42
+ use_new_attention_order=False,
43
+ use_spatial_transformer=False, # custom transformer support
44
+ transformer_depth=1, # custom transformer support
45
+ context_dim=None, # custom transformer support
46
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
47
+ legacy=True,
48
+ disable_self_attentions=None,
49
+ num_attention_blocks=None,
50
+ disable_middle_self_attn=False,
51
+ use_linear_in_transformer=False,
52
+ adm_in_channels=None,
53
+ transformer_depth_middle=None,
54
+ transformer_depth_output=None,
55
+ use_spatial_context=False,
56
+ extra_ff_mix_layer=False,
57
+ merge_strategy="fixed",
58
+ merge_factor=0.5,
59
+ video_kernel_size=3,
60
+ device=None,
61
+ operations=comfy.ops.disable_weight_init,
62
+ **kwargs,
63
+ ):
64
+ super().__init__()
65
+ assert use_spatial_transformer == True, "use_spatial_transformer has to be true"
66
+ if use_spatial_transformer:
67
+ assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
68
+
69
+ if context_dim is not None:
70
+ assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
71
+ # from omegaconf.listconfig import ListConfig
72
+ # if type(context_dim) == ListConfig:
73
+ # context_dim = list(context_dim)
74
+
75
+ if num_heads_upsample == -1:
76
+ num_heads_upsample = num_heads
77
+
78
+ if num_heads == -1:
79
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
80
+
81
+ if num_head_channels == -1:
82
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
83
+
84
+ self.dims = dims
85
+ self.image_size = image_size
86
+ self.in_channels = in_channels
87
+ self.model_channels = model_channels
88
+
89
+ if isinstance(num_res_blocks, int):
90
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
91
+ else:
92
+ if len(num_res_blocks) != len(channel_mult):
93
+ raise ValueError("provide num_res_blocks either as an int (globally constant) or "
94
+ "as a list/tuple (per-level) with the same length as channel_mult")
95
+ self.num_res_blocks = num_res_blocks
96
+
97
+ if disable_self_attentions is not None:
98
+ # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
99
+ assert len(disable_self_attentions) == len(channel_mult)
100
+ if num_attention_blocks is not None:
101
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
102
+ assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
103
+
104
+ transformer_depth = transformer_depth[:]
105
+
106
+ self.dropout = dropout
107
+ self.channel_mult = channel_mult
108
+ self.conv_resample = conv_resample
109
+ self.num_classes = num_classes
110
+ self.use_checkpoint = use_checkpoint
111
+ self.dtype = dtype
112
+ self.num_heads = num_heads
113
+ self.num_head_channels = num_head_channels
114
+ self.num_heads_upsample = num_heads_upsample
115
+ self.predict_codebook_ids = n_embed is not None
116
+
117
+ time_embed_dim = model_channels * 4
118
+ self.time_embed = nn.Sequential(
119
+ operations.Linear(model_channels, time_embed_dim, dtype=self.dtype, device=device),
120
+ nn.SiLU(),
121
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
122
+ )
123
+
124
+ if self.num_classes is not None:
125
+ if isinstance(self.num_classes, int):
126
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
127
+ elif self.num_classes == "continuous":
128
+ print("setting up linear c_adm embedding layer")
129
+ self.label_emb = nn.Linear(1, time_embed_dim)
130
+ elif self.num_classes == "sequential":
131
+ assert adm_in_channels is not None
132
+ self.label_emb = nn.Sequential(
133
+ nn.Sequential(
134
+ operations.Linear(adm_in_channels, time_embed_dim, dtype=self.dtype, device=device),
135
+ nn.SiLU(),
136
+ operations.Linear(time_embed_dim, time_embed_dim, dtype=self.dtype, device=device),
137
+ )
138
+ )
139
+ else:
140
+ raise ValueError()
141
+
142
+ self.input_blocks = nn.ModuleList(
143
+ [
144
+ TimestepEmbedSequential(
145
+ operations.conv_nd(dims, in_channels, model_channels, 3, padding=1, dtype=self.dtype, device=device)
146
+ )
147
+ ]
148
+ )
149
+ self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels, operations=operations, dtype=self.dtype, device=device)])
150
+
151
+ self.input_hint_block = TimestepEmbedSequential(
152
+ operations.conv_nd(dims, hint_channels, 16, 3, padding=1, dtype=self.dtype, device=device),
153
+ nn.SiLU(),
154
+ operations.conv_nd(dims, 16, 16, 3, padding=1, dtype=self.dtype, device=device),
155
+ nn.SiLU(),
156
+ operations.conv_nd(dims, 16, 32, 3, padding=1, stride=2, dtype=self.dtype, device=device),
157
+ nn.SiLU(),
158
+ operations.conv_nd(dims, 32, 32, 3, padding=1, dtype=self.dtype, device=device),
159
+ nn.SiLU(),
160
+ operations.conv_nd(dims, 32, 96, 3, padding=1, stride=2, dtype=self.dtype, device=device),
161
+ nn.SiLU(),
162
+ operations.conv_nd(dims, 96, 96, 3, padding=1, dtype=self.dtype, device=device),
163
+ nn.SiLU(),
164
+ operations.conv_nd(dims, 96, 256, 3, padding=1, stride=2, dtype=self.dtype, device=device),
165
+ nn.SiLU(),
166
+ operations.conv_nd(dims, 256, model_channels, 3, padding=1, dtype=self.dtype, device=device)
167
+ )
168
+
169
+ self._feature_size = model_channels
170
+ input_block_chans = [model_channels]
171
+ ch = model_channels
172
+ ds = 1
173
+ for level, mult in enumerate(channel_mult):
174
+ for nr in range(self.num_res_blocks[level]):
175
+ layers = [
176
+ VideoResBlock(
177
+ ch,
178
+ time_embed_dim,
179
+ dropout,
180
+ out_channels=mult * model_channels,
181
+ dims=dims,
182
+ use_checkpoint=use_checkpoint,
183
+ use_scale_shift_norm=use_scale_shift_norm,
184
+ dtype=self.dtype,
185
+ device=device,
186
+ operations=operations,
187
+ video_kernel_size=video_kernel_size,
188
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
189
+ )
190
+ ]
191
+ ch = mult * model_channels
192
+ num_transformers = transformer_depth.pop(0)
193
+ if num_transformers > 0:
194
+ if num_head_channels == -1:
195
+ dim_head = ch // num_heads
196
+ else:
197
+ num_heads = ch // num_head_channels
198
+ dim_head = num_head_channels
199
+ if legacy:
200
+ #num_heads = 1
201
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
202
+ if exists(disable_self_attentions):
203
+ disabled_sa = disable_self_attentions[level]
204
+ else:
205
+ disabled_sa = False
206
+
207
+ if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
208
+ layers.append(
209
+ SpatialVideoTransformer(
210
+ ch, num_heads, dim_head, depth=num_transformers, context_dim=context_dim,
211
+ disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
212
+ checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations,
213
+ use_spatial_context=use_spatial_context, ff_in=extra_ff_mix_layer,
214
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
215
+ )
216
+ )
217
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
218
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
219
+ self._feature_size += ch
220
+ input_block_chans.append(ch)
221
+ if level != len(channel_mult) - 1:
222
+ out_ch = ch
223
+ self.input_blocks.append(
224
+ TimestepEmbedSequential(
225
+ VideoResBlock(
226
+ ch,
227
+ time_embed_dim,
228
+ dropout,
229
+ out_channels=out_ch,
230
+ dims=dims,
231
+ use_checkpoint=use_checkpoint,
232
+ use_scale_shift_norm=use_scale_shift_norm,
233
+ down=True,
234
+ dtype=self.dtype,
235
+ device=device,
236
+ operations=operations,
237
+ video_kernel_size=video_kernel_size,
238
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
239
+ )
240
+ if resblock_updown
241
+ else Downsample(
242
+ ch, conv_resample, dims=dims, out_channels=out_ch, dtype=self.dtype, device=device, operations=operations
243
+ )
244
+ )
245
+ )
246
+ ch = out_ch
247
+ input_block_chans.append(ch)
248
+ self.zero_convs.append(self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device))
249
+ ds *= 2
250
+ self._feature_size += ch
251
+
252
+ if num_head_channels == -1:
253
+ dim_head = ch // num_heads
254
+ else:
255
+ num_heads = ch // num_head_channels
256
+ dim_head = num_head_channels
257
+ if legacy:
258
+ #num_heads = 1
259
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
260
+ mid_block = [
261
+ VideoResBlock(
262
+ ch,
263
+ time_embed_dim,
264
+ dropout,
265
+ dims=dims,
266
+ use_checkpoint=use_checkpoint,
267
+ use_scale_shift_norm=use_scale_shift_norm,
268
+ dtype=self.dtype,
269
+ device=device,
270
+ operations=operations,
271
+ video_kernel_size=video_kernel_size,
272
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
273
+ )]
274
+ if transformer_depth_middle >= 0:
275
+ mid_block += [SpatialVideoTransformer( # always uses a self-attn
276
+ ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
277
+ disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
278
+ checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations,
279
+ use_spatial_context=use_spatial_context, ff_in=extra_ff_mix_layer,
280
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
281
+ ),
282
+ VideoResBlock(
283
+ ch,
284
+ time_embed_dim,
285
+ dropout,
286
+ dims=dims,
287
+ use_checkpoint=use_checkpoint,
288
+ use_scale_shift_norm=use_scale_shift_norm,
289
+ dtype=self.dtype,
290
+ device=device,
291
+ operations=operations,
292
+ video_kernel_size=video_kernel_size,
293
+ merge_strategy=merge_strategy, merge_factor=merge_factor,
294
+ )]
295
+ self.middle_block = TimestepEmbedSequential(*mid_block)
296
+ self.middle_block_out = self.make_zero_conv(ch, operations=operations, dtype=self.dtype, device=device)
297
+ self._feature_size += ch
298
+
299
+ def make_zero_conv(self, channels, operations=None, dtype=None, device=None):
300
+ return TimestepEmbedSequential(operations.conv_nd(self.dims, channels, channels, 1, padding=0, dtype=dtype, device=device))
301
+
302
+ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
303
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
304
+ emb = self.time_embed(t_emb)
305
+
306
+ cond = kwargs["cond"]
307
+ num_video_frames = cond["num_video_frames"]
308
+ image_only_indicator = cond.get("image_only_indicator", None)
309
+ time_context = cond.get("time_context", None)
310
+ del cond
311
+
312
+ guided_hint = self.input_hint_block(hint, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
313
+
314
+ outs = []
315
+
316
+ hs = []
317
+ if self.num_classes is not None:
318
+ assert y.shape[0] == x.shape[0]
319
+ emb = emb + self.label_emb(y)
320
+
321
+ h = x
322
+ for module, zero_conv in zip(self.input_blocks, self.zero_convs):
323
+ if guided_hint is not None:
324
+ h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
325
+ h += guided_hint
326
+ guided_hint = None
327
+ else:
328
+ h = module(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
329
+ outs.append(zero_conv(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
330
+
331
+ h = self.middle_block(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
332
+ outs.append(self.middle_block_out(h, emb, context, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator))
333
+
334
+ return outs
335
+
336
+
337
+ TEMPORAL_TRANSFORMER_BLOCKS = {
338
+ "norm_in.weight",
339
+ "norm_in.bias",
340
+ "ff_in.net.0.proj.weight",
341
+ "ff_in.net.0.proj.bias",
342
+ "ff_in.net.2.weight",
343
+ "ff_in.net.2.bias",
344
+ }
345
+ TEMPORAL_TRANSFORMER_BLOCKS.update(TRANSFORMER_BLOCKS)
346
+
347
+
348
+ TEMPORAL_UNET_MAP_ATTENTIONS = {
349
+ "time_mixer.mix_factor",
350
+ }
351
+ TEMPORAL_UNET_MAP_ATTENTIONS.update(UNET_MAP_ATTENTIONS)
352
+
353
+
354
+ TEMPORAL_TRANSFORMER_MAP = {
355
+ "time_pos_embed.0.weight": "time_pos_embed.linear_1.weight",
356
+ "time_pos_embed.0.bias": "time_pos_embed.linear_1.bias",
357
+ "time_pos_embed.2.weight": "time_pos_embed.linear_2.weight",
358
+ "time_pos_embed.2.bias": "time_pos_embed.linear_2.bias",
359
+ }
360
+
361
+
362
+ TEMPORAL_RESNET = {
363
+ "time_mixer.mix_factor",
364
+ }
365
+
366
+
367
+ def svd_unet_config_from_diffusers_unet(state_dict: dict[str, Tensor], dtype):
368
+ match = {}
369
+ transformer_depth = []
370
+
371
+ attn_res = 1
372
+ down_blocks = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}")
373
+ for i in range(down_blocks):
374
+ attn_blocks = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
375
+ for ab in range(attn_blocks):
376
+ transformer_count = comfy.model_detection.count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
377
+ transformer_depth.append(transformer_count)
378
+ if transformer_count > 0:
379
+ match["context_dim"] = state_dict["down_blocks.{}.attentions.{}.transformer_blocks.0.attn2.to_k.weight".format(i, ab)].shape[1]
380
+
381
+ attn_res *= 2
382
+ if attn_blocks == 0:
383
+ transformer_depth.append(0)
384
+ transformer_depth.append(0)
385
+
386
+ match["transformer_depth"] = transformer_depth
387
+
388
+ match["model_channels"] = state_dict["conv_in.weight"].shape[0]
389
+ match["in_channels"] = state_dict["conv_in.weight"].shape[1]
390
+ match["adm_in_channels"] = None
391
+ if "class_embedding.linear_1.weight" in state_dict:
392
+ match["adm_in_channels"] = state_dict["class_embedding.linear_1.weight"].shape[1]
393
+ elif "add_embedding.linear_1.weight" in state_dict:
394
+ match["adm_in_channels"] = state_dict["add_embedding.linear_1.weight"].shape[1]
395
+
396
+ # based on unet_config of SVD
397
+ SVD = {
398
+ 'use_checkpoint': False,
399
+ 'image_size': 32,
400
+ 'use_spatial_transformer': True,
401
+ 'legacy': False,
402
+ 'num_classes': 'sequential',
403
+ 'adm_in_channels': 768,
404
+ 'dtype': dtype,
405
+ 'in_channels': 8,
406
+ 'out_channels': 4,
407
+ 'model_channels': 320,
408
+ 'num_res_blocks': [2, 2, 2, 2],
409
+ 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
410
+ 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
411
+ 'channel_mult': [1, 2, 4, 4],
412
+ 'transformer_depth_middle': 1,
413
+ 'use_linear_in_transformer': True,
414
+ 'context_dim': 1024,
415
+ 'extra_ff_mix_layer': True,
416
+ 'use_spatial_context': True,
417
+ 'merge_strategy': 'learned_with_images',
418
+ 'merge_factor': 0.0,
419
+ 'video_kernel_size': [3, 1, 1],
420
+ 'use_temporal_attention': True,
421
+ 'use_temporal_resblock': True,
422
+ 'num_heads': -1,
423
+ 'num_head_channels': 64,
424
+ }
425
+
426
+ supported_models = [SVD]
427
+
428
+ for unet_config in supported_models:
429
+ matches = True
430
+ for k in match:
431
+ if match[k] != unet_config[k]:
432
+ matches = False
433
+ break
434
+ if matches:
435
+ return comfy.model_detection.convert_config(unet_config)
436
+ return None
437
+
438
+
439
+ def svd_unet_to_diffusers(unet_config):
440
+ num_res_blocks = unet_config["num_res_blocks"]
441
+ channel_mult = unet_config["channel_mult"]
442
+ transformer_depth = unet_config["transformer_depth"][:]
443
+ transformer_depth_output = unet_config["transformer_depth_output"][:]
444
+ num_blocks = len(channel_mult)
445
+
446
+ transformers_mid = unet_config.get("transformer_depth_middle", None)
447
+
448
+ diffusers_unet_map = {}
449
+ for x in range(num_blocks):
450
+ n = 1 + (num_res_blocks[x] + 1) * x
451
+ for i in range(num_res_blocks[x]):
452
+ for b in TEMPORAL_RESNET:
453
+ diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, b)] = "input_blocks.{}.0.{}".format(n, b)
454
+ for b in UNET_MAP_RESNET:
455
+ diffusers_unet_map["down_blocks.{}.resnets.{}.spatial_res_block.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
456
+ diffusers_unet_map["down_blocks.{}.resnets.{}.temporal_res_block.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.time_stack.{}".format(n, b)
457
+ #diffusers_unet_map["down_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "input_blocks.{}.0.{}".format(n, b)
458
+ num_transformers = transformer_depth.pop(0)
459
+ if num_transformers > 0:
460
+ for b in TEMPORAL_UNET_MAP_ATTENTIONS:
461
+ diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, b)] = "input_blocks.{}.1.{}".format(n, b)
462
+ for b in TEMPORAL_TRANSFORMER_MAP:
463
+ diffusers_unet_map["down_blocks.{}.attentions.{}.{}".format(x, i, TEMPORAL_TRANSFORMER_MAP[b])] = "input_blocks.{}.1.{}".format(n, b)
464
+ for t in range(num_transformers):
465
+ for b in TRANSFORMER_BLOCKS:
466
+ diffusers_unet_map["down_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
467
+ for b in TEMPORAL_TRANSFORMER_BLOCKS:
468
+ diffusers_unet_map["down_blocks.{}.attentions.{}.temporal_transformer_blocks.{}.{}".format(x, i, t, b)] = "input_blocks.{}.1.time_stack.{}.{}".format(n, t, b)
469
+ n += 1
470
+ for k in ["weight", "bias"]:
471
+ diffusers_unet_map["down_blocks.{}.downsamplers.0.conv.{}".format(x, k)] = "input_blocks.{}.0.op.{}".format(n, k)
472
+
473
+ i = 0
474
+ for b in TEMPORAL_UNET_MAP_ATTENTIONS:
475
+ diffusers_unet_map["mid_block.attentions.{}.{}".format(i, b)] = "middle_block.1.{}".format(b)
476
+ for b in TEMPORAL_TRANSFORMER_MAP:
477
+ diffusers_unet_map["mid_block.attentions.{}.{}".format(i, TEMPORAL_TRANSFORMER_MAP[b])] = "middle_block.1.{}".format(b)
478
+ for t in range(transformers_mid):
479
+ for b in TRANSFORMER_BLOCKS:
480
+ diffusers_unet_map["mid_block.attentions.{}.transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.transformer_blocks.{}.{}".format(t, b)
481
+ for b in TEMPORAL_TRANSFORMER_BLOCKS:
482
+ diffusers_unet_map["mid_block.attentions.{}.temporal_transformer_blocks.{}.{}".format(i, t, b)] = "middle_block.1.time_stack.{}.{}".format(t, b)
483
+
484
+ for i, n in enumerate([0, 2]):
485
+ for b in TEMPORAL_RESNET:
486
+ diffusers_unet_map["mid_block.resnets.{}.{}".format(i, b)] = "middle_block.{}.{}".format(n, b)
487
+ for b in UNET_MAP_RESNET:
488
+ diffusers_unet_map["mid_block.resnets.{}.spatial_res_block.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
489
+ diffusers_unet_map["mid_block.resnets.{}.temporal_res_block.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.time_stack.{}".format(n, b)
490
+ #diffusers_unet_map["mid_block.resnets.{}.{}".format(i, UNET_MAP_RESNET[b])] = "middle_block.{}.{}".format(n, b)
491
+
492
+ num_res_blocks = list(reversed(num_res_blocks))
493
+ for x in range(num_blocks):
494
+ n = (num_res_blocks[x] + 1) * x
495
+ l = num_res_blocks[x] + 1
496
+ for i in range(l):
497
+ c = 0
498
+ for b in UNET_MAP_RESNET:
499
+ diffusers_unet_map["up_blocks.{}.resnets.{}.{}".format(x, i, UNET_MAP_RESNET[b])] = "output_blocks.{}.0.{}".format(n, b)
500
+ c += 1
501
+ num_transformers = transformer_depth_output.pop()
502
+ if num_transformers > 0:
503
+ c += 1
504
+ for b in UNET_MAP_ATTENTIONS:
505
+ diffusers_unet_map["up_blocks.{}.attentions.{}.{}".format(x, i, b)] = "output_blocks.{}.1.{}".format(n, b)
506
+ for t in range(num_transformers):
507
+ for b in TRANSFORMER_BLOCKS:
508
+ diffusers_unet_map["up_blocks.{}.attentions.{}.transformer_blocks.{}.{}".format(x, i, t, b)] = "output_blocks.{}.1.transformer_blocks.{}.{}".format(n, t, b)
509
+ if i == l - 1:
510
+ for k in ["weight", "bias"]:
511
+ diffusers_unet_map["up_blocks.{}.upsamplers.0.conv.{}".format(x, k)] = "output_blocks.{}.{}.conv.{}".format(n, c, k)
512
+ n += 1
513
+
514
+ for k in UNET_MAP_BASIC:
515
+ diffusers_unet_map[k[1]] = k[0]
516
+
517
+ return diffusers_unet_map
ComfyUI-Advanced-ControlNet/adv_control/logger.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import copy
3
+ import logging
4
+
5
+
6
+ class ColoredFormatter(logging.Formatter):
7
+ COLORS = {
8
+ "DEBUG": "\033[0;36m", # CYAN
9
+ "INFO": "\033[0;32m", # GREEN
10
+ "WARNING": "\033[0;33m", # YELLOW
11
+ "ERROR": "\033[0;31m", # RED
12
+ "CRITICAL": "\033[0;37;41m", # WHITE ON RED
13
+ "RESET": "\033[0m", # RESET COLOR
14
+ }
15
+
16
+ def format(self, record):
17
+ colored_record = copy.copy(record)
18
+ levelname = colored_record.levelname
19
+ seq = self.COLORS.get(levelname, self.COLORS["RESET"])
20
+ colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
21
+ return super().format(colored_record)
22
+
23
+
24
+ # Create a new logger
25
+ logger = logging.getLogger("Advanced-ControlNet")
26
+ logger.propagate = False
27
+
28
+ # Add handler if we don't have one.
29
+ if not logger.handlers:
30
+ handler = logging.StreamHandler(sys.stdout)
31
+ handler.setFormatter(ColoredFormatter("[%(name)s] - %(levelname)s - %(message)s"))
32
+ logger.addHandler(handler)
33
+
34
+ # Configure logger
35
+ loglevel = logging.INFO
36
+ logger.setLevel(loglevel)
ComfyUI-Advanced-ControlNet/adv_control/nodes.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch import Tensor
3
+
4
+ import folder_paths
5
+ from comfy.model_patcher import ModelPatcher
6
+
7
+ from .control import load_controlnet, convert_to_advanced, is_advanced_controlnet
8
+ from .utils import ControlWeights, LatentKeyframeGroup, TimestepKeyframeGroup, BIGMAX
9
+ from .nodes_weight import (DefaultWeights, ScaledSoftMaskedUniversalWeights, ScaledSoftUniversalWeights, SoftControlNetWeights, CustomControlNetWeights,
10
+ SoftT2IAdapterWeights, CustomT2IAdapterWeights)
11
+ from .nodes_keyframes import (LatentKeyframeGroupNode, LatentKeyframeInterpolationNode, LatentKeyframeBatchedGroupNode, LatentKeyframeNode,
12
+ TimestepKeyframeNode, TimestepKeyframeInterpolationNode, TimestepKeyframeFromStrengthListNode)
13
+ from .nodes_sparsectrl import SparseCtrlMergedLoaderAdvanced, SparseCtrlLoaderAdvanced, SparseIndexMethodNode, SparseSpreadMethodNode, RgbSparseCtrlPreprocessor
14
+ from .nodes_reference import ReferenceControlNetNode, ReferenceControlFinetune, ReferencePreprocessorNode
15
+ from .nodes_loosecontrol import ControlNetLoaderWithLoraAdvanced
16
+ from .nodes_deprecated import LoadImagesFromDirectory
17
+ from .logger import logger
18
+
19
+
20
+ class ControlNetLoaderAdvanced:
21
+ @classmethod
22
+ def INPUT_TYPES(s):
23
+ return {
24
+ "required": {
25
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), ),
26
+ },
27
+ "optional": {
28
+ "timestep_keyframe": ("TIMESTEP_KEYFRAME", ),
29
+ }
30
+ }
31
+
32
+ RETURN_TYPES = ("CONTROL_NET", )
33
+ FUNCTION = "load_controlnet"
34
+
35
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
36
+
37
+ def load_controlnet(self, control_net_name,
38
+ timestep_keyframe: TimestepKeyframeGroup=None
39
+ ):
40
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
41
+ controlnet = load_controlnet(controlnet_path, timestep_keyframe)
42
+ return (controlnet,)
43
+
44
+
45
+ class DiffControlNetLoaderAdvanced:
46
+ @classmethod
47
+ def INPUT_TYPES(s):
48
+ return {
49
+ "required": {
50
+ "model": ("MODEL",),
51
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), )
52
+ },
53
+ "optional": {
54
+ "timestep_keyframe": ("TIMESTEP_KEYFRAME", ),
55
+ }
56
+ }
57
+
58
+ RETURN_TYPES = ("CONTROL_NET", )
59
+ FUNCTION = "load_controlnet"
60
+
61
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
62
+
63
+ def load_controlnet(self, control_net_name, model,
64
+ timestep_keyframe: TimestepKeyframeGroup=None
65
+ ):
66
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
67
+ controlnet = load_controlnet(controlnet_path, timestep_keyframe, model)
68
+ if is_advanced_controlnet(controlnet):
69
+ controlnet.verify_all_weights()
70
+ return (controlnet,)
71
+
72
+
73
+ class AdvancedControlNetApply:
74
+ @classmethod
75
+ def INPUT_TYPES(s):
76
+ return {
77
+ "required": {
78
+ "positive": ("CONDITIONING", ),
79
+ "negative": ("CONDITIONING", ),
80
+ "control_net": ("CONTROL_NET", ),
81
+ "image": ("IMAGE", ),
82
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
83
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
84
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
85
+ },
86
+ "optional": {
87
+ "mask_optional": ("MASK", ),
88
+ "timestep_kf": ("TIMESTEP_KEYFRAME", ),
89
+ "latent_kf_override": ("LATENT_KEYFRAME", ),
90
+ "weights_override": ("CONTROL_NET_WEIGHTS", ),
91
+ "model_optional": ("MODEL",),
92
+ }
93
+ }
94
+
95
+ RETURN_TYPES = ("CONDITIONING","CONDITIONING","MODEL",)
96
+ RETURN_NAMES = ("positive", "negative", "model_opt")
97
+ FUNCTION = "apply_controlnet"
98
+
99
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝"
100
+
101
+ def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent,
102
+ mask_optional: Tensor=None, model_optional: ModelPatcher=None,
103
+ timestep_kf: TimestepKeyframeGroup=None, latent_kf_override: LatentKeyframeGroup=None,
104
+ weights_override: ControlWeights=None):
105
+ if strength == 0:
106
+ return (positive, negative, model_optional)
107
+ if model_optional:
108
+ model_optional = model_optional.clone()
109
+
110
+ control_hint = image.movedim(-1,1)
111
+ cnets = {}
112
+
113
+ out = []
114
+ for conditioning in [positive, negative]:
115
+ c = []
116
+ for t in conditioning:
117
+ d = t[1].copy()
118
+
119
+ prev_cnet = d.get('control', None)
120
+ if prev_cnet in cnets:
121
+ c_net = cnets[prev_cnet]
122
+ else:
123
+ # copy, convert to advanced if needed, and set cond
124
+ c_net = convert_to_advanced(control_net.copy()).set_cond_hint(control_hint, strength, (start_percent, end_percent))
125
+ if is_advanced_controlnet(c_net):
126
+ # disarm node check
127
+ c_net.disarm()
128
+ # if model required, verify model is passed in, and if so patch it
129
+ if c_net.require_model:
130
+ if not model_optional:
131
+ raise Exception(f"Type '{type(c_net).__name__}' requires model_optional input, but got None.")
132
+ c_net.patch_model(model=model_optional)
133
+ # apply optional parameters and overrides, if provided
134
+ if timestep_kf is not None:
135
+ c_net.set_timestep_keyframes(timestep_kf)
136
+ if latent_kf_override is not None:
137
+ c_net.latent_keyframe_override = latent_kf_override
138
+ if weights_override is not None:
139
+ c_net.weights_override = weights_override
140
+ # verify weights are compatible
141
+ c_net.verify_all_weights()
142
+ # set cond hint mask
143
+ if mask_optional is not None:
144
+ mask_optional = mask_optional.clone()
145
+ # if not in the form of a batch, make it so
146
+ if len(mask_optional.shape) < 3:
147
+ mask_optional = mask_optional.unsqueeze(0)
148
+ c_net.set_cond_hint_mask(mask_optional)
149
+ c_net.set_previous_controlnet(prev_cnet)
150
+ cnets[prev_cnet] = c_net
151
+
152
+ d['control'] = c_net
153
+ d['control_apply_to_uncond'] = False
154
+ n = [t[0], d]
155
+ c.append(n)
156
+ out.append(c)
157
+ return (out[0], out[1], model_optional)
158
+
159
+
160
+ # NODE MAPPING
161
+ NODE_CLASS_MAPPINGS = {
162
+ # Keyframes
163
+ "TimestepKeyframe": TimestepKeyframeNode,
164
+ "ACN_TimestepKeyframeInterpolation": TimestepKeyframeInterpolationNode,
165
+ "ACN_TimestepKeyframeFromStrengthList": TimestepKeyframeFromStrengthListNode,
166
+ "LatentKeyframe": LatentKeyframeNode,
167
+ "LatentKeyframeTiming": LatentKeyframeInterpolationNode,
168
+ "LatentKeyframeBatchedGroup": LatentKeyframeBatchedGroupNode,
169
+ "LatentKeyframeGroup": LatentKeyframeGroupNode,
170
+ # Conditioning
171
+ "ACN_AdvancedControlNetApply": AdvancedControlNetApply,
172
+ # Loaders
173
+ "ControlNetLoaderAdvanced": ControlNetLoaderAdvanced,
174
+ "DiffControlNetLoaderAdvanced": DiffControlNetLoaderAdvanced,
175
+ # Weights
176
+ "ScaledSoftControlNetWeights": ScaledSoftUniversalWeights,
177
+ "ScaledSoftMaskedUniversalWeights": ScaledSoftMaskedUniversalWeights,
178
+ "SoftControlNetWeights": SoftControlNetWeights,
179
+ "CustomControlNetWeights": CustomControlNetWeights,
180
+ "SoftT2IAdapterWeights": SoftT2IAdapterWeights,
181
+ "CustomT2IAdapterWeights": CustomT2IAdapterWeights,
182
+ "ACN_DefaultUniversalWeights": DefaultWeights,
183
+ # SparseCtrl
184
+ "ACN_SparseCtrlRGBPreprocessor": RgbSparseCtrlPreprocessor,
185
+ "ACN_SparseCtrlLoaderAdvanced": SparseCtrlLoaderAdvanced,
186
+ "ACN_SparseCtrlMergedLoaderAdvanced": SparseCtrlMergedLoaderAdvanced,
187
+ "ACN_SparseCtrlIndexMethodNode": SparseIndexMethodNode,
188
+ "ACN_SparseCtrlSpreadMethodNode": SparseSpreadMethodNode,
189
+ # Reference
190
+ "ACN_ReferencePreprocessor": ReferencePreprocessorNode,
191
+ "ACN_ReferenceControlNet": ReferenceControlNetNode,
192
+ "ACN_ReferenceControlNetFinetune": ReferenceControlFinetune,
193
+ # LOOSEControl
194
+ #"ACN_ControlNetLoaderWithLoraAdvanced": ControlNetLoaderWithLoraAdvanced,
195
+ # Deprecated
196
+ "LoadImagesFromDirectory": LoadImagesFromDirectory,
197
+ }
198
+
199
+ NODE_DISPLAY_NAME_MAPPINGS = {
200
+ # Keyframes
201
+ "TimestepKeyframe": "Timestep Keyframe 🛂🅐🅒🅝",
202
+ "ACN_TimestepKeyframeInterpolation": "Timestep Keyframe Interpolation 🛂🅐🅒🅝",
203
+ "ACN_TimestepKeyframeFromStrengthList": "Timestep Keyframe From List 🛂🅐🅒🅝",
204
+ "LatentKeyframe": "Latent Keyframe 🛂🅐🅒🅝",
205
+ "LatentKeyframeTiming": "Latent Keyframe Interpolation 🛂🅐🅒🅝",
206
+ "LatentKeyframeBatchedGroup": "Latent Keyframe From List 🛂🅐🅒🅝",
207
+ "LatentKeyframeGroup": "Latent Keyframe Group 🛂🅐🅒🅝",
208
+ # Conditioning
209
+ "ACN_AdvancedControlNetApply": "Apply Advanced ControlNet 🛂🅐🅒🅝",
210
+ # Loaders
211
+ "ControlNetLoaderAdvanced": "Load Advanced ControlNet Model 🛂🅐🅒🅝",
212
+ "DiffControlNetLoaderAdvanced": "Load Advanced ControlNet Model (diff) 🛂🅐🅒🅝",
213
+ # Weights
214
+ "ScaledSoftControlNetWeights": "Scaled Soft Weights 🛂🅐🅒🅝",
215
+ "ScaledSoftMaskedUniversalWeights": "Scaled Soft Masked Weights 🛂🅐🅒🅝",
216
+ "SoftControlNetWeights": "ControlNet Soft Weights 🛂🅐🅒🅝",
217
+ "CustomControlNetWeights": "ControlNet Custom Weights 🛂🅐🅒🅝",
218
+ "SoftT2IAdapterWeights": "T2IAdapter Soft Weights 🛂🅐🅒🅝",
219
+ "CustomT2IAdapterWeights": "T2IAdapter Custom Weights 🛂🅐🅒🅝",
220
+ "ACN_DefaultUniversalWeights": "Force Default Weights 🛂🅐🅒🅝",
221
+ # SparseCtrl
222
+ "ACN_SparseCtrlRGBPreprocessor": "RGB SparseCtrl 🛂🅐🅒🅝",
223
+ "ACN_SparseCtrlLoaderAdvanced": "Load SparseCtrl Model 🛂🅐🅒🅝",
224
+ "ACN_SparseCtrlMergedLoaderAdvanced": "🧪Load Merged SparseCtrl Model 🛂🅐🅒🅝",
225
+ "ACN_SparseCtrlIndexMethodNode": "SparseCtrl Index Method 🛂🅐🅒🅝",
226
+ "ACN_SparseCtrlSpreadMethodNode": "SparseCtrl Spread Method 🛂🅐🅒🅝",
227
+ # Reference
228
+ "ACN_ReferencePreprocessor": "Reference Preproccessor 🛂🅐🅒🅝",
229
+ "ACN_ReferenceControlNet": "Reference ControlNet 🛂🅐🅒🅝",
230
+ "ACN_ReferenceControlNetFinetune": "Reference ControlNet (Finetune) 🛂🅐🅒🅝",
231
+ # LOOSEControl
232
+ #"ACN_ControlNetLoaderWithLoraAdvanced": "Load Adv. ControlNet Model w/ LoRA 🛂🅐🅒🅝",
233
+ # Deprecated
234
+ "LoadImagesFromDirectory": "🚫Load Images [DEPRECATED] 🛂🅐🅒🅝",
235
+ }
ComfyUI-Advanced-ControlNet/adv_control/nodes_deprecated.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+
5
+ import numpy as np
6
+ from PIL import Image, ImageOps
7
+ from .utils import BIGMAX
8
+ from .logger import logger
9
+
10
+
11
+ class LoadImagesFromDirectory:
12
+ @classmethod
13
+ def INPUT_TYPES(s):
14
+ return {
15
+ "required": {
16
+ "directory": ("STRING", {"default": ""}),
17
+ },
18
+ "optional": {
19
+ "image_load_cap": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
20
+ "start_index": ("INT", {"default": 0, "min": 0, "max": BIGMAX, "step": 1}),
21
+ }
22
+ }
23
+
24
+ RETURN_TYPES = ("IMAGE", "MASK", "INT")
25
+ FUNCTION = "load_images"
26
+
27
+ CATEGORY = ""
28
+
29
+ def load_images(self, directory: str, image_load_cap: int = 0, start_index: int = 0):
30
+ if not os.path.isdir(directory):
31
+ raise FileNotFoundError(f"Directory '{directory} cannot be found.'")
32
+ dir_files = os.listdir(directory)
33
+ if len(dir_files) == 0:
34
+ raise FileNotFoundError(f"No files in directory '{directory}'.")
35
+
36
+ dir_files = sorted(dir_files)
37
+ dir_files = [os.path.join(directory, x) for x in dir_files]
38
+ # start at start_index
39
+ dir_files = dir_files[start_index:]
40
+
41
+ images = []
42
+ masks = []
43
+
44
+ limit_images = False
45
+ if image_load_cap > 0:
46
+ limit_images = True
47
+ image_count = 0
48
+
49
+ for image_path in dir_files:
50
+ if os.path.isdir(image_path):
51
+ continue
52
+ if limit_images and image_count >= image_load_cap:
53
+ break
54
+ i = Image.open(image_path)
55
+ i = ImageOps.exif_transpose(i)
56
+ image = i.convert("RGB")
57
+ image = np.array(image).astype(np.float32) / 255.0
58
+ image = torch.from_numpy(image)[None,]
59
+ if 'A' in i.getbands():
60
+ mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
61
+ mask = 1. - torch.from_numpy(mask)
62
+ else:
63
+ mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
64
+ images.append(image)
65
+ masks.append(mask)
66
+ image_count += 1
67
+
68
+ if len(images) == 0:
69
+ raise FileNotFoundError(f"No images could be loaded from directory '{directory}'.")
70
+
71
+ return (torch.cat(images, dim=0), torch.stack(masks, dim=0), image_count)
ComfyUI-Advanced-ControlNet/adv_control/nodes_keyframes.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import numpy as np
3
+ from collections.abc import Iterable
4
+
5
+ from .utils import ControlWeights, TimestepKeyframe, TimestepKeyframeGroup, LatentKeyframe, LatentKeyframeGroup, BIGMIN, BIGMAX
6
+ from .utils import StrengthInterpolation as SI
7
+ from .logger import logger
8
+
9
+
10
+ class TimestepKeyframeNode:
11
+ OUTDATED_DUMMY = -39
12
+
13
+ @classmethod
14
+ def INPUT_TYPES(s):
15
+ return {
16
+ "required": {
17
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
18
+ },
19
+ "optional": {
20
+ "prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
21
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
22
+ "cn_weights": ("CONTROL_NET_WEIGHTS", ),
23
+ "latent_keyframe": ("LATENT_KEYFRAME", ),
24
+ "null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
25
+ "inherit_missing": ("BOOLEAN", {"default": True}, ),
26
+ "guarantee_steps": ("INT", {"default": 1, "min": 0, "max": BIGMAX}),
27
+ "mask_optional": ("MASK", ),
28
+ }
29
+ }
30
+
31
+ RETURN_NAMES = ("TIMESTEP_KF", )
32
+ RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
33
+ FUNCTION = "load_keyframe"
34
+
35
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
36
+
37
+ def load_keyframe(self,
38
+ start_percent: float,
39
+ strength: float=1.0,
40
+ cn_weights: ControlWeights=None, control_net_weights: ControlWeights=None, # old name
41
+ latent_keyframe: LatentKeyframeGroup=None,
42
+ prev_timestep_kf: TimestepKeyframeGroup=None, prev_timestep_keyframe: TimestepKeyframeGroup=None, # old name
43
+ null_latent_kf_strength: float=0.0,
44
+ inherit_missing=True,
45
+ guarantee_steps=OUTDATED_DUMMY,
46
+ guarantee_usage=True, # old input
47
+ mask_optional=None,):
48
+ # if using outdated dummy value, means node on workflow is outdated and should appropriately convert behavior
49
+ if guarantee_steps == self.OUTDATED_DUMMY:
50
+ guarantee_steps = int(guarantee_usage)
51
+ control_net_weights = control_net_weights if control_net_weights else cn_weights
52
+ prev_timestep_keyframe = prev_timestep_keyframe if prev_timestep_keyframe else prev_timestep_kf
53
+ if not prev_timestep_keyframe:
54
+ prev_timestep_keyframe = TimestepKeyframeGroup()
55
+ else:
56
+ prev_timestep_keyframe = prev_timestep_keyframe.clone()
57
+ keyframe = TimestepKeyframe(start_percent=start_percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
58
+ control_weights=control_net_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
59
+ guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional)
60
+ prev_timestep_keyframe.add(keyframe)
61
+ return (prev_timestep_keyframe,)
62
+
63
+
64
+ class TimestepKeyframeInterpolationNode:
65
+ @classmethod
66
+ def INPUT_TYPES(s):
67
+ return {
68
+ "required": {
69
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001},),
70
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
71
+ "strength_start": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},),
72
+ "strength_end": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001},),
73
+ "interpolation": (SI._LIST, ),
74
+ "intervals": ("INT", {"default": 50, "min": 2, "max": 100, "step": 1}),
75
+ },
76
+ "optional": {
77
+ "prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
78
+ "cn_weights": ("CONTROL_NET_WEIGHTS", ),
79
+ "latent_keyframe": ("LATENT_KEYFRAME", ),
80
+ "null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001},),
81
+ "inherit_missing": ("BOOLEAN", {"default": True},),
82
+ "mask_optional": ("MASK", ),
83
+ "print_keyframes": ("BOOLEAN", {"default": False}),
84
+ }
85
+ }
86
+
87
+ RETURN_NAMES = ("TIMESTEP_KF", )
88
+ RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
89
+ FUNCTION = "load_keyframe"
90
+
91
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
92
+
93
+ def load_keyframe(self,
94
+ start_percent: float, end_percent: float,
95
+ strength_start: float, strength_end: float, interpolation: str, intervals: int,
96
+ cn_weights: ControlWeights=None,
97
+ latent_keyframe: LatentKeyframeGroup=None,
98
+ prev_timestep_kf: TimestepKeyframeGroup=None,
99
+ null_latent_kf_strength: float=0.0,
100
+ inherit_missing=True,
101
+ guarantee_steps=1,
102
+ mask_optional=None, print_keyframes=False):
103
+ if not prev_timestep_kf:
104
+ prev_timestep_kf = TimestepKeyframeGroup()
105
+ else:
106
+ prev_timestep_kf = prev_timestep_kf.clone()
107
+
108
+ percents = SI.get_weights(num_from=start_percent, num_to=end_percent, length=intervals, method=SI.LINEAR)
109
+ strengths = SI.get_weights(num_from=strength_start, num_to=strength_end, length=intervals, method=interpolation)
110
+
111
+ is_first = True
112
+ for percent, strength in zip(percents, strengths):
113
+ guarantee_steps = 0
114
+ if is_first:
115
+ guarantee_steps = 1
116
+ is_first = False
117
+ prev_timestep_kf.add(TimestepKeyframe(start_percent=percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
118
+ control_weights=cn_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
119
+ guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional))
120
+ if print_keyframes:
121
+ logger.info(f"TimestepKeyframe - start_percent:{percent} = {strength}")
122
+ return (prev_timestep_kf,)
123
+
124
+
125
+ class TimestepKeyframeFromStrengthListNode:
126
+ @classmethod
127
+ def INPUT_TYPES(s):
128
+ return {
129
+ "required": {
130
+ "float_strengths": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
131
+ "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001},),
132
+ "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
133
+ },
134
+ "optional": {
135
+ "prev_timestep_kf": ("TIMESTEP_KEYFRAME", ),
136
+ "cn_weights": ("CONTROL_NET_WEIGHTS", ),
137
+ "latent_keyframe": ("LATENT_KEYFRAME", ),
138
+ "null_latent_kf_strength": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 10.0, "step": 0.001},),
139
+ "inherit_missing": ("BOOLEAN", {"default": True},),
140
+ "mask_optional": ("MASK", ),
141
+ "print_keyframes": ("BOOLEAN", {"default": False}),
142
+ }
143
+ }
144
+
145
+ RETURN_NAMES = ("TIMESTEP_KF", )
146
+ RETURN_TYPES = ("TIMESTEP_KEYFRAME", )
147
+ FUNCTION = "load_keyframe"
148
+
149
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
150
+
151
+ def load_keyframe(self,
152
+ start_percent: float, end_percent: float,
153
+ float_strengths: float,
154
+ cn_weights: ControlWeights=None,
155
+ latent_keyframe: LatentKeyframeGroup=None,
156
+ prev_timestep_kf: TimestepKeyframeGroup=None,
157
+ null_latent_kf_strength: float=0.0,
158
+ inherit_missing=True,
159
+ guarantee_steps=1,
160
+ mask_optional=None, print_keyframes=False):
161
+ if not prev_timestep_kf:
162
+ prev_timestep_kf = TimestepKeyframeGroup()
163
+ else:
164
+ prev_timestep_kf = prev_timestep_kf.clone()
165
+
166
+ if type(float_strengths) in (float, int):
167
+ float_strengths = [float(float_strengths)]
168
+ elif isinstance(float_strengths, Iterable):
169
+ pass
170
+ else:
171
+ raise Exception(f"strengths_float must be either an iterable input or a float, but was {type(float_strengths).__repr__}.")
172
+ percents = SI.get_weights(num_from=start_percent, num_to=end_percent, length=len(float_strengths), method=SI.LINEAR)
173
+
174
+ is_first = True
175
+ for percent, strength in zip(percents, float_strengths):
176
+ guarantee_steps = 0
177
+ if is_first:
178
+ guarantee_steps = 1
179
+ is_first = False
180
+ prev_timestep_kf.add(TimestepKeyframe(start_percent=percent, strength=strength, null_latent_kf_strength=null_latent_kf_strength,
181
+ control_weights=cn_weights, latent_keyframes=latent_keyframe, inherit_missing=inherit_missing,
182
+ guarantee_steps=guarantee_steps, mask_hint_orig=mask_optional))
183
+ if print_keyframes:
184
+ logger.info(f"TimestepKeyframe - start_percent:{percent} = {strength}")
185
+ return (prev_timestep_kf,)
186
+
187
+
188
+ class LatentKeyframeNode:
189
+ @classmethod
190
+ def INPUT_TYPES(s):
191
+ return {
192
+ "required": {
193
+ "batch_index": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
194
+ "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
195
+ },
196
+ "optional": {
197
+ "prev_latent_kf": ("LATENT_KEYFRAME", ),
198
+ }
199
+ }
200
+
201
+ RETURN_NAMES = ("LATENT_KF", )
202
+ RETURN_TYPES = ("LATENT_KEYFRAME", )
203
+ FUNCTION = "load_keyframe"
204
+
205
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
206
+
207
+ def load_keyframe(self,
208
+ batch_index: int,
209
+ strength: float,
210
+ prev_latent_kf: LatentKeyframeGroup=None,
211
+ prev_latent_keyframe: LatentKeyframeGroup=None, # old name
212
+ ):
213
+ prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
214
+ if not prev_latent_keyframe:
215
+ prev_latent_keyframe = LatentKeyframeGroup()
216
+ else:
217
+ prev_latent_keyframe = prev_latent_keyframe.clone()
218
+ keyframe = LatentKeyframe(batch_index, strength)
219
+ prev_latent_keyframe.add(keyframe)
220
+ return (prev_latent_keyframe,)
221
+
222
+
223
+ class LatentKeyframeGroupNode:
224
+ @classmethod
225
+ def INPUT_TYPES(s):
226
+ return {
227
+ "required": {
228
+ "index_strengths": ("STRING", {"multiline": True, "default": ""}),
229
+ },
230
+ "optional": {
231
+ "prev_latent_kf": ("LATENT_KEYFRAME", ),
232
+ "latent_optional": ("LATENT", ),
233
+ "print_keyframes": ("BOOLEAN", {"default": False})
234
+ }
235
+ }
236
+
237
+ RETURN_NAMES = ("LATENT_KF", )
238
+ RETURN_TYPES = ("LATENT_KEYFRAME", )
239
+ FUNCTION = "load_keyframes"
240
+
241
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
242
+
243
+ def validate_index(self, index: int, latent_count: int = 0, is_range: bool = False, allow_negative = False) -> int:
244
+ # if part of range, do nothing
245
+ if is_range:
246
+ return index
247
+ # otherwise, validate index
248
+ # validate not out of range - only when latent_count is passed in
249
+ if latent_count > 0 and index > latent_count-1:
250
+ raise IndexError(f"Index '{index}' out of range for the total {latent_count} latents.")
251
+ # if negative, validate not out of range
252
+ if index < 0:
253
+ if not allow_negative:
254
+ raise IndexError(f"Negative indeces not allowed, but was {index}.")
255
+ conv_index = latent_count+index
256
+ if conv_index < 0:
257
+ raise IndexError(f"Index '{index}', converted to '{conv_index}' out of range for the total {latent_count} latents.")
258
+ index = conv_index
259
+ return index
260
+
261
+ def convert_to_index_int(self, raw_index: str, latent_count: int = 0, is_range: bool = False, allow_negative = False) -> int:
262
+ try:
263
+ return self.validate_index(int(raw_index), latent_count=latent_count, is_range=is_range, allow_negative=allow_negative)
264
+ except ValueError as e:
265
+ raise ValueError(f"index '{raw_index}' must be an integer.", e)
266
+
267
+ def convert_to_latent_keyframes(self, latent_indeces: str, latent_count: int) -> set[LatentKeyframe]:
268
+ if not latent_indeces:
269
+ return set()
270
+ int_latent_indeces = [i for i in range(0, latent_count)]
271
+ allow_negative = latent_count > 0
272
+ chosen_indeces = set()
273
+ # parse string - allow positive ints, negative ints, and ranges separated by ':'
274
+ groups = latent_indeces.split(",")
275
+ groups = [g.strip() for g in groups]
276
+ for g in groups:
277
+ # parse strengths - default to 1.0 if no strength given
278
+ strength = 1.0
279
+ if '=' in g:
280
+ g, strength_str = g.split("=", 1)
281
+ g = g.strip()
282
+ try:
283
+ strength = float(strength_str.strip())
284
+ except ValueError as e:
285
+ raise ValueError(f"strength '{strength_str}' must be a float.", e)
286
+ if strength < 0:
287
+ raise ValueError(f"Strength '{strength}' cannot be negative.")
288
+ # parse range of indeces (e.g. 2:16)
289
+ if ':' in g:
290
+ index_range = g.split(":", 1)
291
+ index_range = [r.strip() for r in index_range]
292
+ start_index = self.convert_to_index_int(index_range[0], latent_count=latent_count, is_range=True, allow_negative=allow_negative)
293
+ end_index = self.convert_to_index_int(index_range[1], latent_count=latent_count, is_range=True, allow_negative=allow_negative)
294
+ # if latents were passed in, base indeces on known latent count
295
+ if len(int_latent_indeces) > 0:
296
+ for i in int_latent_indeces[start_index:end_index]:
297
+ chosen_indeces.add(LatentKeyframe(i, strength))
298
+ # otherwise, assume indeces are valid
299
+ else:
300
+ for i in range(start_index, end_index):
301
+ chosen_indeces.add(LatentKeyframe(i, strength))
302
+ # parse individual indeces
303
+ else:
304
+ chosen_indeces.add(LatentKeyframe(self.convert_to_index_int(g, latent_count=latent_count, allow_negative=allow_negative), strength))
305
+ return chosen_indeces
306
+
307
+ def load_keyframes(self,
308
+ index_strengths: str,
309
+ prev_latent_kf: LatentKeyframeGroup=None,
310
+ prev_latent_keyframe: LatentKeyframeGroup=None, # old name
311
+ latent_image_opt=None,
312
+ print_keyframes=False):
313
+ prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
314
+ if not prev_latent_keyframe:
315
+ prev_latent_keyframe = LatentKeyframeGroup()
316
+ else:
317
+ prev_latent_keyframe = prev_latent_keyframe.clone()
318
+ curr_latent_keyframe = LatentKeyframeGroup()
319
+
320
+ latent_count = -1
321
+ if latent_image_opt:
322
+ latent_count = latent_image_opt['samples'].size()[0]
323
+ latent_keyframes = self.convert_to_latent_keyframes(index_strengths, latent_count=latent_count)
324
+
325
+ for latent_keyframe in latent_keyframes:
326
+ curr_latent_keyframe.add(latent_keyframe)
327
+
328
+ if print_keyframes:
329
+ for keyframe in curr_latent_keyframe.keyframes:
330
+ logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
331
+
332
+ # replace values with prev_latent_keyframes
333
+ for latent_keyframe in prev_latent_keyframe.keyframes:
334
+ curr_latent_keyframe.add(latent_keyframe)
335
+
336
+ return (curr_latent_keyframe,)
337
+
338
+
339
+ class LatentKeyframeInterpolationNode:
340
+ @classmethod
341
+ def INPUT_TYPES(s):
342
+ return {
343
+ "required": {
344
+ "batch_index_from": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
345
+ "batch_index_to_excl": ("INT", {"default": 0, "min": BIGMIN, "max": BIGMAX, "step": 1}),
346
+ "strength_from": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
347
+ "strength_to": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
348
+ "interpolation": (SI._LIST, ),
349
+ },
350
+ "optional": {
351
+ "prev_latent_kf": ("LATENT_KEYFRAME", ),
352
+ "print_keyframes": ("BOOLEAN", {"default": False})
353
+ }
354
+ }
355
+
356
+ RETURN_NAMES = ("LATENT_KF", )
357
+ RETURN_TYPES = ("LATENT_KEYFRAME", )
358
+ FUNCTION = "load_keyframe"
359
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
360
+
361
+ def load_keyframe(self,
362
+ batch_index_from: int,
363
+ strength_from: float,
364
+ batch_index_to_excl: int,
365
+ strength_to: float,
366
+ interpolation: str,
367
+ prev_latent_kf: LatentKeyframeGroup=None,
368
+ prev_latent_keyframe: LatentKeyframeGroup=None, # old name
369
+ print_keyframes=False):
370
+
371
+ if (batch_index_from > batch_index_to_excl):
372
+ raise ValueError("batch_index_from must be less than or equal to batch_index_to.")
373
+
374
+ if (batch_index_from < 0 and batch_index_to_excl >= 0):
375
+ raise ValueError("batch_index_from and batch_index_to must be either both positive or both negative.")
376
+
377
+ prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
378
+ if not prev_latent_keyframe:
379
+ prev_latent_keyframe = LatentKeyframeGroup()
380
+ else:
381
+ prev_latent_keyframe = prev_latent_keyframe.clone()
382
+ curr_latent_keyframe = LatentKeyframeGroup()
383
+
384
+ steps = batch_index_to_excl - batch_index_from
385
+ diff = strength_to - strength_from
386
+ if interpolation == SI.LINEAR:
387
+ weights = np.linspace(strength_from, strength_to, steps)
388
+ elif interpolation == SI.EASE_IN:
389
+ index = np.linspace(0, 1, steps)
390
+ weights = diff * np.power(index, 2) + strength_from
391
+ elif interpolation == SI.EASE_OUT:
392
+ index = np.linspace(0, 1, steps)
393
+ weights = diff * (1 - np.power(1 - index, 2)) + strength_from
394
+ elif interpolation == SI.EASE_IN_OUT:
395
+ index = np.linspace(0, 1, steps)
396
+ weights = diff * ((1 - np.cos(index * np.pi)) / 2) + strength_from
397
+
398
+ for i in range(steps):
399
+ keyframe = LatentKeyframe(batch_index_from + i, float(weights[i]))
400
+ curr_latent_keyframe.add(keyframe)
401
+
402
+ if print_keyframes:
403
+ for keyframe in curr_latent_keyframe.keyframes:
404
+ logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
405
+
406
+ # replace values with prev_latent_keyframes
407
+ for latent_keyframe in prev_latent_keyframe.keyframes:
408
+ curr_latent_keyframe.add(latent_keyframe)
409
+
410
+ return (curr_latent_keyframe,)
411
+
412
+
413
+ class LatentKeyframeBatchedGroupNode:
414
+ @classmethod
415
+ def INPUT_TYPES(s):
416
+ return {
417
+ "required": {
418
+ "float_strengths": ("FLOAT", {"default": -1, "min": -1, "step": 0.001, "forceInput": True}),
419
+ },
420
+ "optional": {
421
+ "prev_latent_kf": ("LATENT_KEYFRAME", ),
422
+ "print_keyframes": ("BOOLEAN", {"default": False})
423
+ }
424
+ }
425
+
426
+ RETURN_NAMES = ("LATENT_KF", )
427
+ RETURN_TYPES = ("LATENT_KEYFRAME", )
428
+ FUNCTION = "load_keyframe"
429
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/keyframes"
430
+
431
+ def load_keyframe(self, float_strengths: Union[float, list[float]],
432
+ prev_latent_kf: LatentKeyframeGroup=None,
433
+ prev_latent_keyframe: LatentKeyframeGroup=None, # old name
434
+ print_keyframes=False):
435
+ prev_latent_keyframe = prev_latent_keyframe if prev_latent_keyframe else prev_latent_kf
436
+ if not prev_latent_keyframe:
437
+ prev_latent_keyframe = LatentKeyframeGroup()
438
+ else:
439
+ prev_latent_keyframe = prev_latent_keyframe.clone()
440
+ curr_latent_keyframe = LatentKeyframeGroup()
441
+
442
+ # if received a normal float input, do nothing
443
+ if type(float_strengths) in (float, int):
444
+ logger.info("No batched float_strengths passed into Latent Keyframe Batch Group node; will not create any new keyframes.")
445
+ # if iterable, attempt to create LatentKeyframes with chosen strengths
446
+ elif isinstance(float_strengths, Iterable):
447
+ for idx, strength in enumerate(float_strengths):
448
+ keyframe = LatentKeyframe(idx, strength)
449
+ curr_latent_keyframe.add(keyframe)
450
+ else:
451
+ raise ValueError(f"Expected strengths to be an iterable input, but was {type(float_strengths).__repr__}.")
452
+
453
+ if print_keyframes:
454
+ for keyframe in curr_latent_keyframe.keyframes:
455
+ logger.info(f"LatentKeyframe {keyframe.batch_index}={keyframe.strength}")
456
+
457
+ # replace values with prev_latent_keyframes
458
+ for latent_keyframe in prev_latent_keyframe.keyframes:
459
+ curr_latent_keyframe.add(latent_keyframe)
460
+
461
+ return (curr_latent_keyframe,)
ComfyUI-Advanced-ControlNet/adv_control/nodes_loosecontrol.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import folder_paths
2
+ import comfy.utils
3
+ import comfy.model_detection
4
+ import comfy.model_management
5
+ import comfy.lora
6
+ from comfy.model_patcher import ModelPatcher
7
+
8
+ from .utils import TimestepKeyframeGroup
9
+ from .control import ControlNetAdvanced, load_controlnet
10
+
11
+
12
+
13
+
14
+ def convert_cn_lora_from_diffusers(cn_model: ModelPatcher, lora_path: str):
15
+ lora_data = comfy.utils.load_torch_file(lora_path, safe_load=True)
16
+ unet_dtype = comfy.model_management.unet_dtype()
17
+ for key, value in lora_data.items():
18
+ lora_data[key] = value.to(unet_dtype)
19
+ diffusers_keys = comfy.utils.unet_to_diffusers(cn_model.model.state_dict())
20
+
21
+ #lora_data = comfy.model_detection.unet_config_from_diffusers_unet(lora_data, dtype=unet_dtype)
22
+
23
+
24
+
25
+ #key_map = comfy.lora.model_lora_keys_unet(cn_model.model, key_map)
26
+ lora_data = comfy.lora.load_lora(lora_data, to_load=diffusers_keys)
27
+
28
+ # TODO: detect if diffusers for sure? not sure if needed at this time, since cn loras are
29
+ # only used currently for LOOSEControl, and those are all in diffusers format
30
+ #unet_dtype = comfy.model_management.unet_dtype()
31
+ #lora_data = comfy.model_detection.unet_config_from_diffusers_unet(lora_data, unet_dtype)
32
+ return lora_data
33
+
34
+
35
+ class ControlNetLoaderWithLoraAdvanced:
36
+ @classmethod
37
+ def INPUT_TYPES(s):
38
+ return {
39
+ "required": {
40
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), ),
41
+ "cn_lora_name": (folder_paths.get_filename_list("controlnet"), ),
42
+ "cn_lora_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}),
43
+ },
44
+ "optional": {
45
+ "timestep_keyframe": ("TIMESTEP_KEYFRAME", ),
46
+ }
47
+ }
48
+
49
+ RETURN_TYPES = ("CONTROL_NET", )
50
+ FUNCTION = "load_controlnet"
51
+
52
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/LOOSEControl"
53
+
54
+ def load_controlnet(self, control_net_name, cn_lora_name, cn_lora_strength: float,
55
+ timestep_keyframe: TimestepKeyframeGroup=None
56
+ ):
57
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
58
+ controlnet: ControlNetAdvanced = load_controlnet(controlnet_path, timestep_keyframe)
59
+ if not isinstance(controlnet, ControlNetAdvanced):
60
+ raise ValueError("Type {} is not compatible with CN LoRA features at this time.")
61
+ # now, try to load CN LoRA
62
+ lora_path = folder_paths.get_full_path("controlnet", cn_lora_name)
63
+ lora_data = convert_cn_lora_from_diffusers(cn_model=controlnet.control_model_wrapped, lora_path=lora_path)
64
+ # apply patches to wrapped control_model
65
+ controlnet.control_model_wrapped.add_patches(lora_data, strength_patch=cn_lora_strength)
66
+ # all done
67
+ return (controlnet,)
ComfyUI-Advanced-ControlNet/adv_control/nodes_reference.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+
3
+ from nodes import VAEEncode
4
+ import comfy.utils
5
+ from comfy.sd import VAE
6
+
7
+ from .control_reference import ReferenceAdvanced, ReferenceOptions, ReferenceType, ReferencePreprocWrapper
8
+
9
+
10
+ # node for ReferenceCN
11
+ class ReferenceControlNetNode:
12
+ @classmethod
13
+ def INPUT_TYPES(s):
14
+ return {
15
+ "required": {
16
+ "reference_type": (ReferenceType._LIST,),
17
+ "style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
18
+ "ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
19
+ },
20
+ }
21
+
22
+ RETURN_TYPES = ("CONTROL_NET", )
23
+ FUNCTION = "load_controlnet"
24
+
25
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference"
26
+
27
+ def load_controlnet(self, reference_type: str, style_fidelity: float, ref_weight: float):
28
+ ref_opts = ReferenceOptions.create_combo(reference_type=reference_type, style_fidelity=style_fidelity, ref_weight=ref_weight)
29
+ controlnet = ReferenceAdvanced(ref_opts=ref_opts, timestep_keyframes=None)
30
+ return (controlnet,)
31
+
32
+
33
+ class ReferenceControlFinetune:
34
+ @classmethod
35
+ def INPUT_TYPES(s):
36
+ return {
37
+ "required": {
38
+ "attn_style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
39
+ "attn_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
40
+ "attn_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
41
+ "adain_style_fidelity": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
42
+ "adain_ref_weight": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
43
+ "adain_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
44
+ },
45
+ }
46
+
47
+ RETURN_TYPES = ("CONTROL_NET", )
48
+ FUNCTION = "load_controlnet"
49
+
50
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference"
51
+
52
+ def load_controlnet(self,
53
+ attn_style_fidelity: float, attn_ref_weight: float, attn_strength: float,
54
+ adain_style_fidelity: float, adain_ref_weight: float, adain_strength: float):
55
+ ref_opts = ReferenceOptions(reference_type=ReferenceType.ATTN_ADAIN,
56
+ attn_style_fidelity=attn_style_fidelity, attn_ref_weight=attn_ref_weight, attn_strength=attn_strength,
57
+ adain_style_fidelity=adain_style_fidelity, adain_ref_weight=adain_ref_weight, adain_strength=adain_strength)
58
+ controlnet = ReferenceAdvanced(ref_opts=ref_opts, timestep_keyframes=None)
59
+ return (controlnet,)
60
+
61
+
62
+ class ReferencePreprocessorNode:
63
+ @classmethod
64
+ def INPUT_TYPES(s):
65
+ return {
66
+ "required": {
67
+ "image": ("IMAGE", ),
68
+ "vae": ("VAE", ),
69
+ "latent_size": ("LATENT", ),
70
+ }
71
+ }
72
+
73
+ RETURN_TYPES = ("IMAGE",)
74
+ RETURN_NAMES = ("proc_IMAGE",)
75
+ FUNCTION = "preprocess_images"
76
+
77
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/Reference/preprocess"
78
+
79
+ def preprocess_images(self, vae: VAE, image: Tensor, latent_size: Tensor):
80
+ # first, resize image to match latents
81
+ image = image.movedim(-1,1)
82
+ image = comfy.utils.common_upscale(image, latent_size["samples"].shape[3] * 8, latent_size["samples"].shape[2] * 8, 'nearest-exact', "center")
83
+ image = image.movedim(1,-1)
84
+ # then, vae encode
85
+ try:
86
+ image = vae.vae_encode_crop_pixels(image)
87
+ except Exception:
88
+ image = VAEEncode.vae_encode_crop_pixels(image)
89
+ encoded = vae.encode(image[:,:,:,:3])
90
+ return (ReferencePreprocWrapper(condhint=encoded),)
ComfyUI-Advanced-ControlNet/adv_control/nodes_sparsectrl.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+
3
+ import folder_paths
4
+ from nodes import VAEEncode
5
+ import comfy.utils
6
+ from comfy.sd import VAE
7
+
8
+ from .utils import TimestepKeyframeGroup
9
+ from .control_sparsectrl import SparseMethod, SparseIndexMethod, SparseSettings, SparseSpreadMethod, PreprocSparseRGBWrapper
10
+ from .control import load_sparsectrl, load_controlnet, ControlNetAdvanced, SparseCtrlAdvanced
11
+
12
+
13
+ # node for SparseCtrl loading
14
+ class SparseCtrlLoaderAdvanced:
15
+ @classmethod
16
+ def INPUT_TYPES(s):
17
+ return {
18
+ "required": {
19
+ "sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ),
20
+ "use_motion": ("BOOLEAN", {"default": True}, ),
21
+ "motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
22
+ "motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
23
+ },
24
+ "optional": {
25
+ "sparse_method": ("SPARSE_METHOD", ),
26
+ "tk_optional": ("TIMESTEP_KEYFRAME", ),
27
+ }
28
+ }
29
+
30
+ RETURN_TYPES = ("CONTROL_NET", )
31
+ FUNCTION = "load_controlnet"
32
+
33
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
34
+
35
+ def load_controlnet(self, sparsectrl_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
36
+ sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
37
+ sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale)
38
+ sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
39
+ return (sparsectrl,)
40
+
41
+
42
+ class SparseCtrlMergedLoaderAdvanced:
43
+ @classmethod
44
+ def INPUT_TYPES(s):
45
+ return {
46
+ "required": {
47
+ "sparsectrl_name": (folder_paths.get_filename_list("controlnet"), ),
48
+ "control_net_name": (folder_paths.get_filename_list("controlnet"), ),
49
+ "use_motion": ("BOOLEAN", {"default": True}, ),
50
+ "motion_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
51
+ "motion_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
52
+ },
53
+ "optional": {
54
+ "sparse_method": ("SPARSE_METHOD", ),
55
+ "tk_optional": ("TIMESTEP_KEYFRAME", ),
56
+ }
57
+ }
58
+
59
+ RETURN_TYPES = ("CONTROL_NET", )
60
+ FUNCTION = "load_controlnet"
61
+
62
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl/experimental"
63
+
64
+ def load_controlnet(self, sparsectrl_name: str, control_net_name: str, use_motion: bool, motion_strength: float, motion_scale: float, sparse_method: SparseMethod=SparseSpreadMethod(), tk_optional: TimestepKeyframeGroup=None):
65
+ sparsectrl_path = folder_paths.get_full_path("controlnet", sparsectrl_name)
66
+ controlnet_path = folder_paths.get_full_path("controlnet", control_net_name)
67
+ sparse_settings = SparseSettings(sparse_method=sparse_method, use_motion=use_motion, motion_strength=motion_strength, motion_scale=motion_scale, merged=True)
68
+ # first, load normal controlnet
69
+ controlnet = load_controlnet(controlnet_path, timestep_keyframe=tk_optional)
70
+ # confirm that controlnet is ControlNetAdvanced
71
+ if controlnet is None or type(controlnet) != ControlNetAdvanced:
72
+ raise ValueError(f"controlnet_path must point to a normal ControlNet, but instead: {type(controlnet).__name__}")
73
+ # next, load sparsectrl, making sure to load motion portion
74
+ sparsectrl = load_sparsectrl(sparsectrl_path, timestep_keyframe=tk_optional, sparse_settings=SparseSettings.default())
75
+ # now, combine state dicts
76
+ new_state_dict = controlnet.control_model.state_dict()
77
+ for key, value in sparsectrl.control_model.motion_holder.motion_wrapper.state_dict().items():
78
+ new_state_dict[key] = value
79
+ # now, reload sparsectrl with real settings
80
+ sparsectrl = load_sparsectrl(sparsectrl_path, controlnet_data=new_state_dict, timestep_keyframe=tk_optional, sparse_settings=sparse_settings)
81
+ return (sparsectrl,)
82
+
83
+
84
+ class SparseIndexMethodNode:
85
+ @classmethod
86
+ def INPUT_TYPES(s):
87
+ return {
88
+ "required": {
89
+ "indexes": ("STRING", {"default": "0"}),
90
+ }
91
+ }
92
+
93
+ RETURN_TYPES = ("SPARSE_METHOD",)
94
+ FUNCTION = "get_method"
95
+
96
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
97
+
98
+ def get_method(self, indexes: str):
99
+ idxs = []
100
+ unique_idxs = set()
101
+ # get indeces from string
102
+ str_idxs = [x.strip() for x in indexes.strip().split(",")]
103
+ for str_idx in str_idxs:
104
+ try:
105
+ idx = int(str_idx)
106
+ if idx in unique_idxs:
107
+ raise ValueError(f"'{idx}' is duplicated; indexes must be unique.")
108
+ idxs.append(idx)
109
+ unique_idxs.add(idx)
110
+ except ValueError:
111
+ raise ValueError(f"'{str_idx}' is not a valid integer index.")
112
+ if len(idxs) == 0:
113
+ raise ValueError(f"No indexes were listed in Sparse Index Method.")
114
+ return (SparseIndexMethod(idxs),)
115
+
116
+
117
+ class SparseSpreadMethodNode:
118
+ @classmethod
119
+ def INPUT_TYPES(s):
120
+ return {
121
+ "required": {
122
+ "spread": (SparseSpreadMethod.LIST,),
123
+ }
124
+ }
125
+
126
+ RETURN_TYPES = ("SPARSE_METHOD",)
127
+ FUNCTION = "get_method"
128
+
129
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl"
130
+
131
+ def get_method(self, spread: str):
132
+ return (SparseSpreadMethod(spread=spread),)
133
+
134
+
135
+ class RgbSparseCtrlPreprocessor:
136
+ @classmethod
137
+ def INPUT_TYPES(s):
138
+ return {
139
+ "required": {
140
+ "image": ("IMAGE", ),
141
+ "vae": ("VAE", ),
142
+ "latent_size": ("LATENT", ),
143
+ }
144
+ }
145
+
146
+ RETURN_TYPES = ("IMAGE",)
147
+ RETURN_NAMES = ("proc_IMAGE",)
148
+ FUNCTION = "preprocess_images"
149
+
150
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/SparseCtrl/preprocess"
151
+
152
+ def preprocess_images(self, vae: VAE, image: Tensor, latent_size: Tensor):
153
+ # first, resize image to match latents
154
+ image = image.movedim(-1,1)
155
+ image = comfy.utils.common_upscale(image, latent_size["samples"].shape[3] * 8, latent_size["samples"].shape[2] * 8, 'nearest-exact', "center")
156
+ image = image.movedim(1,-1)
157
+ # then, vae encode
158
+ try:
159
+ image = vae.vae_encode_crop_pixels(image)
160
+ except Exception:
161
+ image = VAEEncode.vae_encode_crop_pixels(image)
162
+ encoded = vae.encode(image[:,:,:,:3])
163
+ return (PreprocSparseRGBWrapper(condhint=encoded),)
ComfyUI-Advanced-ControlNet/adv_control/nodes_weight.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor
2
+ import torch
3
+ from .utils import TimestepKeyframe, TimestepKeyframeGroup, ControlWeights, get_properly_arranged_t2i_weights, linear_conversion
4
+ from .logger import logger
5
+
6
+
7
+ WEIGHTS_RETURN_NAMES = ("CN_WEIGHTS", "TK_SHORTCUT")
8
+
9
+
10
+ class DefaultWeights:
11
+ @classmethod
12
+ def INPUT_TYPES(s):
13
+ return {
14
+ }
15
+
16
+ RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
17
+ RETURN_NAMES = WEIGHTS_RETURN_NAMES
18
+ FUNCTION = "load_weights"
19
+
20
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
21
+
22
+ def load_weights(self):
23
+ weights = ControlWeights.default()
24
+ return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
25
+
26
+
27
+ class ScaledSoftMaskedUniversalWeights:
28
+ @classmethod
29
+ def INPUT_TYPES(s):
30
+ return {
31
+ "required": {
32
+ "mask": ("MASK", ),
33
+ "min_base_multiplier": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
34
+ "max_base_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}, ),
35
+ #"lock_min": ("BOOLEAN", {"default": False}, ),
36
+ #"lock_max": ("BOOLEAN", {"default": False}, ),
37
+ },
38
+ "optional": {
39
+ "uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
40
+ }
41
+ }
42
+
43
+ RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
44
+ RETURN_NAMES = WEIGHTS_RETURN_NAMES
45
+ FUNCTION = "load_weights"
46
+
47
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
48
+
49
+ def load_weights(self, mask: Tensor, min_base_multiplier: float, max_base_multiplier: float, lock_min=False, lock_max=False,
50
+ uncond_multiplier: float=1.0):
51
+ # normalize mask
52
+ mask = mask.clone()
53
+ x_min = 0.0 if lock_min else mask.min()
54
+ x_max = 1.0 if lock_max else mask.max()
55
+ if x_min == x_max:
56
+ mask = torch.ones_like(mask) * max_base_multiplier
57
+ else:
58
+ mask = linear_conversion(mask, x_min, x_max, min_base_multiplier, max_base_multiplier)
59
+ weights = ControlWeights.universal_mask(weight_mask=mask, uncond_multiplier=uncond_multiplier)
60
+ return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
61
+
62
+
63
+ class ScaledSoftUniversalWeights:
64
+ @classmethod
65
+ def INPUT_TYPES(s):
66
+ return {
67
+ "required": {
68
+ "base_multiplier": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 1.0, "step": 0.001}, ),
69
+ "flip_weights": ("BOOLEAN", {"default": False}),
70
+ },
71
+ "optional": {
72
+ "uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
73
+ }
74
+ }
75
+
76
+ RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
77
+ RETURN_NAMES = WEIGHTS_RETURN_NAMES
78
+ FUNCTION = "load_weights"
79
+
80
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights"
81
+
82
+ def load_weights(self, base_multiplier, flip_weights, uncond_multiplier: float=1.0):
83
+ weights = ControlWeights.universal(base_multiplier=base_multiplier, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
84
+ return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
85
+
86
+
87
+ class SoftControlNetWeights:
88
+ @classmethod
89
+ def INPUT_TYPES(s):
90
+ return {
91
+ "required": {
92
+ "weight_00": ("FLOAT", {"default": 0.09941396206337118, "min": 0.0, "max": 10.0, "step": 0.001}, ),
93
+ "weight_01": ("FLOAT", {"default": 0.12050177219802567, "min": 0.0, "max": 10.0, "step": 0.001}, ),
94
+ "weight_02": ("FLOAT", {"default": 0.14606275417942507, "min": 0.0, "max": 10.0, "step": 0.001}, ),
95
+ "weight_03": ("FLOAT", {"default": 0.17704576264172736, "min": 0.0, "max": 10.0, "step": 0.001}, ),
96
+ "weight_04": ("FLOAT", {"default": 0.214600924414215, "min": 0.0, "max": 10.0, "step": 0.001}, ),
97
+ "weight_05": ("FLOAT", {"default": 0.26012233262329093, "min": 0.0, "max": 10.0, "step": 0.001}, ),
98
+ "weight_06": ("FLOAT", {"default": 0.3152997971191405, "min": 0.0, "max": 10.0, "step": 0.001}, ),
99
+ "weight_07": ("FLOAT", {"default": 0.3821815722656249, "min": 0.0, "max": 10.0, "step": 0.001}, ),
100
+ "weight_08": ("FLOAT", {"default": 0.4632503906249999, "min": 0.0, "max": 10.0, "step": 0.001}, ),
101
+ "weight_09": ("FLOAT", {"default": 0.561515625, "min": 0.0, "max": 10.0, "step": 0.001}, ),
102
+ "weight_10": ("FLOAT", {"default": 0.6806249999999999, "min": 0.0, "max": 10.0, "step": 0.001}, ),
103
+ "weight_11": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ),
104
+ "weight_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
105
+ "flip_weights": ("BOOLEAN", {"default": False}),
106
+ },
107
+ "optional": {
108
+ "uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
109
+ }
110
+ }
111
+
112
+ RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
113
+ RETURN_NAMES = WEIGHTS_RETURN_NAMES
114
+ FUNCTION = "load_weights"
115
+
116
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet"
117
+
118
+ def load_weights(self, weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
119
+ weight_07, weight_08, weight_09, weight_10, weight_11, weight_12, flip_weights,
120
+ uncond_multiplier: float=1.0):
121
+ weights = [weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
122
+ weight_07, weight_08, weight_09, weight_10, weight_11, weight_12]
123
+ weights = ControlWeights.controlnet(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
124
+ return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
125
+
126
+
127
+ class CustomControlNetWeights:
128
+ @classmethod
129
+ def INPUT_TYPES(s):
130
+ return {
131
+ "required": {
132
+ "weight_00": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
133
+ "weight_01": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
134
+ "weight_02": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
135
+ "weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
136
+ "weight_04": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
137
+ "weight_05": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
138
+ "weight_06": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
139
+ "weight_07": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
140
+ "weight_08": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
141
+ "weight_09": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
142
+ "weight_10": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
143
+ "weight_11": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
144
+ "weight_12": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
145
+ "flip_weights": ("BOOLEAN", {"default": False}),
146
+ },
147
+ "optional": {
148
+ "uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
149
+ }
150
+ }
151
+
152
+ RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
153
+ RETURN_NAMES = WEIGHTS_RETURN_NAMES
154
+ FUNCTION = "load_weights"
155
+
156
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/ControlNet"
157
+
158
+ def load_weights(self, weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
159
+ weight_07, weight_08, weight_09, weight_10, weight_11, weight_12, flip_weights,
160
+ uncond_multiplier: float=1.0):
161
+ weights = [weight_00, weight_01, weight_02, weight_03, weight_04, weight_05, weight_06,
162
+ weight_07, weight_08, weight_09, weight_10, weight_11, weight_12]
163
+ weights = ControlWeights.controlnet(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
164
+ return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
165
+
166
+
167
+ class SoftT2IAdapterWeights:
168
+ @classmethod
169
+ def INPUT_TYPES(s):
170
+ return {
171
+ "required": {
172
+ "weight_00": ("FLOAT", {"default": 0.25, "min": 0.0, "max": 10.0, "step": 0.001}, ),
173
+ "weight_01": ("FLOAT", {"default": 0.62, "min": 0.0, "max": 10.0, "step": 0.001}, ),
174
+ "weight_02": ("FLOAT", {"default": 0.825, "min": 0.0, "max": 10.0, "step": 0.001}, ),
175
+ "weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
176
+ "flip_weights": ("BOOLEAN", {"default": False}),
177
+ },
178
+ "optional": {
179
+ "uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
180
+ }
181
+ }
182
+
183
+ RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
184
+ RETURN_NAMES = WEIGHTS_RETURN_NAMES
185
+ FUNCTION = "load_weights"
186
+
187
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/T2IAdapter"
188
+
189
+ def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
190
+ uncond_multiplier: float=1.0):
191
+ weights = [weight_00, weight_01, weight_02, weight_03]
192
+ weights = get_properly_arranged_t2i_weights(weights)
193
+ weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
194
+ return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
195
+
196
+
197
+ class CustomT2IAdapterWeights:
198
+ @classmethod
199
+ def INPUT_TYPES(s):
200
+ return {
201
+ "required": {
202
+ "weight_00": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
203
+ "weight_01": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
204
+ "weight_02": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
205
+ "weight_03": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}, ),
206
+ "flip_weights": ("BOOLEAN", {"default": False}),
207
+ },
208
+ "optional": {
209
+ "uncond_multiplier": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}, ),
210
+ }
211
+ }
212
+
213
+ RETURN_TYPES = ("CONTROL_NET_WEIGHTS", "TIMESTEP_KEYFRAME",)
214
+ RETURN_NAMES = WEIGHTS_RETURN_NAMES
215
+ FUNCTION = "load_weights"
216
+
217
+ CATEGORY = "Adv-ControlNet 🛂🅐🅒🅝/weights/T2IAdapter"
218
+
219
+ def load_weights(self, weight_00, weight_01, weight_02, weight_03, flip_weights,
220
+ uncond_multiplier: float=1.0):
221
+ weights = [weight_00, weight_01, weight_02, weight_03]
222
+ weights = get_properly_arranged_t2i_weights(weights)
223
+ weights = ControlWeights.t2iadapter(weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
224
+ return (weights, TimestepKeyframeGroup.default(TimestepKeyframe(control_weights=weights)))
ComfyUI-Advanced-ControlNet/adv_control/utils.py ADDED
@@ -0,0 +1,915 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from typing import Callable, Union
3
+ import torch
4
+ from torch import Tensor
5
+ import torch.nn.functional
6
+ import numpy as np
7
+ import math
8
+
9
+ import comfy.ops
10
+ import comfy.utils
11
+ import comfy.sample
12
+ import comfy.samplers
13
+ import comfy.model_base
14
+
15
+ from comfy.controlnet import ControlBase, broadcast_image_to
16
+ from comfy.model_patcher import ModelPatcher
17
+
18
+ from .logger import logger
19
+
20
+ BIGMIN = -(2**53-1)
21
+ BIGMAX = (2**53-1)
22
+
23
+ def load_torch_file_with_dict_factory(controlnet_data: dict[str, Tensor], orig_load_torch_file: Callable):
24
+ def load_torch_file_with_dict(*args, **kwargs):
25
+ # immediately restore load_torch_file to original version
26
+ comfy.utils.load_torch_file = orig_load_torch_file
27
+ return controlnet_data
28
+ return load_torch_file_with_dict
29
+
30
+ # wrapping len function so that it will save the thing len is trying to get the length of;
31
+ # this will be assumed to be the cond_or_uncond variable;
32
+ # automatically restores len to original function after running
33
+ def wrapper_len_factory(orig_len: Callable) -> Callable:
34
+ def wrapper_len(*args, **kwargs):
35
+ cond_or_uncond = args[0]
36
+ real_length = orig_len(*args, **kwargs)
37
+ if real_length > 0 and type(cond_or_uncond) == list and (cond_or_uncond[0] in [0, 1]):
38
+ try:
39
+ to_return = IntWithCondOrUncond(real_length)
40
+ setattr(to_return, "cond_or_uncond", cond_or_uncond)
41
+ return to_return
42
+ finally:
43
+ __builtins__["len"] = orig_len
44
+ else:
45
+ return real_length
46
+ return wrapper_len
47
+
48
+ # wrapping cond_cat function so that it will wrap around len function to get cond_or_uncond variable value
49
+ # from comfy.samplers.calc_conds_batch
50
+ def wrapper_cond_cat_factory(orig_cond_cat: Callable):
51
+ def wrapper_cond_cat(*args, **kwargs):
52
+ __builtins__["len"] = wrapper_len_factory(__builtins__["len"])
53
+ return orig_cond_cat(*args, **kwargs)
54
+ return wrapper_cond_cat
55
+ orig_cond_cat = comfy.samplers.cond_cat
56
+ comfy.samplers.cond_cat = wrapper_cond_cat_factory(orig_cond_cat)
57
+
58
+
59
+ # wrapping apply_model so that len function will be cleaned up fairly soon after being injected
60
+ def apply_model_uncond_cleanup_factory(orig_apply_model, orig_len):
61
+ def apply_model_uncond_cleanup_wrapper(self, *args, **kwargs):
62
+ __builtins__["len"] = orig_len
63
+ return orig_apply_model(self, *args, **kwargs)
64
+ return apply_model_uncond_cleanup_wrapper
65
+ global_orig_len = __builtins__["len"]
66
+ orig_apply_model = comfy.model_base.BaseModel.apply_model
67
+ comfy.model_base.BaseModel.apply_model = apply_model_uncond_cleanup_factory(orig_apply_model, global_orig_len)
68
+
69
+
70
+ def uncond_multiplier_check_cn_sample_factory(orig_comfy_sample: Callable, is_custom=False) -> Callable:
71
+ def contains_uncond_multiplier(control: Union[ControlBase, 'AdvancedControlBase']):
72
+ if control is None:
73
+ return False
74
+ if not isinstance(control, AdvancedControlBase):
75
+ return contains_uncond_multiplier(control.previous_controlnet)
76
+ # check if weights_override has an uncond_multiplier
77
+ if control.weights_override is not None and control.weights_override.has_uncond_multiplier:
78
+ return True
79
+ # check if any timestep_keyframes have an uncond_multiplier on their weights
80
+ if control.timestep_keyframes is not None:
81
+ for tk in control.timestep_keyframes.keyframes:
82
+ if tk.has_control_weights() and tk.control_weights.has_uncond_multiplier:
83
+ return True
84
+ return contains_uncond_multiplier(control.previous_controlnet)
85
+
86
+ # check if positive or negative conds contain Adv. Cns that use multiply_negative on weights
87
+ def uncond_multiplier_check_cn_sample(model: ModelPatcher, *args, **kwargs):
88
+ positive = args[-3]
89
+ negative = args[-2]
90
+ has_uncond_multiplier = False
91
+ if positive is not None:
92
+ for cond in positive:
93
+ if "control" in cond[1]:
94
+ has_uncond_multiplier = contains_uncond_multiplier(cond[1]["control"])
95
+ if has_uncond_multiplier:
96
+ break
97
+ if negative is not None and not has_uncond_multiplier:
98
+ for cond in negative:
99
+ if "control" in cond[1]:
100
+ has_uncond_multiplier = contains_uncond_multiplier(cond[1]["control"])
101
+ if has_uncond_multiplier:
102
+ break
103
+ try:
104
+ # if uncond_multiplier found, continue to use wrapped version of function
105
+ if has_uncond_multiplier:
106
+ return orig_comfy_sample(model, *args, **kwargs)
107
+ # otherwise, use original version of function to prevent even the smallest of slowdowns (0.XX%)
108
+ try:
109
+ wrapped_cond_cat = comfy.samplers.cond_cat
110
+ comfy.samplers.cond_cat = orig_cond_cat
111
+ return orig_comfy_sample(model, *args, **kwargs)
112
+ finally:
113
+ comfy.samplers.cond_cat = wrapped_cond_cat
114
+ finally:
115
+ # make sure len function is unwrapped by the time sampling is done, just in case
116
+ __builtins__["len"] = global_orig_len
117
+ return uncond_multiplier_check_cn_sample
118
+ # inject sample functions
119
+ comfy.sample.sample = uncond_multiplier_check_cn_sample_factory(comfy.sample.sample)
120
+ comfy.sample.sample_custom = uncond_multiplier_check_cn_sample_factory(comfy.sample.sample_custom, is_custom=True)
121
+
122
+
123
+ class IntWithCondOrUncond(int):
124
+ def __new__(cls, *args, **kwargs):
125
+ return super(IntWithCondOrUncond, cls).__new__(cls, *args, **kwargs)
126
+
127
+ def __init__(self, *args, **kwargs):
128
+ super().__init__()
129
+ self.cond_or_uncond = None
130
+
131
+
132
+
133
+ def get_properly_arranged_t2i_weights(initial_weights: list[float]):
134
+ new_weights = []
135
+ new_weights.extend([initial_weights[0]]*3)
136
+ new_weights.extend([initial_weights[1]]*3)
137
+ new_weights.extend([initial_weights[2]]*3)
138
+ new_weights.extend([initial_weights[3]]*3)
139
+ return new_weights
140
+
141
+
142
+ class ControlWeightType:
143
+ DEFAULT = "default"
144
+ UNIVERSAL = "universal"
145
+ T2IADAPTER = "t2iadapter"
146
+ CONTROLNET = "controlnet"
147
+ CONTROLLORA = "controllora"
148
+ CONTROLLLLITE = "controllllite"
149
+ SVD_CONTROLNET = "svd_controlnet"
150
+ SPARSECTRL = "sparsectrl"
151
+
152
+
153
+ class ControlWeights:
154
+ def __init__(self, weight_type: str, base_multiplier: float=1.0, flip_weights: bool=False, weights: list[float]=None, weight_mask: Tensor=None,
155
+ uncond_multiplier=1.0):
156
+ self.weight_type = weight_type
157
+ self.base_multiplier = base_multiplier
158
+ self.flip_weights = flip_weights
159
+ self.weights = weights
160
+ if self.weights is not None and self.flip_weights:
161
+ self.weights.reverse()
162
+ self.weight_mask = weight_mask
163
+ self.uncond_multiplier = float(uncond_multiplier)
164
+ self.has_uncond_multiplier = not math.isclose(self.uncond_multiplier, 1.0)
165
+
166
+ def get(self, idx: int, default=1.0) -> Union[float, Tensor]:
167
+ # if weights is not none, return index
168
+ if self.weights is not None:
169
+ # this implies weights list is not aligning with expectations - will need to adjust code
170
+ if idx >= len(self.weights):
171
+ return default
172
+ return self.weights[idx]
173
+ return 1.0
174
+
175
+ def copy_with_new_weights(self, new_weights: list[float]):
176
+ return ControlWeights(weight_type=self.weight_type, base_multiplier=self.base_multiplier, flip_weights=self.flip_weights,
177
+ weights=new_weights, weight_mask=self.weight_mask, uncond_multiplier=self.uncond_multiplier)
178
+
179
+ @classmethod
180
+ def default(cls):
181
+ return cls(ControlWeightType.DEFAULT)
182
+
183
+ @classmethod
184
+ def universal(cls, base_multiplier: float, flip_weights: bool=False, uncond_multiplier: float=1.0):
185
+ return cls(ControlWeightType.UNIVERSAL, base_multiplier=base_multiplier, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
186
+
187
+ @classmethod
188
+ def universal_mask(cls, weight_mask: Tensor, uncond_multiplier: float=1.0):
189
+ return cls(ControlWeightType.UNIVERSAL, weight_mask=weight_mask, uncond_multiplier=uncond_multiplier)
190
+
191
+ @classmethod
192
+ def t2iadapter(cls, weights: list[float]=None, flip_weights: bool=False, uncond_multiplier: float=1.0):
193
+ if weights is None:
194
+ weights = [1.0]*12
195
+ return cls(ControlWeightType.T2IADAPTER, weights=weights,flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
196
+
197
+ @classmethod
198
+ def controlnet(cls, weights: list[float]=None, flip_weights: bool=False, uncond_multiplier: float=1.0):
199
+ if weights is None:
200
+ weights = [1.0]*13
201
+ return cls(ControlWeightType.CONTROLNET, weights=weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
202
+
203
+ @classmethod
204
+ def controllora(cls, weights: list[float]=None, flip_weights: bool=False, uncond_multiplier: float=1.0):
205
+ if weights is None:
206
+ weights = [1.0]*10
207
+ return cls(ControlWeightType.CONTROLLORA, weights=weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
208
+
209
+ @classmethod
210
+ def controllllite(cls, weights: list[float]=None, flip_weights: bool=False, uncond_multiplier: float=1.0):
211
+ if weights is None:
212
+ # TODO: make this have a real value
213
+ weights = [1.0]*200
214
+ return cls(ControlWeightType.CONTROLLLLITE, weights=weights, flip_weights=flip_weights, uncond_multiplier=uncond_multiplier)
215
+
216
+
217
+ class StrengthInterpolation:
218
+ LINEAR = "linear"
219
+ EASE_IN = "ease-in"
220
+ EASE_OUT = "ease-out"
221
+ EASE_IN_OUT = "ease-in-out"
222
+ NONE = "none"
223
+
224
+ _LIST = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT]
225
+ _LIST_WITH_NONE = [LINEAR, EASE_IN, EASE_OUT, EASE_IN_OUT, NONE]
226
+
227
+ @classmethod
228
+ def get_weights(cls, num_from: float, num_to: float, length: int, method: str, reverse=False):
229
+ diff = num_to - num_from
230
+ if method == cls.LINEAR:
231
+ weights = torch.linspace(num_from, num_to, length)
232
+ elif method == cls.EASE_IN:
233
+ index = torch.linspace(0, 1, length)
234
+ weights = diff * np.power(index, 2) + num_from
235
+ elif method == cls.EASE_OUT:
236
+ index = torch.linspace(0, 1, length)
237
+ weights = diff * (1 - np.power(1 - index, 2)) + num_from
238
+ elif method == cls.EASE_IN_OUT:
239
+ index = torch.linspace(0, 1, length)
240
+ weights = diff * ((1 - np.cos(index * np.pi)) / 2) + num_from
241
+ else:
242
+ raise ValueError(f"Unrecognized interpolation method '{method}'.")
243
+ if reverse:
244
+ weights = weights.flip(dims=(0,))
245
+ return weights
246
+
247
+
248
+ class LatentKeyframe:
249
+ def __init__(self, batch_index: int, strength: float) -> None:
250
+ self.batch_index = batch_index
251
+ self.strength = strength
252
+
253
+
254
+ # always maintain sorted state (by batch_index of LatentKeyframe)
255
+ class LatentKeyframeGroup:
256
+ def __init__(self) -> None:
257
+ self.keyframes: list[LatentKeyframe] = []
258
+
259
+ def add(self, keyframe: LatentKeyframe) -> None:
260
+ added = False
261
+ # replace existing keyframe if same batch_index
262
+ for i in range(len(self.keyframes)):
263
+ if self.keyframes[i].batch_index == keyframe.batch_index:
264
+ self.keyframes[i] = keyframe
265
+ added = True
266
+ break
267
+ if not added:
268
+ self.keyframes.append(keyframe)
269
+ self.keyframes.sort(key=lambda k: k.batch_index)
270
+
271
+ def get_index(self, index: int) -> Union[LatentKeyframe, None]:
272
+ try:
273
+ return self.keyframes[index]
274
+ except IndexError:
275
+ return None
276
+
277
+ def __getitem__(self, index) -> LatentKeyframe:
278
+ return self.keyframes[index]
279
+
280
+ def is_empty(self) -> bool:
281
+ return len(self.keyframes) == 0
282
+
283
+ def clone(self) -> 'LatentKeyframeGroup':
284
+ cloned = LatentKeyframeGroup()
285
+ for tk in self.keyframes:
286
+ cloned.add(tk)
287
+ return cloned
288
+
289
+
290
+ class TimestepKeyframe:
291
+ def __init__(self,
292
+ start_percent: float = 0.0,
293
+ strength: float = 1.0,
294
+ control_weights: ControlWeights = None,
295
+ latent_keyframes: LatentKeyframeGroup = None,
296
+ null_latent_kf_strength: float = 0.0,
297
+ inherit_missing: bool = True,
298
+ guarantee_steps: int = 1,
299
+ mask_hint_orig: Tensor = None) -> None:
300
+ self.start_percent = float(start_percent)
301
+ self.start_t = 999999999.9
302
+ self.strength = strength
303
+ self.control_weights = control_weights
304
+ self.latent_keyframes = latent_keyframes
305
+ self.null_latent_kf_strength = null_latent_kf_strength
306
+ self.inherit_missing = inherit_missing
307
+ self.guarantee_steps = guarantee_steps
308
+ self.mask_hint_orig = mask_hint_orig
309
+
310
+ def has_control_weights(self):
311
+ return self.control_weights is not None
312
+
313
+ def has_latent_keyframes(self):
314
+ return self.latent_keyframes is not None
315
+
316
+ def has_mask_hint(self):
317
+ return self.mask_hint_orig is not None
318
+
319
+
320
+ @staticmethod
321
+ def default() -> 'TimestepKeyframe':
322
+ return TimestepKeyframe(start_percent=0.0, guarantee_steps=0)
323
+
324
+
325
+ # always maintain sorted state (by start_percent of TimestepKeyFrame)
326
+ class TimestepKeyframeGroup:
327
+ def __init__(self) -> None:
328
+ self.keyframes: list[TimestepKeyframe] = []
329
+ self.keyframes.append(TimestepKeyframe.default())
330
+
331
+ def add(self, keyframe: TimestepKeyframe) -> None:
332
+ # add to end of list, then sort
333
+ self.keyframes.append(keyframe)
334
+ self.keyframes = get_sorted_list_via_attr(self.keyframes, attr="start_percent")
335
+
336
+ def get_index(self, index: int) -> Union[TimestepKeyframe, None]:
337
+ try:
338
+ return self.keyframes[index]
339
+ except IndexError:
340
+ return None
341
+
342
+ def has_index(self, index: int) -> int:
343
+ return index >=0 and index < len(self.keyframes)
344
+
345
+ def __getitem__(self, index) -> TimestepKeyframe:
346
+ return self.keyframes[index]
347
+
348
+ def __len__(self) -> int:
349
+ return len(self.keyframes)
350
+
351
+ def is_empty(self) -> bool:
352
+ return len(self.keyframes) == 0
353
+
354
+ def clone(self) -> 'TimestepKeyframeGroup':
355
+ cloned = TimestepKeyframeGroup()
356
+ # already sorted, so don't use add function to make cloning quicker
357
+ for tk in self.keyframes:
358
+ cloned.keyframes.append(tk)
359
+ return cloned
360
+
361
+ @classmethod
362
+ def default(cls, keyframe: TimestepKeyframe) -> 'TimestepKeyframeGroup':
363
+ group = cls()
364
+ group.keyframes[0] = keyframe
365
+ return group
366
+
367
+
368
+ class AbstractPreprocWrapper:
369
+ error_msg = "Invalid use of [InsertHere] output. The output of [InsertHere] preprocessor is NOT a usual image, but a latent pretending to be an image - you must connect the output directly to an Apply ControlNet node (advanced or otherwise). It cannot be used for anything else that accepts IMAGE input."
370
+ def __init__(self, condhint: Tensor):
371
+ self.condhint = condhint
372
+
373
+ def movedim(self, *args, **kwargs):
374
+ return self
375
+
376
+ def __getattr__(self, *args, **kwargs):
377
+ raise AttributeError(self.error_msg)
378
+
379
+ def __setattr__(self, name, value):
380
+ if name != "condhint":
381
+ raise AttributeError(self.error_msg)
382
+ super().__setattr__(name, value)
383
+
384
+ def __iter__(self, *args, **kwargs):
385
+ raise AttributeError(self.error_msg)
386
+
387
+ def __next__(self, *args, **kwargs):
388
+ raise AttributeError(self.error_msg)
389
+
390
+ def __len__(self, *args, **kwargs):
391
+ raise AttributeError(self.error_msg)
392
+
393
+ def __getitem__(self, *args, **kwargs):
394
+ raise AttributeError(self.error_msg)
395
+
396
+ def __setitem__(self, *args, **kwargs):
397
+ raise AttributeError(self.error_msg)
398
+
399
+
400
+ # depending on model, AnimateDiff may inject into GroupNorm, so make sure GroupNorm will be clean
401
+ class disable_weight_init_clean_groupnorm(comfy.ops.disable_weight_init):
402
+ class GroupNorm(comfy.ops.disable_weight_init.GroupNorm):
403
+ def forward_comfy_cast_weights(self, input):
404
+ weight, bias = comfy.ops.cast_bias_weight(self, input)
405
+ return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
406
+
407
+ def forward(self, input):
408
+ if self.comfy_cast_weights:
409
+ return self.forward_comfy_cast_weights(input)
410
+ else:
411
+ return torch.nn.functional.group_norm(input, self.num_groups, self.weight, self.bias, self.eps)
412
+
413
+ class manual_cast_clean_groupnorm(comfy.ops.manual_cast):
414
+ class GroupNorm(disable_weight_init_clean_groupnorm.GroupNorm):
415
+ comfy_cast_weights = True
416
+
417
+
418
+ # adapted from comfy/sample.py
419
+ def prepare_mask_batch(mask: Tensor, shape: Tensor, multiplier: int=1, match_dim1=False):
420
+ mask = mask.clone()
421
+ mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[2]*multiplier, shape[3]*multiplier), mode="bilinear")
422
+ if match_dim1:
423
+ mask = torch.cat([mask] * shape[1], dim=1)
424
+ return mask
425
+
426
+
427
+ # applies min-max normalization, from:
428
+ # https://stackoverflow.com/questions/68791508/min-max-normalization-of-a-tensor-in-pytorch
429
+ def normalize_min_max(x: Tensor, new_min = 0.0, new_max = 1.0):
430
+ x_min, x_max = x.min(), x.max()
431
+ return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min
432
+
433
+ def linear_conversion(x, x_min=0.0, x_max=1.0, new_min=0.0, new_max=1.0):
434
+ return (((x - x_min)/(x_max - x_min)) * (new_max - new_min)) + new_min
435
+
436
+
437
+ def broadcast_image_to_full(tensor, target_batch_size, batched_number, except_one=True):
438
+ current_batch_size = tensor.shape[0]
439
+ #print(current_batch_size, target_batch_size)
440
+ if except_one and current_batch_size == 1:
441
+ return tensor
442
+
443
+ per_batch = target_batch_size // batched_number
444
+ tensor = tensor[:per_batch]
445
+
446
+ if per_batch > tensor.shape[0]:
447
+ tensor = torch.cat([tensor] * (per_batch // tensor.shape[0]) + [tensor[:(per_batch % tensor.shape[0])]], dim=0)
448
+
449
+ current_batch_size = tensor.shape[0]
450
+ if current_batch_size == target_batch_size:
451
+ return tensor
452
+ else:
453
+ return torch.cat([tensor] * batched_number, dim=0)
454
+
455
+
456
+ # from https://stackoverflow.com/a/24621200
457
+ def deepcopy_with_sharing(obj, shared_attribute_names, memo=None):
458
+ '''
459
+ Deepcopy an object, except for a given list of attributes, which should
460
+ be shared between the original object and its copy.
461
+
462
+ obj is some object
463
+ shared_attribute_names: A list of strings identifying the attributes that
464
+ should be shared between the original and its copy.
465
+ memo is the dictionary passed into __deepcopy__. Ignore this argument if
466
+ not calling from within __deepcopy__.
467
+ '''
468
+ assert isinstance(shared_attribute_names, (list, tuple))
469
+
470
+ shared_attributes = {k: getattr(obj, k) for k in shared_attribute_names}
471
+
472
+ if hasattr(obj, '__deepcopy__'):
473
+ # Do hack to prevent infinite recursion in call to deepcopy
474
+ deepcopy_method = obj.__deepcopy__
475
+ obj.__deepcopy__ = None
476
+
477
+ for attr in shared_attribute_names:
478
+ del obj.__dict__[attr]
479
+
480
+ clone = deepcopy(obj)
481
+
482
+ for attr, val in shared_attributes.items():
483
+ setattr(obj, attr, val)
484
+ setattr(clone, attr, val)
485
+
486
+ if hasattr(obj, '__deepcopy__'):
487
+ # Undo hack
488
+ obj.__deepcopy__ = deepcopy_method
489
+ del clone.__deepcopy__
490
+
491
+ return clone
492
+
493
+
494
+ def get_sorted_list_via_attr(objects: list, attr: str) -> list:
495
+ if not objects:
496
+ return objects
497
+ elif len(objects) <= 1:
498
+ return [x for x in objects]
499
+ # now that we know we have to sort, do it following these rules:
500
+ # a) if objects have same value of attribute, maintain their relative order
501
+ # b) perform sorting of the groups of objects with same attributes
502
+ unique_attrs = {}
503
+ for o in objects:
504
+ val_attr = getattr(o, attr)
505
+ attr_list: list = unique_attrs.get(val_attr, list())
506
+ attr_list.append(o)
507
+ if val_attr not in unique_attrs:
508
+ unique_attrs[val_attr] = attr_list
509
+ # now that we have the unique attr values grouped together in relative order, sort them by key
510
+ sorted_attrs = dict(sorted(unique_attrs.items()))
511
+ # now flatten out the dict into a list to return
512
+ sorted_list = []
513
+ for object_list in sorted_attrs.values():
514
+ sorted_list.extend(object_list)
515
+ return sorted_list
516
+
517
+
518
+ class WeightTypeException(TypeError):
519
+ "Raised when weight not compatible with AdvancedControlBase object"
520
+ pass
521
+
522
+
523
+ class AdvancedControlBase:
524
+ def __init__(self, base: ControlBase, timestep_keyframes: TimestepKeyframeGroup, weights_default: ControlWeights, require_model=False):
525
+ self.base = base
526
+ self.compatible_weights = [ControlWeightType.UNIVERSAL]
527
+ self.add_compatible_weight(weights_default.weight_type)
528
+ # mask for which parts of controlnet output to keep
529
+ self.mask_cond_hint_original = None
530
+ self.mask_cond_hint = None
531
+ self.tk_mask_cond_hint_original = None
532
+ self.tk_mask_cond_hint = None
533
+ self.weight_mask_cond_hint = None
534
+ # actual index values
535
+ self.sub_idxs = None
536
+ self.full_latent_length = 0
537
+ self.context_length = 0
538
+ # timesteps
539
+ self.t: Tensor = None
540
+ self.batched_number: Union[int, IntWithCondOrUncond] = None
541
+ self.batch_size: int = 0
542
+ # weights + override
543
+ self.weights: ControlWeights = None
544
+ self.weights_default: ControlWeights = weights_default
545
+ self.weights_override: ControlWeights = None
546
+ # latent keyframe + override
547
+ self.latent_keyframes: LatentKeyframeGroup = None
548
+ self.latent_keyframe_override: LatentKeyframeGroup = None
549
+ # initialize timestep_keyframes
550
+ self.set_timestep_keyframes(timestep_keyframes)
551
+ # override some functions
552
+ self.get_control = self.get_control_inject
553
+ self.control_merge = self.control_merge_inject
554
+ self.pre_run = self.pre_run_inject
555
+ self.cleanup = self.cleanup_inject
556
+ self.set_previous_controlnet = self.set_previous_controlnet_inject
557
+ # require model to be passed into Apply Advanced ControlNet 🛂🅐🅒🅝 node
558
+ self.require_model = require_model
559
+ # disarm - when set to False, used to force usage of Apply Advanced ControlNet 🛂🅐🅒🅝 node (which will set it to True)
560
+ self.disarmed = not require_model
561
+
562
+ def patch_model(self, model: ModelPatcher):
563
+ pass
564
+
565
+ def add_compatible_weight(self, control_weight_type: str):
566
+ self.compatible_weights.append(control_weight_type)
567
+
568
+ def verify_all_weights(self, throw_error=True):
569
+ # first, check if override exists - if so, only need to check the override
570
+ if self.weights_override is not None:
571
+ if self.weights_override.weight_type not in self.compatible_weights:
572
+ msg = f"Weight override is type {self.weights_override.weight_type}, but loaded {type(self).__name__}" + \
573
+ f"only supports {self.compatible_weights} weights."
574
+ raise WeightTypeException(msg)
575
+ # otherwise, check all timestep keyframe weights
576
+ else:
577
+ for tk in self.timestep_keyframes.keyframes:
578
+ if tk.has_control_weights() and tk.control_weights.weight_type not in self.compatible_weights:
579
+ msg = f"Weight on Timestep Keyframe with start_percent={tk.start_percent} is type" + \
580
+ f"{tk.control_weights.weight_type}, but loaded {type(self).__name__} only supports {self.compatible_weights} weights."
581
+ raise WeightTypeException(msg)
582
+
583
+ def set_timestep_keyframes(self, timestep_keyframes: TimestepKeyframeGroup):
584
+ self.timestep_keyframes = timestep_keyframes if timestep_keyframes else TimestepKeyframeGroup()
585
+ # prepare first timestep_keyframe related stuff
586
+ self._current_timestep_keyframe = None
587
+ self._current_timestep_index = -1
588
+ self._current_used_steps = 0
589
+ self.weights = None
590
+ self.latent_keyframes = None
591
+
592
+ def prepare_current_timestep(self, t: Tensor, batched_number: int):
593
+ self.t = float(t[0])
594
+ self.batched_number = batched_number
595
+ self.batch_size = len(t)
596
+ # get current step percent
597
+ curr_t: float = self.t
598
+ prev_index = self._current_timestep_index
599
+ # if met guaranteed steps (or no current keyframe), look for next keyframe in case need to switch
600
+ if self._current_timestep_keyframe is None or self._current_used_steps >= self._current_timestep_keyframe.guarantee_steps:
601
+ # if has next index, loop through and see if need to switch
602
+ if self.timestep_keyframes.has_index(self._current_timestep_index+1):
603
+ for i in range(self._current_timestep_index+1, len(self.timestep_keyframes)):
604
+ eval_tk = self.timestep_keyframes[i]
605
+ # check if start percent is less or equal to curr_t
606
+ if eval_tk.start_t >= curr_t:
607
+ self._current_timestep_index = i
608
+ self._current_timestep_keyframe = eval_tk
609
+ self._current_used_steps = 0
610
+ # keep track of control weights, latent keyframes, and masks,
611
+ # accounting for inherit_missing
612
+ if self._current_timestep_keyframe.has_control_weights():
613
+ self.weights = self._current_timestep_keyframe.control_weights
614
+ elif not self._current_timestep_keyframe.inherit_missing:
615
+ self.weights = self.weights_default
616
+ if self._current_timestep_keyframe.has_latent_keyframes():
617
+ self.latent_keyframes = self._current_timestep_keyframe.latent_keyframes
618
+ elif not self._current_timestep_keyframe.inherit_missing:
619
+ self.latent_keyframes = None
620
+ if self._current_timestep_keyframe.has_mask_hint():
621
+ self.tk_mask_cond_hint_original = self._current_timestep_keyframe.mask_hint_orig
622
+ elif not self._current_timestep_keyframe.inherit_missing:
623
+ del self.tk_mask_cond_hint_original
624
+ self.tk_mask_cond_hint_original = None
625
+ # if guarantee_steps greater than zero, stop searching for other keyframes
626
+ if self._current_timestep_keyframe.guarantee_steps > 0:
627
+ break
628
+ # if eval_tk is outside of percent range, stop looking further
629
+ else:
630
+ break
631
+
632
+ # update steps current keyframe is used
633
+ self._current_used_steps += 1
634
+ # if index changed, apply overrides
635
+ if prev_index != self._current_timestep_index:
636
+ if self.weights_override is not None:
637
+ self.weights = self.weights_override
638
+ if self.latent_keyframe_override is not None:
639
+ self.latent_keyframes = self.latent_keyframe_override
640
+
641
+ # make sure weights and latent_keyframes are in a workable state
642
+ # Note: each AdvancedControlBase should create their own get_universal_weights class
643
+ self.prepare_weights()
644
+
645
+ def prepare_weights(self):
646
+ if self.weights is None or self.weights.weight_type == ControlWeightType.DEFAULT:
647
+ self.weights = self.weights_default
648
+ elif self.weights.weight_type == ControlWeightType.UNIVERSAL:
649
+ # if universal and weight_mask present, no need to convert
650
+ if self.weights.weight_mask is not None:
651
+ return
652
+ self.weights = self.get_universal_weights()
653
+
654
+ def get_universal_weights(self) -> ControlWeights:
655
+ return self.weights
656
+
657
+ def set_cond_hint_mask(self, mask_hint):
658
+ self.mask_cond_hint_original = mask_hint
659
+ return self
660
+
661
+ def pre_run_inject(self, model, percent_to_timestep_function):
662
+ self.base.pre_run(model, percent_to_timestep_function)
663
+ self.pre_run_advanced(model, percent_to_timestep_function)
664
+
665
+ def pre_run_advanced(self, model, percent_to_timestep_function):
666
+ # for each timestep keyframe, calculate the start_t
667
+ for tk in self.timestep_keyframes.keyframes:
668
+ tk.start_t = percent_to_timestep_function(tk.start_percent)
669
+ # clear variables
670
+ self.cleanup_advanced()
671
+
672
+ def set_previous_controlnet_inject(self, *args, **kwargs):
673
+ to_return = self.base.set_previous_controlnet(*args, **kwargs)
674
+ if not self.disarmed:
675
+ raise Exception(f"Type '{type(self).__name__}' must be used with Apply Advanced ControlNet 🛂🅐🅒🅝 node (with model_optional passed in); otherwise, it will not work.")
676
+ return to_return
677
+
678
+ def disarm(self):
679
+ self.disarmed = True
680
+
681
+ def should_run(self):
682
+ if math.isclose(self.strength, 0.0) or math.isclose(self._current_timestep_keyframe.strength, 0.0):
683
+ return False
684
+ if self.timestep_range is not None:
685
+ if self.t > self.timestep_range[0] or self.t < self.timestep_range[1]:
686
+ return False
687
+ return True
688
+
689
+ def get_control_inject(self, x_noisy, t, cond, batched_number):
690
+ # prepare timestep and everything related
691
+ self.prepare_current_timestep(t=t, batched_number=batched_number)
692
+ # if should not perform any actions for the controlnet, exit without doing any work
693
+ if self.strength == 0.0 or self._current_timestep_keyframe.strength == 0.0:
694
+ return self.default_control_actions(x_noisy, t, cond, batched_number)
695
+ # otherwise, perform normal function
696
+ return self.get_control_advanced(x_noisy, t, cond, batched_number)
697
+
698
+ def get_control_advanced(self, x_noisy, t, cond, batched_number):
699
+ return self.default_control_actions(x_noisy, t, cond, batched_number)
700
+
701
+ def default_control_actions(self, x_noisy, t, cond, batched_number):
702
+ control_prev = None
703
+ if self.previous_controlnet is not None:
704
+ control_prev = self.previous_controlnet.get_control(x_noisy, t, cond, batched_number)
705
+ return control_prev
706
+
707
+ def calc_weight(self, idx: int, x: Tensor, layers: int) -> Union[float, Tensor]:
708
+ if self.weights.weight_mask is not None:
709
+ # prepare weight mask
710
+ self.prepare_weight_mask_cond_hint(x, self.batched_number)
711
+ # adjust mask for current layer and return
712
+ return torch.pow(self.weight_mask_cond_hint, self.get_calc_pow(idx=idx, layers=layers))
713
+ return self.weights.get(idx=idx)
714
+
715
+ def get_calc_pow(self, idx: int, layers: int) -> int:
716
+ return (layers-1)-idx
717
+
718
+ def calc_latent_keyframe_mults(self, x: Tensor, batched_number: int) -> Tensor:
719
+ # apply strengths, and get batch indeces to null out
720
+ # AKA latents that should not be influenced by ControlNet
721
+ final_mults = [1.0] * x.shape[0]
722
+ if self.latent_keyframes:
723
+ latent_count = x.shape[0] // batched_number
724
+ indeces_to_null = set(range(latent_count))
725
+ mapped_indeces = None
726
+ # if expecting subdivision, will need to translate between subset and actual idx values
727
+ if self.sub_idxs:
728
+ mapped_indeces = {}
729
+ for i, actual in enumerate(self.sub_idxs):
730
+ mapped_indeces[actual] = i
731
+ for keyframe in self.latent_keyframes:
732
+ real_index = keyframe.batch_index
733
+ # if negative, count from end
734
+ if real_index < 0:
735
+ real_index += latent_count if self.sub_idxs is None else self.full_latent_length
736
+
737
+ # if not mapping indeces, what you see is what you get
738
+ if mapped_indeces is None:
739
+ if real_index in indeces_to_null:
740
+ indeces_to_null.remove(real_index)
741
+ # otherwise, see if batch_index is even included in this set of latents
742
+ else:
743
+ real_index = mapped_indeces.get(real_index, None)
744
+ if real_index is None:
745
+ continue
746
+ indeces_to_null.remove(real_index)
747
+
748
+ # if real_index is outside the bounds of latents, don't apply
749
+ if real_index >= latent_count or real_index < 0:
750
+ continue
751
+
752
+ # apply strength for each batched cond/uncond
753
+ for b in range(batched_number):
754
+ final_mults[(latent_count*b)+real_index] = keyframe.strength
755
+ # null them out by multiplying by null_latent_kf_strength
756
+ for batch_index in indeces_to_null:
757
+ # apply null for each batched cond/uncond
758
+ for b in range(batched_number):
759
+ final_mults[(latent_count*b)+batch_index] = self._current_timestep_keyframe.null_latent_kf_strength
760
+ # convert final_mults into tensor and match expected dimension count
761
+ final_tensor = torch.tensor(final_mults, dtype=x.dtype, device=x.device)
762
+ while len(final_tensor.shape) < len(x.shape):
763
+ final_tensor = final_tensor.unsqueeze(-1)
764
+ return final_tensor
765
+
766
+ def apply_advanced_strengths_and_masks(self, x: Tensor, batched_number: int):
767
+ # handle weight's uncond_multiplier, if applicable
768
+ if self.weights.has_uncond_multiplier:
769
+ cond_or_uncond = self.batched_number.cond_or_uncond
770
+ actual_length = x.size(0) // batched_number
771
+ for idx, cond_type in enumerate(cond_or_uncond):
772
+ # if uncond, set to weight's uncond_multiplier
773
+ if cond_type == 1:
774
+ x[actual_length*idx:actual_length*(idx+1)] *= self.weights.uncond_multiplier
775
+
776
+ if self.latent_keyframes is not None:
777
+ x[:] = x[:] * self.calc_latent_keyframe_mults(x=x, batched_number=batched_number)
778
+ # apply masks, resizing mask to required dims
779
+ if self.mask_cond_hint is not None:
780
+ masks = prepare_mask_batch(self.mask_cond_hint, x.shape)
781
+ x[:] = x[:] * masks
782
+ if self.tk_mask_cond_hint is not None:
783
+ masks = prepare_mask_batch(self.tk_mask_cond_hint, x.shape)
784
+ x[:] = x[:] * masks
785
+ # apply timestep keyframe strengths
786
+ if self._current_timestep_keyframe.strength != 1.0:
787
+ x[:] *= self._current_timestep_keyframe.strength
788
+
789
+ def control_merge_inject(self: 'AdvancedControlBase', control_input, control_output, control_prev, output_dtype):
790
+ out = {'input':[], 'middle':[], 'output': []}
791
+
792
+ if control_input is not None:
793
+ for i in range(len(control_input)):
794
+ key = 'input'
795
+ x = control_input[i]
796
+ if x is not None:
797
+ self.apply_advanced_strengths_and_masks(x, self.batched_number)
798
+
799
+ x *= self.strength * self.calc_weight(i, x, len(control_input))
800
+ if x.dtype != output_dtype:
801
+ x = x.to(output_dtype)
802
+ out[key].insert(0, x)
803
+
804
+ if control_output is not None:
805
+ for i in range(len(control_output)):
806
+ if i == (len(control_output) - 1):
807
+ key = 'middle'
808
+ index = 0
809
+ else:
810
+ key = 'output'
811
+ index = i
812
+ x = control_output[i]
813
+ if x is not None:
814
+ self.apply_advanced_strengths_and_masks(x, self.batched_number)
815
+
816
+ if self.global_average_pooling:
817
+ x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
818
+
819
+ x *= self.strength * self.calc_weight(i, x, len(control_output))
820
+ if x.dtype != output_dtype:
821
+ x = x.to(output_dtype)
822
+
823
+ out[key].append(x)
824
+ if control_prev is not None:
825
+ for x in ['input', 'middle', 'output']:
826
+ o = out[x]
827
+ for i in range(len(control_prev[x])):
828
+ prev_val = control_prev[x][i]
829
+ if i >= len(o):
830
+ o.append(prev_val)
831
+ elif prev_val is not None:
832
+ if o[i] is None:
833
+ o[i] = prev_val
834
+ else:
835
+ if o[i].shape[0] < prev_val.shape[0]:
836
+ o[i] = prev_val + o[i]
837
+ else:
838
+ o[i] += prev_val
839
+ return out
840
+
841
+ def prepare_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
842
+ self._prepare_mask("mask_cond_hint", self.mask_cond_hint_original, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
843
+ self.prepare_tk_mask_cond_hint(x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
844
+
845
+ def prepare_tk_mask_cond_hint(self, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
846
+ return self._prepare_mask("tk_mask_cond_hint", self._current_timestep_keyframe.mask_hint_orig, x_noisy, t, cond, batched_number, dtype, direct_attn=direct_attn)
847
+
848
+ def prepare_weight_mask_cond_hint(self, x_noisy: Tensor, batched_number, dtype=None):
849
+ return self._prepare_mask("weight_mask_cond_hint", self.weights.weight_mask, x_noisy, t=None, cond=None, batched_number=batched_number, dtype=dtype, direct_attn=True)
850
+
851
+ def _prepare_mask(self, attr_name, orig_mask: Tensor, x_noisy: Tensor, t, cond, batched_number, dtype=None, direct_attn=False):
852
+ # make mask appropriate dimensions, if present
853
+ if orig_mask is not None:
854
+ out_mask = getattr(self, attr_name)
855
+ multiplier = 1 if direct_attn else 8
856
+ if self.sub_idxs is not None or out_mask is None or x_noisy.shape[2] * multiplier != out_mask.shape[1] or x_noisy.shape[3] * multiplier != out_mask.shape[2]:
857
+ self._reset_attr(attr_name)
858
+ del out_mask
859
+ # TODO: perform upscale on only the sub_idxs masks at a time instead of all to conserve RAM
860
+ # resize mask and match batch count
861
+ out_mask = prepare_mask_batch(orig_mask, x_noisy.shape, multiplier=multiplier)
862
+ actual_latent_length = x_noisy.shape[0] // batched_number
863
+ out_mask = comfy.utils.repeat_to_batch_size(out_mask, actual_latent_length if self.sub_idxs is None else self.full_latent_length)
864
+ if self.sub_idxs is not None:
865
+ out_mask = out_mask[self.sub_idxs]
866
+ # make cond_hint_mask length match x_noise
867
+ if x_noisy.shape[0] != out_mask.shape[0]:
868
+ out_mask = broadcast_image_to(out_mask, x_noisy.shape[0], batched_number)
869
+ # default dtype to be same as x_noisy
870
+ if dtype is None:
871
+ dtype = x_noisy.dtype
872
+ setattr(self, attr_name, out_mask.to(dtype=dtype).to(self.device))
873
+ del out_mask
874
+
875
+ def _reset_attr(self, attr_name, new_value=None):
876
+ if hasattr(self, attr_name):
877
+ delattr(self, attr_name)
878
+ setattr(self, attr_name, new_value)
879
+
880
+ def cleanup_inject(self):
881
+ self.base.cleanup()
882
+ self.cleanup_advanced()
883
+
884
+ def cleanup_advanced(self):
885
+ self.sub_idxs = None
886
+ self.full_latent_length = 0
887
+ self.context_length = 0
888
+ self.t = None
889
+ self.batched_number = None
890
+ self.batch_size = 0
891
+ self.weights = None
892
+ self.latent_keyframes = None
893
+ # timestep stuff
894
+ self._current_timestep_keyframe = None
895
+ self._current_timestep_index = -1
896
+ self._current_used_steps = 0
897
+ # clear mask hints
898
+ if self.mask_cond_hint is not None:
899
+ del self.mask_cond_hint
900
+ self.mask_cond_hint = None
901
+ if self.tk_mask_cond_hint_original is not None:
902
+ del self.tk_mask_cond_hint_original
903
+ self.tk_mask_cond_hint_original = None
904
+ if self.tk_mask_cond_hint is not None:
905
+ del self.tk_mask_cond_hint
906
+ self.tk_mask_cond_hint = None
907
+ if self.weight_mask_cond_hint is not None:
908
+ del self.weight_mask_cond_hint
909
+ self.weight_mask_cond_hint = None
910
+
911
+ def copy_to_advanced(self, copied: 'AdvancedControlBase'):
912
+ copied.mask_cond_hint_original = self.mask_cond_hint_original
913
+ copied.weights_override = self.weights_override
914
+ copied.latent_keyframe_override = self.latent_keyframe_override
915
+ copied.disarmed = self.disarmed
ComfyUI-Advanced-ControlNet/pyproject.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "comfyui-advanced-controlnet"
3
+ description = "Nodes for scheduling ControlNet strength across timesteps and batched latents, as well as applying custom weights and attention masks."
4
+ version = "1.0.2"
5
+ license = "LICENSE"
6
+ dependencies = []
7
+
8
+ [project.urls]
9
+ Repository = "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet"
10
+
11
+ # Used by Comfy Registry https://comfyregistry.org
12
+ [tool.comfy]
13
+ PublisherId = "kosinkadink"
14
+ DisplayName = "ComfyUI-Advanced-ControlNet"
15
+ Icon = ""
ComfyUI-Advanced-ControlNet/requirements.txt ADDED
File without changes
ComfyUI-BrushNet/BIG_IMAGE.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ![example workflow](example/BrushNet_cut_for_inpaint.png?raw=true)
2
+
3
+ [workflow](example/BrushNet_cut_for_inpaint.json)
4
+
5
+ When you work with big image and your inpaint mask is small it is better to cut part of the image, work with it and then blend it back.
6
+ I created a node for such workflow, see example.
ComfyUI-BrushNet/CN.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ControlNet Canny Edge
2
+
3
+ Let's take the pestered cake and try to inpaint it again. Now I would like to use a sleeping cat for it:
4
+
5
+ ![sleeping cat](example/sleeping_cat.png?raw=true)
6
+
7
+ I use Canny Edge node from [comfyui_controlnet_aux](https://github.com/Fannovel16/comfyui_controlnet_aux). Don't forget to resize canny edge mask to 512 pixels:
8
+
9
+ ![sleeping cat inpaint](example/sleeping_cat_inpaint1.png?raw=true)
10
+
11
+ Let's look at the result:
12
+
13
+ ![sleeping cat inpaint](example/sleeping_cat_inpaint2.png?raw=true)
14
+
15
+ The first problem I see here is some kind of object behind the cat. Such objects appear since the inpainting mask strictly aligns with the removed object, the cake in our case. To remove such artifact we should expand our mask a little:
16
+
17
+ ![sleeping cat inpaint](example/sleeping_cat_inpaint3.png?raw=true)
18
+
19
+ Now. what's up with cat back and tail? Let's see the inpainting mask and canny edge mask side to side:
20
+
21
+ ![masks](example/sleeping_cat_inpaint4.png?raw=true)
22
+
23
+ The inpainting works (mostly) only in masked (white) area, so we cut off cat's back. **The ControlNet mask should be inside the inpaint mask.**
24
+
25
+ To address the issue I resized the mask to 256 pixels:
26
+
27
+ ![sleeping cat inpaint](example/sleeping_cat_inpaint5.png?raw=true)
28
+
29
+ This is better but still have a room for improvement. The problem with edge mask downsampling is that edge lines tend to be broken and after some size we will got a mess:
30
+
31
+ ![sleeping cat inpaint](example/sleeping_cat_inpaint6.png?raw=true)
32
+
33
+ Look at the edge mask, at this resolution it is so broken:
34
+
35
+ ![masks](example/sleeping_cat_mask.png?raw=true)
36
+
37
+
38
+
39
+
ComfyUI-BrushNet/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
ComfyUI-BrushNet/PARAMS.md ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Start At and End At parameters usage
2
+
3
+ ### start_at
4
+
5
+ Let's start with a ELLA outpaint [workflow](example/BrushNet_with_ELLA.json) and switch off Blend Inpaint node:
6
+
7
+ ![example workflow](example/params1.png?raw=true)
8
+
9
+ For this example I use "wargaming shop showcase" prompt, `dpmpp_2m` deterministic sampler and `karras` scheduler with 15 steps. This is the result:
10
+
11
+ ![goblin in the shop](example/params2.png?raw=true)
12
+
13
+ The `start_at` BrushNet node parameter allows us to delay BrushNet inference for some steps, so the base model will do all the job. Let's see what the result will be without BrushNet. For this I set up `start_at` parameter to 20 - it should be more then `steps` in KSampler node:
14
+
15
+ ![the shop](example/params3.png?raw=true)
16
+
17
+ So, if we apply BrushNet from the beginning (`start_at` equals 0), the resulting scene will be heavily influenced by BrushNet image. The more we increase this parameter, the more scene will be based on prompt. Let's compare:
18
+
19
+ | `start_at` = 1 | `start_at` = 2 | `start_at` = 3 |
20
+ |:--------------:|:--------------:|:--------------:|
21
+ | ![p1](example/params4.png?raw=true) | ![p2](example/params5.png?raw=true) | ![p3](example/params6.png?raw=true) |
22
+ | `start_at` = 4 | `start_at` = 5 | `start_at` = 6 |
23
+ | ![p1](example/params7.png?raw=true) | ![p2](example/params8.png?raw=true) | ![p3](example/params9.png?raw=true) |
24
+ | `start_at` = 7 | `start_at` = 8 | `start_at` = 9 |
25
+ | ![p1](example/params10.png?raw=true) | ![p2](example/params11.png?raw=true) | ![p3](example/params12.png?raw=true) |
26
+
27
+ Look how the floor is aligned with toy's base - at some step it looses consistency. The results will depend on type of sampler and number of KSampler steps, of course.
28
+
29
+ ### end_at
30
+
31
+ The `end_at` parameter switches off BrushNet at the last steps. If you use deterministic sampler it will only influences details on last steps, but stochastic samplers can change the whole scene. For a description of samplers see, for example, Matteo Spinelli's [video on ComfyUI basics](https://youtu.be/_C7kR2TFIX0?t=516).
32
+
33
+ Here I use basic BrushNet inpaint [example](example/BrushNet_basic.json), with "intricate teapot" prompt, `dpmpp_2m` deterministic sampler and `karras` scheduler with 15 steps:
34
+
35
+ ![example workflow](example/params13.png?raw=true)
36
+
37
+ There are almost no changes when we set 'end_at' paramter to 10, but starting from it:
38
+
39
+ | `end_at` = 10 | `end_at` = 9 | `end_at` = 8 |
40
+ |:--------------:|:--------------:|:--------------:|
41
+ | ![p1](example/params14.png?raw=true) | ![p2](example/params15.png?raw=true) | ![p3](example/params16.png?raw=true) |
42
+ | `end_at` = 7 | `end_at` = 6 | `end_at` = 5 |
43
+ | ![p1](example/params17.png?raw=true) | ![p2](example/params18.png?raw=true) | ![p3](example/params19.png?raw=true) |
44
+ | `end_at` = 4 | `end_at` = 3 | `end_at` = 2 |
45
+ | ![p1](example/params20.png?raw=true) | ![p2](example/params21.png?raw=true) | ![p3](example/params22.png?raw=true) |
46
+
47
+ You can see how the scene was completely redrawn.
ComfyUI-BrushNet/RAUNET.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ During investigation of compatibility issues with [WASasquatch's FreeU_Advanced](https://github.com/WASasquatch/FreeU_Advanced/tree/main) and [blepping's jank HiDiffusion](https://github.com/blepping/comfyui_jankhidiffusion) nodes I stumbled upon some quite hard problems. There are `FreeU` nodes in ComfyUI, but no such for HiDiffusion, so I decided to implement RAUNet on base of my BrushNet implementation. **blepping**, I am sorry. :)
2
+
3
+ ### RAUNet
4
+
5
+ What is RAUNet? I know many of you saw and generate images with a lot of limbs, fingers and faces all morphed together.
6
+
7
+ The authors of HiDiffusion invent simple, yet efficient trick to alleviate this problem. Here is an example:
8
+
9
+ ![example workflow](example/RAUNet1.png?raw=true)
10
+
11
+ [workflow](example/RAUNet_basic.json)
12
+
13
+ The left picture is created using ZavyChromaXL checkpoint on 2048x2048 canvas. The right one uses RAUNet.
14
+
15
+ In my experience the node is helpful but quite sensitive to its parameters. And there is no universal solution - you should adjust them for every new image you generate. It also lowers model's imagination, you usually get only what you described in the prompt. Look at the example: in first you have a forest in the background, but RAUNet deleted all except fox which is described in the prompt.
16
+
17
+ From the [paper](https://arxiv.org/abs/2311.17528): Diffusion models denoise from structures to details. RAU-Net introduces additional downsampling and upsampling operations, leading to a certain degree of information loss. In the early stages of denoising, RAU-Net can generate reasonable structures with minimal impact from information loss. However, in the later stages of denoising when generating fine details, the information loss in RAU-Net results in the loss of image details and a degradation in quality.
18
+
19
+ ### Parameters
20
+
21
+ There are two independent parts in this node: DU (Downsample/Upsample) and XA (CrossAttention). The four parameters are the start and end steps for applying these parts.
22
+
23
+ The Downsample/Upsample part lowers models degrees of freedom. If you apply it a lot (for more steps) the resulting images will have a lot of symmetries.
24
+
25
+ The CrossAttension part lowers number of objects which model tracks in image.
26
+
27
+ Usually you apply DU and after several steps apply XA, sometimes you will need only XA, you should try it yourself.
28
+
29
+ ### Compatibility
30
+
31
+ It is compatible with BrushNet and most other nodes.
32
+
33
+ This is ControlNet example. The lower image is pure model, the upper is after using RAUNet. You can see small fox and two tails in lower image.
34
+
35
+ ![example workflow](example/RAUNet2.png?raw=true)
36
+
37
+ [workflow](example/RAUNet_with_CN.json)
38
+
39
+ The node can be implemented for any model. Right now it can be applied to SD15 and SDXL models.
ComfyUI-BrushNet/README.md ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## ComfyUI-BrushNet
2
+
3
+ These are custom nodes for ComfyUI native implementation of
4
+
5
+ - Brushnet: ["BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"](https://arxiv.org/abs/2403.06976)
6
+ - PowerPaint: [A Task is Worth One Word: Learning with Task Prompts for High-Quality Versatile Image Inpainting](https://arxiv.org/abs/2312.03594)
7
+ - HiDiffusion: [HiDiffusion: Unlocking Higher-Resolution Creativity and Efficiency in Pretrained Diffusion Models](https://arxiv.org/abs/2311.17528)
8
+
9
+ My contribution is limited to the ComfyUI adaptation, and all credit goes to the authors of the papers.
10
+
11
+ ## Updates
12
+
13
+ May 16, 2024. Internal rework to improve compatibility with other nodes. [RAUNet](RAUNET.md) is implemented.
14
+
15
+ May 12, 2024. CutForInpaint node, see [example](BIG_IMAGE.md).
16
+
17
+ May 11, 2024. Image batch is implemented. You can even add BrushNet to AnimateDiff vid2vid workflow, but they don't work together - they are different models and both try to patch UNet. Added some more examples.
18
+
19
+ May 6, 2024. PowerPaint v2 model is implemented. After update your workflow probably will not work. Don't panic! Check `end_at` parameter of BrushNode, if it equals 1, change it to some big number. Read about parameters in Usage section below.
20
+
21
+ May 2, 2024. BrushNet SDXL is live. It needs positive and negative conditioning though, so workflow changes a little, see example.
22
+
23
+ Apr 28, 2024. Another rework, sorry for inconvenience. But now BrushNet is native to ComfyUI. Famous cubiq's [IPAdapter Plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) is now working with BrushNet! I hope... :) Please, report any bugs you found.
24
+
25
+ Apr 18, 2024. Complete rework, no more custom `diffusers` library. It is possible to use LoRA models.
26
+
27
+ Apr 11, 2024. Initial commit.
28
+
29
+ ## Plans
30
+
31
+ - [x] BrushNet SDXL
32
+ - [x] PowerPaint v2
33
+ - [x] Image batch
34
+
35
+ ## Installation
36
+
37
+ Clone the repo into the `custom_nodes` directory and install the requirements:
38
+
39
+ ```
40
+ git clone https://github.com/nullquant/ComfyUI-BrushNet.git
41
+ pip install -r requirements.txt
42
+ ```
43
+
44
+ Checkpoints of BrushNet can be downloaded from [here](https://drive.google.com/drive/folders/1fqmS1CEOvXCxNWFrsSYd_jHYXxrydh1n?usp=drive_link).
45
+
46
+ The checkpoint in `segmentation_mask_brushnet_ckpt` provides checkpoints trained on BrushData, which has segmentation prior (mask are with the same shape of objects). The `random_mask_brushnet_ckpt` provides a more general ckpt for random mask shape.
47
+
48
+ `segmentation_mask_brushnet_ckpt` and `random_mask_brushnet_ckpt` contains BrushNet for SD 1.5 models while
49
+ `segmentation_mask_brushnet_ckpt_sdxl_v0` and `random_mask_brushnet_ckpt_sdxl_v0` for SDXL.
50
+
51
+ You should place `diffusion_pytorch_model.safetensors` files to your `models/inpaint` folder. You can also specify `inpaint` folder in your `extra_model_paths.yaml`.
52
+
53
+ For PowerPaint you should download three files. Both `diffusion_pytorch_model.safetensors` and `pytorch_model.bin` from [here](https://huggingface.co/JunhaoZhuang/PowerPaint-v2-1/tree/main/PowerPaint_Brushnet) should be placed in your `models/inpaint` folder.
54
+
55
+ Also you need SD1.5 text encoder model `model.fp16.safetensors` from [here](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/text_encoder). It should be placed in your `models/clip` folder.
56
+
57
+ This is a structure of my `models/inpaint` folder:
58
+
59
+ ![inpaint folder](example/inpaint_folder.png?raw=true)
60
+
61
+ Yours can be different.
62
+
63
+ ## Usage
64
+
65
+ Below is an example for the intended workflow. The [workflow](example/BrushNet_basic.json) for the example can be found inside the 'example' directory.
66
+
67
+ ![example workflow](example/BrushNet_basic.png?raw=true)
68
+
69
+ <details>
70
+ <summary>SDXL</summary>
71
+
72
+ ![example workflow](example/BrushNet_SDXL_basic.png?raw=true)
73
+
74
+ [workflow](example/BrushNet_SDXL_basic.json)
75
+
76
+ </details>
77
+
78
+ <details>
79
+ <summary>IPAdapter plus</summary>
80
+
81
+ ![example workflow](example/BrushNet_with_IPA.png?raw=true)
82
+
83
+ [workflow](example/BrushNet_with_IPA.json)
84
+
85
+ </details>
86
+
87
+ <details>
88
+ <summary>LoRA</summary>
89
+
90
+ ![example workflow](example/BrushNet_with_LoRA.png?raw=true)
91
+
92
+ [workflow](example/BrushNet_with_LoRA.json)
93
+
94
+ </details>
95
+
96
+ <details>
97
+ <summary>Blending inpaint</summary>
98
+
99
+ ![example workflow](example/BrushNet_inpaint.png?raw=true)
100
+
101
+ Sometimes inference and VAE broke image, so you need to blend inpaint image with the original: [workflow](example/BrushNet_inpaint.json). You can see blurred and broken text after inpainting in the first image and how I suppose to repair it.
102
+
103
+ </details>
104
+
105
+ <details>
106
+ <summary>ControlNet</summary>
107
+
108
+ ![example workflow](example/BrushNet_with_CN.png?raw=true)
109
+
110
+ [workflow](example/BrushNet_with_CN.json)
111
+
112
+ [ControlNet canny edge](CN.md)
113
+
114
+ </details>
115
+
116
+ <details>
117
+ <summary>ELLA outpaint</summary>
118
+
119
+ ![example workflow](example/BrushNet_with_ELLA.png?raw=true)
120
+
121
+ [workflow](example/BrushNet_with_ELLA.json)
122
+
123
+ </details>
124
+
125
+ <details>
126
+ <summary>Upscale</summary>
127
+
128
+ ![example workflow](example/BrushNet_SDXL_upscale.png?raw=true)
129
+
130
+ [workflow](example/BrushNet_SDXL_upscale.json)
131
+
132
+ To upscale you should use base model, not BrushNet. The same is true for conditioning. Latent upscaling between BrushNet and KSampler will not work or will give you wierd results. These limitations are due to structure of BrushNet and its influence on UNet calculations.
133
+
134
+ </details>
135
+
136
+ <details>
137
+ <summary>Image batch</summary>
138
+
139
+ ![example workflow](example/BrushNet_image_batch.png?raw=true)
140
+
141
+ [workflow](example/BrushNet_image_batch.json)
142
+
143
+ If you have OOM problems, you can use Evolved Sampling from [AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved):
144
+
145
+ ![example workflow](example/BrushNet_image_big_batch.png?raw=true)
146
+
147
+ [workflow](example/BrushNet_image_big_batch.json)
148
+
149
+ In Context Options set context_length to number of images which can be loaded into VRAM. Images will be processed in chunks of this size.
150
+
151
+ </details>
152
+
153
+
154
+ <details>
155
+ <summary>Big image inpaint</summary>
156
+
157
+ ![example workflow](example/BrushNet_cut_for_inpaint.png?raw=true)
158
+
159
+ [workflow](example/BrushNet_cut_for_inpaint.json)
160
+
161
+ When you work with big image and your inpaint mask is small it is better to cut part of the image, work with it and then blend it back.
162
+ I created a node for such workflow, see example.
163
+
164
+ </details>
165
+
166
+
167
+ <details>
168
+ <summary>PowerPaint outpaint</summary>
169
+
170
+ ![example workflow](example/PowerPaint_outpaint.png?raw=true)
171
+
172
+ [workflow](example/PowerPaint_outpaint.json)
173
+
174
+ </details>
175
+
176
+ <details>
177
+ <summary>PowerPaint object removal</summary>
178
+
179
+ ![example workflow](example/PowerPaint_object_removal.png?raw=true)
180
+
181
+ [workflow](example/PowerPaint_object_removal.json)
182
+
183
+ It is often hard to completely remove the object, especially if it is at the front:
184
+
185
+ ![object removal example](example/object_removal_fail.png?raw=true)
186
+
187
+ You should try to add object description to negative prompt and describe empty scene, like here:
188
+
189
+ ![object removal example](example/object_removal.png?raw=true)
190
+
191
+ </details>
192
+
193
+ ### Parameters
194
+
195
+ #### Brushnet Loader
196
+
197
+ - `dtype`, defaults to `torch.float16`. The torch.dtype of BrushNet. If you have old GPU or NVIDIA 16 series card try to switch to `torch.float32`.
198
+
199
+ #### Brushnet
200
+
201
+ - `scale`, defaults to 1.0: The "strength" of BrushNet. The outputs of the BrushNet are multiplied by `scale` before they are added to the residual in the original unet.
202
+ - `start_at`, defaults to 0: step at which the BrushNet starts applying.
203
+ - `end_at`, defaults to 10000: step at which the BrushNet stops applying.
204
+
205
+ [Here](PARAMS.md) are examples of use these two last parameters.
206
+
207
+ #### PowerPaint
208
+
209
+ - `CLIP`: PowerPaint CLIP that should be passed from PowerPaintCLIPLoader node.
210
+ - `fitting`: PowerPaint fitting degree.
211
+ - `function`: PowerPaint function, see its [page](https://github.com/open-mmlab/PowerPaint) for details.
212
+
213
+ When using certain network functions, the authors of PowerPaint recommend adding phrases to the prompt:
214
+
215
+ - object removal: `empty scene blur`
216
+ - context aware: `empty scene`
217
+ - outpainting: `empty scene`
218
+
219
+ Many of ComfyUI users use custom text generation nodes, CLIP nodes and a lot of other conditioning. I don't want to break all of these nodes, so I didn't add prompt updating and instead rely on users. Also my own experiments show that these additions to prompt are not strictly necessary.
220
+
221
+ The latent image can be from BrushNet node or not, but it should be the same size as original image (divided by 8 in latent space).
222
+
223
+ The both conditioning `positive` and `negative` in BrushNet and PowerPaint nodes are used for calculation inside, but then simply copied to output.
224
+
225
+ Be advised, not all workflows and nodes will work with BrushNet due to its structure. Also put model changes before BrushNet nodes, not after. If you need model to work with image after BrushNet inference use base one (see Upscale example below).
226
+
227
+ #### RAUNet
228
+
229
+ - `du_start`, defaults to 0: step at which the Downsample/Upsample resize starts applying.
230
+ - `du_end`, defaults to 4: step at which the Downsample/Upsample resize stops applying.
231
+ - `xa_start`, defaults to 4: step at which the CrossAttention resize starts applying.
232
+ - `xa_end`, defaults to 10: step at which the CrossAttention resize stops applying.
233
+
234
+ For an examples and explanation, please look [here](RAUNET.md).
235
+
236
+ ## Limitations
237
+
238
+ BrushNet has some limitations (from the [paper](https://arxiv.org/abs/2403.06976)):
239
+
240
+ - The quality and content generated by the model are heavily dependent on the chosen base model.
241
+ The results can exhibit incoherence if, for example, the given image is a natural image while the base model primarily focuses on anime.
242
+ - Even with BrushNet, we still observe poor generation results in cases where the given mask has an unusually shaped
243
+ or irregular form, or when the given text does not align well with the masked image.
244
+
245
+ ## Notes
246
+
247
+ Unfortunately, due to the nature of BrushNet code some nodes are not compatible with these, since we are trying to patch the same ComfyUI's functions.
248
+
249
+ List of known uncompartible nodes.
250
+
251
+ - [WASasquatch's FreeU_Advanced](https://github.com/WASasquatch/FreeU_Advanced/tree/main)
252
+ - [blepping's jank HiDiffusion](https://github.com/blepping/comfyui_jankhidiffusion)
253
+
254
+ ## Credits
255
+
256
+ The code is based on
257
+
258
+ - [BrushNet](https://github.com/TencentARC/BrushNet)
259
+ - [PowerPaint](https://github.com/zhuang2002/PowerPaint)
260
+ - [HiDiffusion](https://github.com/megvii-research/HiDiffusion)
261
+ - [diffusers](https://github.com/huggingface/diffusers)
ComfyUI-BrushNet/__init__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .brushnet_nodes import BrushNetLoader, BrushNet, BlendInpaint, PowerPaintCLIPLoader, PowerPaint, CutForInpaint
2
+ from .raunet_nodes import RAUNet
3
+
4
+ """
5
+ @author: nullquant
6
+ @title: BrushNet
7
+ @nickname: BrushName nodes
8
+ @description: These are custom nodes for ComfyUI native implementation of BrushNet, PowerPaint and RAUNet models
9
+ """
10
+
11
+ # A dictionary that contains all nodes you want to export with their names
12
+ # NOTE: names should be globally unique
13
+ NODE_CLASS_MAPPINGS = {
14
+ "BrushNetLoader": BrushNetLoader,
15
+ "BrushNet": BrushNet,
16
+ "BlendInpaint": BlendInpaint,
17
+ "PowerPaintCLIPLoader": PowerPaintCLIPLoader,
18
+ "PowerPaint": PowerPaint,
19
+ "CutForInpaint": CutForInpaint,
20
+ "RAUNet": RAUNet,
21
+ }
22
+
23
+ # A dictionary that contains the friendly/humanly readable titles for the nodes
24
+ NODE_DISPLAY_NAME_MAPPINGS = {
25
+ "BrushNetLoader": "BrushNet Loader",
26
+ "BrushNet": "BrushNet",
27
+ "BlendInpaint": "Blend Inpaint",
28
+ "PowerPaintCLIPLoader": "PowerPaint CLIP Loader",
29
+ "PowerPaint": "PowerPaint",
30
+ "CutForInpaint": "Cut For Inpaint",
31
+ "RAUNet": "RAUNet",
32
+ }
ComfyUI-BrushNet/brushnet/brushnet.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BrushNetModel",
3
+ "_diffusers_version": "0.27.0.dev0",
4
+ "_name_or_path": "runs/logs/brushnet_randommask/checkpoint-100000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "block_out_channels": [
11
+ 320,
12
+ 640,
13
+ 1280,
14
+ 1280
15
+ ],
16
+ "brushnet_conditioning_channel_order": "rgb",
17
+ "class_embed_type": null,
18
+ "conditioning_channels": 5,
19
+ "conditioning_embedding_out_channels": [
20
+ 16,
21
+ 32,
22
+ 96,
23
+ 256
24
+ ],
25
+ "cross_attention_dim": 768,
26
+ "down_block_types": [
27
+ "DownBlock2D",
28
+ "DownBlock2D",
29
+ "DownBlock2D",
30
+ "DownBlock2D"
31
+ ],
32
+ "downsample_padding": 1,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "global_pool_conditions": false,
38
+ "in_channels": 4,
39
+ "layers_per_block": 2,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "MidBlock2D",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "projection_class_embeddings_input_dim": null,
48
+ "resnet_time_scale_shift": "default",
49
+ "transformer_layers_per_block": 1,
50
+ "up_block_types": [
51
+ "UpBlock2D",
52
+ "UpBlock2D",
53
+ "UpBlock2D",
54
+ "UpBlock2D"
55
+ ],
56
+ "upcast_attention": false,
57
+ "use_linear_projection": false
58
+ }
ComfyUI-BrushNet/brushnet/brushnet.py ADDED
@@ -0,0 +1,948 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.utils import BaseOutput, logging
10
+ from diffusers.models.attention_processor import (
11
+ ADDED_KV_ATTENTION_PROCESSORS,
12
+ CROSS_ATTENTION_PROCESSORS,
13
+ AttentionProcessor,
14
+ AttnAddedKVProcessor,
15
+ AttnProcessor,
16
+ )
17
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
18
+ from diffusers.models.modeling_utils import ModelMixin
19
+
20
+ from .unet_2d_blocks import (
21
+ CrossAttnDownBlock2D,
22
+ DownBlock2D,
23
+ UNetMidBlock2D,
24
+ UNetMidBlock2DCrossAttn,
25
+ get_down_block,
26
+ get_mid_block,
27
+ get_up_block,
28
+ MidBlock2D
29
+ )
30
+
31
+ from .unet_2d_condition import UNet2DConditionModel
32
+
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ @dataclass
38
+ class BrushNetOutput(BaseOutput):
39
+ """
40
+ The output of [`BrushNetModel`].
41
+
42
+ Args:
43
+ up_block_res_samples (`tuple[torch.Tensor]`):
44
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
45
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
46
+ used to condition the original UNet's upsampling activations.
47
+ down_block_res_samples (`tuple[torch.Tensor]`):
48
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
49
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
50
+ used to condition the original UNet's downsampling activations.
51
+ mid_down_block_re_sample (`torch.Tensor`):
52
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
53
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
54
+ Output can be used to condition the original UNet's middle block activation.
55
+ """
56
+
57
+ up_block_res_samples: Tuple[torch.Tensor]
58
+ down_block_res_samples: Tuple[torch.Tensor]
59
+ mid_block_res_sample: torch.Tensor
60
+
61
+
62
+ class BrushNetModel(ModelMixin, ConfigMixin):
63
+ """
64
+ A BrushNet model.
65
+
66
+ Args:
67
+ in_channels (`int`, defaults to 4):
68
+ The number of channels in the input sample.
69
+ flip_sin_to_cos (`bool`, defaults to `True`):
70
+ Whether to flip the sin to cos in the time embedding.
71
+ freq_shift (`int`, defaults to 0):
72
+ The frequency shift to apply to the time embedding.
73
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
74
+ The tuple of downsample blocks to use.
75
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
76
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
77
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
78
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
79
+ The tuple of upsample blocks to use.
80
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
81
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
82
+ The tuple of output channels for each block.
83
+ layers_per_block (`int`, defaults to 2):
84
+ The number of layers per block.
85
+ downsample_padding (`int`, defaults to 1):
86
+ The padding to use for the downsampling convolution.
87
+ mid_block_scale_factor (`float`, defaults to 1):
88
+ The scale factor to use for the mid block.
89
+ act_fn (`str`, defaults to "silu"):
90
+ The activation function to use.
91
+ norm_num_groups (`int`, *optional*, defaults to 32):
92
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
93
+ in post-processing.
94
+ norm_eps (`float`, defaults to 1e-5):
95
+ The epsilon to use for the normalization.
96
+ cross_attention_dim (`int`, defaults to 1280):
97
+ The dimension of the cross attention features.
98
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
99
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
100
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
101
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
102
+ encoder_hid_dim (`int`, *optional*, defaults to None):
103
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
104
+ dimension to `cross_attention_dim`.
105
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
106
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
107
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
108
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
109
+ The dimension of the attention heads.
110
+ use_linear_projection (`bool`, defaults to `False`):
111
+ class_embed_type (`str`, *optional*, defaults to `None`):
112
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
113
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
114
+ addition_embed_type (`str`, *optional*, defaults to `None`):
115
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
116
+ "text". "text" will use the `TextTimeEmbedding` layer.
117
+ num_class_embeds (`int`, *optional*, defaults to 0):
118
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
119
+ class conditioning with `class_embed_type` equal to `None`.
120
+ upcast_attention (`bool`, defaults to `False`):
121
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
122
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
123
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
124
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
125
+ `class_embed_type="projection"`.
126
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
127
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
128
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
129
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
130
+ global_pool_conditions (`bool`, defaults to `False`):
131
+ TODO(Patrick) - unused parameter.
132
+ addition_embed_type_num_heads (`int`, defaults to 64):
133
+ The number of heads to use for the `TextTimeEmbedding` layer.
134
+ """
135
+
136
+ _supports_gradient_checkpointing = True
137
+
138
+ @register_to_config
139
+ def __init__(
140
+ self,
141
+ in_channels: int = 4,
142
+ conditioning_channels: int = 5,
143
+ flip_sin_to_cos: bool = True,
144
+ freq_shift: int = 0,
145
+ down_block_types: Tuple[str, ...] = (
146
+ "DownBlock2D",
147
+ "DownBlock2D",
148
+ "DownBlock2D",
149
+ "DownBlock2D",
150
+ ),
151
+ mid_block_type: Optional[str] = "UNetMidBlock2D",
152
+ up_block_types: Tuple[str, ...] = (
153
+ "UpBlock2D",
154
+ "UpBlock2D",
155
+ "UpBlock2D",
156
+ "UpBlock2D",
157
+ ),
158
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
159
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
160
+ layers_per_block: int = 2,
161
+ downsample_padding: int = 1,
162
+ mid_block_scale_factor: float = 1,
163
+ act_fn: str = "silu",
164
+ norm_num_groups: Optional[int] = 32,
165
+ norm_eps: float = 1e-5,
166
+ cross_attention_dim: int = 1280,
167
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
168
+ encoder_hid_dim: Optional[int] = None,
169
+ encoder_hid_dim_type: Optional[str] = None,
170
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
171
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
172
+ use_linear_projection: bool = False,
173
+ class_embed_type: Optional[str] = None,
174
+ addition_embed_type: Optional[str] = None,
175
+ addition_time_embed_dim: Optional[int] = None,
176
+ num_class_embeds: Optional[int] = None,
177
+ upcast_attention: bool = False,
178
+ resnet_time_scale_shift: str = "default",
179
+ projection_class_embeddings_input_dim: Optional[int] = None,
180
+ brushnet_conditioning_channel_order: str = "rgb",
181
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
182
+ global_pool_conditions: bool = False,
183
+ addition_embed_type_num_heads: int = 64,
184
+ ):
185
+ super().__init__()
186
+
187
+ # If `num_attention_heads` is not defined (which is the case for most models)
188
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
189
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
190
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
191
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
192
+ # which is why we correct for the naming here.
193
+ num_attention_heads = num_attention_heads or attention_head_dim
194
+
195
+ # Check inputs
196
+ if len(down_block_types) != len(up_block_types):
197
+ raise ValueError(
198
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
199
+ )
200
+
201
+ if len(block_out_channels) != len(down_block_types):
202
+ raise ValueError(
203
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
204
+ )
205
+
206
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
207
+ raise ValueError(
208
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
209
+ )
210
+
211
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
212
+ raise ValueError(
213
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
214
+ )
215
+
216
+ if isinstance(transformer_layers_per_block, int):
217
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
218
+
219
+ # input
220
+ conv_in_kernel = 3
221
+ conv_in_padding = (conv_in_kernel - 1) // 2
222
+ self.conv_in_condition = nn.Conv2d(
223
+ in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
224
+ )
225
+
226
+ # time
227
+ time_embed_dim = block_out_channels[0] * 4
228
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
229
+ timestep_input_dim = block_out_channels[0]
230
+ self.time_embedding = TimestepEmbedding(
231
+ timestep_input_dim,
232
+ time_embed_dim,
233
+ act_fn=act_fn,
234
+ )
235
+
236
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
237
+ encoder_hid_dim_type = "text_proj"
238
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
239
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
240
+
241
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
242
+ raise ValueError(
243
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
244
+ )
245
+
246
+ if encoder_hid_dim_type == "text_proj":
247
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
248
+ elif encoder_hid_dim_type == "text_image_proj":
249
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
250
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
251
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
252
+ self.encoder_hid_proj = TextImageProjection(
253
+ text_embed_dim=encoder_hid_dim,
254
+ image_embed_dim=cross_attention_dim,
255
+ cross_attention_dim=cross_attention_dim,
256
+ )
257
+
258
+ elif encoder_hid_dim_type is not None:
259
+ raise ValueError(
260
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
261
+ )
262
+ else:
263
+ self.encoder_hid_proj = None
264
+
265
+ # class embedding
266
+ if class_embed_type is None and num_class_embeds is not None:
267
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
268
+ elif class_embed_type == "timestep":
269
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
270
+ elif class_embed_type == "identity":
271
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
272
+ elif class_embed_type == "projection":
273
+ if projection_class_embeddings_input_dim is None:
274
+ raise ValueError(
275
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
276
+ )
277
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
278
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
279
+ # 2. it projects from an arbitrary input dimension.
280
+ #
281
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
282
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
283
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
284
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
285
+ else:
286
+ self.class_embedding = None
287
+
288
+ if addition_embed_type == "text":
289
+ if encoder_hid_dim is not None:
290
+ text_time_embedding_from_dim = encoder_hid_dim
291
+ else:
292
+ text_time_embedding_from_dim = cross_attention_dim
293
+
294
+ self.add_embedding = TextTimeEmbedding(
295
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
296
+ )
297
+ elif addition_embed_type == "text_image":
298
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
299
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
300
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
301
+ self.add_embedding = TextImageTimeEmbedding(
302
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
303
+ )
304
+ elif addition_embed_type == "text_time":
305
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
306
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
307
+
308
+ elif addition_embed_type is not None:
309
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
310
+
311
+ self.down_blocks = nn.ModuleList([])
312
+ self.brushnet_down_blocks = nn.ModuleList([])
313
+
314
+ if isinstance(only_cross_attention, bool):
315
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
316
+
317
+ if isinstance(attention_head_dim, int):
318
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
319
+
320
+ if isinstance(num_attention_heads, int):
321
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
322
+
323
+ # down
324
+ output_channel = block_out_channels[0]
325
+
326
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
327
+ brushnet_block = zero_module(brushnet_block)
328
+ self.brushnet_down_blocks.append(brushnet_block)
329
+
330
+ for i, down_block_type in enumerate(down_block_types):
331
+ input_channel = output_channel
332
+ output_channel = block_out_channels[i]
333
+ is_final_block = i == len(block_out_channels) - 1
334
+
335
+ down_block = get_down_block(
336
+ down_block_type,
337
+ num_layers=layers_per_block,
338
+ transformer_layers_per_block=transformer_layers_per_block[i],
339
+ in_channels=input_channel,
340
+ out_channels=output_channel,
341
+ temb_channels=time_embed_dim,
342
+ add_downsample=not is_final_block,
343
+ resnet_eps=norm_eps,
344
+ resnet_act_fn=act_fn,
345
+ resnet_groups=norm_num_groups,
346
+ cross_attention_dim=cross_attention_dim,
347
+ num_attention_heads=num_attention_heads[i],
348
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
349
+ downsample_padding=downsample_padding,
350
+ use_linear_projection=use_linear_projection,
351
+ only_cross_attention=only_cross_attention[i],
352
+ upcast_attention=upcast_attention,
353
+ resnet_time_scale_shift=resnet_time_scale_shift,
354
+ )
355
+ self.down_blocks.append(down_block)
356
+
357
+ for _ in range(layers_per_block):
358
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
359
+ brushnet_block = zero_module(brushnet_block)
360
+ self.brushnet_down_blocks.append(brushnet_block)
361
+
362
+ if not is_final_block:
363
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
364
+ brushnet_block = zero_module(brushnet_block)
365
+ self.brushnet_down_blocks.append(brushnet_block)
366
+
367
+ # mid
368
+ mid_block_channel = block_out_channels[-1]
369
+
370
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
371
+ brushnet_block = zero_module(brushnet_block)
372
+ self.brushnet_mid_block = brushnet_block
373
+
374
+ self.mid_block = get_mid_block(
375
+ mid_block_type,
376
+ transformer_layers_per_block=transformer_layers_per_block[-1],
377
+ in_channels=mid_block_channel,
378
+ temb_channels=time_embed_dim,
379
+ resnet_eps=norm_eps,
380
+ resnet_act_fn=act_fn,
381
+ output_scale_factor=mid_block_scale_factor,
382
+ resnet_time_scale_shift=resnet_time_scale_shift,
383
+ cross_attention_dim=cross_attention_dim,
384
+ num_attention_heads=num_attention_heads[-1],
385
+ resnet_groups=norm_num_groups,
386
+ use_linear_projection=use_linear_projection,
387
+ upcast_attention=upcast_attention,
388
+ )
389
+
390
+ # count how many layers upsample the images
391
+ self.num_upsamplers = 0
392
+
393
+ # up
394
+ reversed_block_out_channels = list(reversed(block_out_channels))
395
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
396
+ reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
397
+ only_cross_attention = list(reversed(only_cross_attention))
398
+
399
+ output_channel = reversed_block_out_channels[0]
400
+
401
+ self.up_blocks = nn.ModuleList([])
402
+ self.brushnet_up_blocks = nn.ModuleList([])
403
+
404
+ for i, up_block_type in enumerate(up_block_types):
405
+ is_final_block = i == len(block_out_channels) - 1
406
+
407
+ prev_output_channel = output_channel
408
+ output_channel = reversed_block_out_channels[i]
409
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
410
+
411
+ # add upsample block for all BUT final layer
412
+ if not is_final_block:
413
+ add_upsample = True
414
+ self.num_upsamplers += 1
415
+ else:
416
+ add_upsample = False
417
+
418
+ up_block = get_up_block(
419
+ up_block_type,
420
+ num_layers=layers_per_block+1,
421
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
422
+ in_channels=input_channel,
423
+ out_channels=output_channel,
424
+ prev_output_channel=prev_output_channel,
425
+ temb_channels=time_embed_dim,
426
+ add_upsample=add_upsample,
427
+ resnet_eps=norm_eps,
428
+ resnet_act_fn=act_fn,
429
+ resolution_idx=i,
430
+ resnet_groups=norm_num_groups,
431
+ cross_attention_dim=cross_attention_dim,
432
+ num_attention_heads=reversed_num_attention_heads[i],
433
+ use_linear_projection=use_linear_projection,
434
+ only_cross_attention=only_cross_attention[i],
435
+ upcast_attention=upcast_attention,
436
+ resnet_time_scale_shift=resnet_time_scale_shift,
437
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
438
+ )
439
+ self.up_blocks.append(up_block)
440
+ prev_output_channel = output_channel
441
+
442
+ for _ in range(layers_per_block+1):
443
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
444
+ brushnet_block = zero_module(brushnet_block)
445
+ self.brushnet_up_blocks.append(brushnet_block)
446
+
447
+ if not is_final_block:
448
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
449
+ brushnet_block = zero_module(brushnet_block)
450
+ self.brushnet_up_blocks.append(brushnet_block)
451
+
452
+
453
+ @classmethod
454
+ def from_unet(
455
+ cls,
456
+ unet: UNet2DConditionModel,
457
+ brushnet_conditioning_channel_order: str = "rgb",
458
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
459
+ load_weights_from_unet: bool = True,
460
+ conditioning_channels: int = 5,
461
+ ):
462
+ r"""
463
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
464
+
465
+ Parameters:
466
+ unet (`UNet2DConditionModel`):
467
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
468
+ where applicable.
469
+ """
470
+ transformer_layers_per_block = (
471
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
472
+ )
473
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
474
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
475
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
476
+ addition_time_embed_dim = (
477
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
478
+ )
479
+
480
+ brushnet = cls(
481
+ in_channels=unet.config.in_channels,
482
+ conditioning_channels=conditioning_channels,
483
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
484
+ freq_shift=unet.config.freq_shift,
485
+ down_block_types=["DownBlock2D" for block_name in unet.config.down_block_types],
486
+ mid_block_type='MidBlock2D',
487
+ up_block_types=["UpBlock2D" for block_name in unet.config.down_block_types],
488
+ only_cross_attention=unet.config.only_cross_attention,
489
+ block_out_channels=unet.config.block_out_channels,
490
+ layers_per_block=unet.config.layers_per_block,
491
+ downsample_padding=unet.config.downsample_padding,
492
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
493
+ act_fn=unet.config.act_fn,
494
+ norm_num_groups=unet.config.norm_num_groups,
495
+ norm_eps=unet.config.norm_eps,
496
+ cross_attention_dim=unet.config.cross_attention_dim,
497
+ transformer_layers_per_block=transformer_layers_per_block,
498
+ encoder_hid_dim=encoder_hid_dim,
499
+ encoder_hid_dim_type=encoder_hid_dim_type,
500
+ attention_head_dim=unet.config.attention_head_dim,
501
+ num_attention_heads=unet.config.num_attention_heads,
502
+ use_linear_projection=unet.config.use_linear_projection,
503
+ class_embed_type=unet.config.class_embed_type,
504
+ addition_embed_type=addition_embed_type,
505
+ addition_time_embed_dim=addition_time_embed_dim,
506
+ num_class_embeds=unet.config.num_class_embeds,
507
+ upcast_attention=unet.config.upcast_attention,
508
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
509
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
510
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
511
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
512
+ )
513
+
514
+ if load_weights_from_unet:
515
+ conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
516
+ conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
517
+ conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
518
+ brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
519
+ brushnet.conv_in_condition.bias=unet.conv_in.bias
520
+
521
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
522
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
523
+
524
+ if brushnet.class_embedding:
525
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
526
+
527
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
528
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
529
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)
530
+
531
+ return brushnet
532
+
533
+ @property
534
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
535
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
536
+ r"""
537
+ Returns:
538
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
539
+ indexed by its weight name.
540
+ """
541
+ # set recursively
542
+ processors = {}
543
+
544
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
545
+ if hasattr(module, "get_processor"):
546
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
547
+
548
+ for sub_name, child in module.named_children():
549
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
550
+
551
+ return processors
552
+
553
+ for name, module in self.named_children():
554
+ fn_recursive_add_processors(name, module, processors)
555
+
556
+ return processors
557
+
558
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
559
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
560
+ r"""
561
+ Sets the attention processor to use to compute attention.
562
+
563
+ Parameters:
564
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
565
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
566
+ for **all** `Attention` layers.
567
+
568
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
569
+ processor. This is strongly recommended when setting trainable attention processors.
570
+
571
+ """
572
+ count = len(self.attn_processors.keys())
573
+
574
+ if isinstance(processor, dict) and len(processor) != count:
575
+ raise ValueError(
576
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
577
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
578
+ )
579
+
580
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
581
+ if hasattr(module, "set_processor"):
582
+ if not isinstance(processor, dict):
583
+ module.set_processor(processor)
584
+ else:
585
+ module.set_processor(processor.pop(f"{name}.processor"))
586
+
587
+ for sub_name, child in module.named_children():
588
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
589
+
590
+ for name, module in self.named_children():
591
+ fn_recursive_attn_processor(name, module, processor)
592
+
593
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
594
+ def set_default_attn_processor(self):
595
+ """
596
+ Disables custom attention processors and sets the default attention implementation.
597
+ """
598
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
599
+ processor = AttnAddedKVProcessor()
600
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
601
+ processor = AttnProcessor()
602
+ else:
603
+ raise ValueError(
604
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
605
+ )
606
+
607
+ self.set_attn_processor(processor)
608
+
609
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
610
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
611
+ r"""
612
+ Enable sliced attention computation.
613
+
614
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
615
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
616
+
617
+ Args:
618
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
619
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
620
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
621
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
622
+ must be a multiple of `slice_size`.
623
+ """
624
+ sliceable_head_dims = []
625
+
626
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
627
+ if hasattr(module, "set_attention_slice"):
628
+ sliceable_head_dims.append(module.sliceable_head_dim)
629
+
630
+ for child in module.children():
631
+ fn_recursive_retrieve_sliceable_dims(child)
632
+
633
+ # retrieve number of attention layers
634
+ for module in self.children():
635
+ fn_recursive_retrieve_sliceable_dims(module)
636
+
637
+ num_sliceable_layers = len(sliceable_head_dims)
638
+
639
+ if slice_size == "auto":
640
+ # half the attention head size is usually a good trade-off between
641
+ # speed and memory
642
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
643
+ elif slice_size == "max":
644
+ # make smallest slice possible
645
+ slice_size = num_sliceable_layers * [1]
646
+
647
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
648
+
649
+ if len(slice_size) != len(sliceable_head_dims):
650
+ raise ValueError(
651
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
652
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
653
+ )
654
+
655
+ for i in range(len(slice_size)):
656
+ size = slice_size[i]
657
+ dim = sliceable_head_dims[i]
658
+ if size is not None and size > dim:
659
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
660
+
661
+ # Recursively walk through all the children.
662
+ # Any children which exposes the set_attention_slice method
663
+ # gets the message
664
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
665
+ if hasattr(module, "set_attention_slice"):
666
+ module.set_attention_slice(slice_size.pop())
667
+
668
+ for child in module.children():
669
+ fn_recursive_set_attention_slice(child, slice_size)
670
+
671
+ reversed_slice_size = list(reversed(slice_size))
672
+ for module in self.children():
673
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
674
+
675
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
676
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
677
+ module.gradient_checkpointing = value
678
+
679
+ def forward(
680
+ self,
681
+ sample: torch.FloatTensor,
682
+ encoder_hidden_states: torch.Tensor,
683
+ brushnet_cond: torch.FloatTensor,
684
+ timestep = None,
685
+ time_emb = None,
686
+ conditioning_scale: float = 1.0,
687
+ class_labels: Optional[torch.Tensor] = None,
688
+ timestep_cond: Optional[torch.Tensor] = None,
689
+ attention_mask: Optional[torch.Tensor] = None,
690
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
691
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
692
+ guess_mode: bool = False,
693
+ return_dict: bool = True,
694
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
695
+ """
696
+ The [`BrushNetModel`] forward method.
697
+
698
+ Args:
699
+ sample (`torch.FloatTensor`):
700
+ The noisy input tensor.
701
+ timestep (`Union[torch.Tensor, float, int]`):
702
+ The number of timesteps to denoise an input.
703
+ encoder_hidden_states (`torch.Tensor`):
704
+ The encoder hidden states.
705
+ brushnet_cond (`torch.FloatTensor`):
706
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
707
+ conditioning_scale (`float`, defaults to `1.0`):
708
+ The scale factor for BrushNet outputs.
709
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
710
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
711
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
712
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
713
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
714
+ embeddings.
715
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
716
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
717
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
718
+ negative values to the attention scores corresponding to "discard" tokens.
719
+ added_cond_kwargs (`dict`):
720
+ Additional conditions for the Stable Diffusion XL UNet.
721
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
722
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
723
+ guess_mode (`bool`, defaults to `False`):
724
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
725
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
726
+ return_dict (`bool`, defaults to `True`):
727
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
728
+
729
+ Returns:
730
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
731
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
732
+ returned where the first element is the sample tensor.
733
+ """
734
+
735
+ # check channel order
736
+ channel_order = self.config.brushnet_conditioning_channel_order
737
+
738
+ if channel_order == "rgb":
739
+ # in rgb order by default
740
+ ...
741
+ elif channel_order == "bgr":
742
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
743
+ else:
744
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
745
+
746
+ # prepare attention_mask
747
+ if attention_mask is not None:
748
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
749
+ attention_mask = attention_mask.unsqueeze(1)
750
+
751
+ if timestep is None and time_emb is None:
752
+ raise ValueError(f"`timestep` and `emb` are both None")
753
+
754
+ #print("BN: sample.device", sample.device)
755
+ #print("BN: TE.device", self.time_embedding.linear_1.weight.device)
756
+
757
+ if timestep is not None:
758
+ # 1. time
759
+ timesteps = timestep
760
+ if not torch.is_tensor(timesteps):
761
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
762
+ # This would be a good case for the `match` statement (Python 3.10+)
763
+ is_mps = sample.device.type == "mps"
764
+ if isinstance(timestep, float):
765
+ dtype = torch.float32 if is_mps else torch.float64
766
+ else:
767
+ dtype = torch.int32 if is_mps else torch.int64
768
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
769
+ elif len(timesteps.shape) == 0:
770
+ timesteps = timesteps[None].to(sample.device)
771
+
772
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
773
+ timesteps = timesteps.expand(sample.shape[0])
774
+
775
+ t_emb = self.time_proj(timesteps)
776
+
777
+ # timesteps does not contain any weights and will always return f32 tensors
778
+ # but time_embedding might actually be running in fp16. so we need to cast here.
779
+ # there might be better ways to encapsulate this.
780
+ t_emb = t_emb.to(dtype=sample.dtype)
781
+
782
+ #print("t_emb.device =",t_emb.device)
783
+
784
+ emb = self.time_embedding(t_emb, timestep_cond)
785
+ aug_emb = None
786
+
787
+ #print('emb.shape', emb.shape)
788
+
789
+ if self.class_embedding is not None:
790
+ if class_labels is None:
791
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
792
+
793
+ if self.config.class_embed_type == "timestep":
794
+ class_labels = self.time_proj(class_labels)
795
+
796
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
797
+ emb = emb + class_emb
798
+
799
+ if self.config.addition_embed_type is not None:
800
+ if self.config.addition_embed_type == "text":
801
+ aug_emb = self.add_embedding(encoder_hidden_states)
802
+
803
+ elif self.config.addition_embed_type == "text_time":
804
+ if "text_embeds" not in added_cond_kwargs:
805
+ raise ValueError(
806
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
807
+ )
808
+ text_embeds = added_cond_kwargs.get("text_embeds")
809
+ if "time_ids" not in added_cond_kwargs:
810
+ raise ValueError(
811
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
812
+ )
813
+ time_ids = added_cond_kwargs.get("time_ids")
814
+ time_embeds = self.add_time_proj(time_ids.flatten())
815
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
816
+
817
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
818
+ add_embeds = add_embeds.to(emb.dtype)
819
+ aug_emb = self.add_embedding(add_embeds)
820
+
821
+ #print('text_embeds', text_embeds.shape, 'time_ids', time_ids.shape, 'time_embeds', time_embeds.shape, 'add__embeds', add_embeds.shape, 'aug_emb', aug_emb.shape)
822
+
823
+ emb = emb + aug_emb if aug_emb is not None else emb
824
+ else:
825
+ emb = time_emb
826
+
827
+ # 2. pre-process
828
+
829
+ brushnet_cond=torch.concat([sample,brushnet_cond],1)
830
+ sample = self.conv_in_condition(brushnet_cond)
831
+
832
+ # 3. down
833
+ down_block_res_samples = (sample,)
834
+ for downsample_block in self.down_blocks:
835
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
836
+ sample, res_samples = downsample_block(
837
+ hidden_states=sample,
838
+ temb=emb,
839
+ encoder_hidden_states=encoder_hidden_states,
840
+ attention_mask=attention_mask,
841
+ cross_attention_kwargs=cross_attention_kwargs,
842
+ )
843
+ else:
844
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
845
+
846
+ down_block_res_samples += res_samples
847
+
848
+ # 4. PaintingNet down blocks
849
+ brushnet_down_block_res_samples = ()
850
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
851
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
852
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
853
+
854
+
855
+ # 5. mid
856
+ if self.mid_block is not None:
857
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
858
+ sample = self.mid_block(
859
+ sample,
860
+ emb,
861
+ encoder_hidden_states=encoder_hidden_states,
862
+ attention_mask=attention_mask,
863
+ cross_attention_kwargs=cross_attention_kwargs,
864
+ )
865
+ else:
866
+ sample = self.mid_block(sample, emb)
867
+
868
+ # 6. BrushNet mid blocks
869
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
870
+
871
+ # 7. up
872
+ up_block_res_samples = ()
873
+ for i, upsample_block in enumerate(self.up_blocks):
874
+ is_final_block = i == len(self.up_blocks) - 1
875
+
876
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
877
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
878
+
879
+ # if we have not reached the final block and need to forward the
880
+ # upsample size, we do it here
881
+ if not is_final_block:
882
+ upsample_size = down_block_res_samples[-1].shape[2:]
883
+
884
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
885
+ sample, up_res_samples = upsample_block(
886
+ hidden_states=sample,
887
+ temb=emb,
888
+ res_hidden_states_tuple=res_samples,
889
+ encoder_hidden_states=encoder_hidden_states,
890
+ cross_attention_kwargs=cross_attention_kwargs,
891
+ upsample_size=upsample_size,
892
+ attention_mask=attention_mask,
893
+ return_res_samples=True
894
+ )
895
+ else:
896
+ sample, up_res_samples = upsample_block(
897
+ hidden_states=sample,
898
+ temb=emb,
899
+ res_hidden_states_tuple=res_samples,
900
+ upsample_size=upsample_size,
901
+ return_res_samples=True
902
+ )
903
+
904
+ up_block_res_samples += up_res_samples
905
+
906
+ # 8. BrushNet up blocks
907
+ brushnet_up_block_res_samples = ()
908
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
909
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
910
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
911
+
912
+ # 6. scaling
913
+ if guess_mode and not self.config.global_pool_conditions:
914
+ scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0
915
+ scales = scales * conditioning_scale
916
+
917
+ brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
918
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
919
+ brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
920
+ else:
921
+ brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
922
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
923
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
924
+
925
+
926
+ if self.config.global_pool_conditions:
927
+ brushnet_down_block_res_samples = [
928
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
929
+ ]
930
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
931
+ brushnet_up_block_res_samples = [
932
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
933
+ ]
934
+
935
+ if not return_dict:
936
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
937
+
938
+ return BrushNetOutput(
939
+ down_block_res_samples=brushnet_down_block_res_samples,
940
+ mid_block_res_sample=brushnet_mid_block_res_sample,
941
+ up_block_res_samples=brushnet_up_block_res_samples
942
+ )
943
+
944
+
945
+ def zero_module(module):
946
+ for p in module.parameters():
947
+ nn.init.zeros_(p)
948
+ return module
ComfyUI-BrushNet/brushnet/brushnet_ca.py ADDED
@@ -0,0 +1,956 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.utils import BaseOutput, logging
9
+ from diffusers.models.attention_processor import (
10
+ ADDED_KV_ATTENTION_PROCESSORS,
11
+ CROSS_ATTENTION_PROCESSORS,
12
+ AttentionProcessor,
13
+ AttnAddedKVProcessor,
14
+ AttnProcessor,
15
+ )
16
+ from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+
19
+ from .unet_2d_blocks import (
20
+ CrossAttnDownBlock2D,
21
+ DownBlock2D,
22
+ UNetMidBlock2D,
23
+ UNetMidBlock2DCrossAttn,
24
+ get_down_block,
25
+ get_mid_block,
26
+ get_up_block,
27
+ MidBlock2D
28
+ )
29
+
30
+ from .unet_2d_condition import UNet2DConditionModel
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class BrushNetOutput(BaseOutput):
38
+ """
39
+ The output of [`BrushNetModel`].
40
+
41
+ Args:
42
+ up_block_res_samples (`tuple[torch.Tensor]`):
43
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
44
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
45
+ used to condition the original UNet's upsampling activations.
46
+ down_block_res_samples (`tuple[torch.Tensor]`):
47
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
48
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
49
+ used to condition the original UNet's downsampling activations.
50
+ mid_down_block_re_sample (`torch.Tensor`):
51
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
52
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
53
+ Output can be used to condition the original UNet's middle block activation.
54
+ """
55
+
56
+ up_block_res_samples: Tuple[torch.Tensor]
57
+ down_block_res_samples: Tuple[torch.Tensor]
58
+ mid_block_res_sample: torch.Tensor
59
+
60
+
61
+ class BrushNetModel(ModelMixin, ConfigMixin):
62
+ """
63
+ A BrushNet model.
64
+
65
+ Args:
66
+ in_channels (`int`, defaults to 4):
67
+ The number of channels in the input sample.
68
+ flip_sin_to_cos (`bool`, defaults to `True`):
69
+ Whether to flip the sin to cos in the time embedding.
70
+ freq_shift (`int`, defaults to 0):
71
+ The frequency shift to apply to the time embedding.
72
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
73
+ The tuple of downsample blocks to use.
74
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
75
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
76
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
77
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
78
+ The tuple of upsample blocks to use.
79
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
80
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
81
+ The tuple of output channels for each block.
82
+ layers_per_block (`int`, defaults to 2):
83
+ The number of layers per block.
84
+ downsample_padding (`int`, defaults to 1):
85
+ The padding to use for the downsampling convolution.
86
+ mid_block_scale_factor (`float`, defaults to 1):
87
+ The scale factor to use for the mid block.
88
+ act_fn (`str`, defaults to "silu"):
89
+ The activation function to use.
90
+ norm_num_groups (`int`, *optional*, defaults to 32):
91
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
92
+ in post-processing.
93
+ norm_eps (`float`, defaults to 1e-5):
94
+ The epsilon to use for the normalization.
95
+ cross_attention_dim (`int`, defaults to 1280):
96
+ The dimension of the cross attention features.
97
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
98
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
99
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
100
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
101
+ encoder_hid_dim (`int`, *optional*, defaults to None):
102
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
103
+ dimension to `cross_attention_dim`.
104
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
105
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
106
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
107
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
108
+ The dimension of the attention heads.
109
+ use_linear_projection (`bool`, defaults to `False`):
110
+ class_embed_type (`str`, *optional*, defaults to `None`):
111
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
112
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
113
+ addition_embed_type (`str`, *optional*, defaults to `None`):
114
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
115
+ "text". "text" will use the `TextTimeEmbedding` layer.
116
+ num_class_embeds (`int`, *optional*, defaults to 0):
117
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
118
+ class conditioning with `class_embed_type` equal to `None`.
119
+ upcast_attention (`bool`, defaults to `False`):
120
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
121
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
122
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
123
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
124
+ `class_embed_type="projection"`.
125
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
126
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
127
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
128
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
129
+ global_pool_conditions (`bool`, defaults to `False`):
130
+ TODO(Patrick) - unused parameter.
131
+ addition_embed_type_num_heads (`int`, defaults to 64):
132
+ The number of heads to use for the `TextTimeEmbedding` layer.
133
+ """
134
+
135
+ _supports_gradient_checkpointing = True
136
+
137
+ @register_to_config
138
+ def __init__(
139
+ self,
140
+ in_channels: int = 4,
141
+ conditioning_channels: int = 5,
142
+ flip_sin_to_cos: bool = True,
143
+ freq_shift: int = 0,
144
+ down_block_types: Tuple[str, ...] = (
145
+ "CrossAttnDownBlock2D",
146
+ "CrossAttnDownBlock2D",
147
+ "CrossAttnDownBlock2D",
148
+ "DownBlock2D",
149
+ ),
150
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
151
+ up_block_types: Tuple[str, ...] = (
152
+ "UpBlock2D",
153
+ "CrossAttnUpBlock2D",
154
+ "CrossAttnUpBlock2D",
155
+ "CrossAttnUpBlock2D",
156
+ ),
157
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
158
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
159
+ layers_per_block: int = 2,
160
+ downsample_padding: int = 1,
161
+ mid_block_scale_factor: float = 1,
162
+ act_fn: str = "silu",
163
+ norm_num_groups: Optional[int] = 32,
164
+ norm_eps: float = 1e-5,
165
+ cross_attention_dim: int = 1280,
166
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
167
+ encoder_hid_dim: Optional[int] = None,
168
+ encoder_hid_dim_type: Optional[str] = None,
169
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
170
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
171
+ use_linear_projection: bool = False,
172
+ class_embed_type: Optional[str] = None,
173
+ addition_embed_type: Optional[str] = None,
174
+ addition_time_embed_dim: Optional[int] = None,
175
+ num_class_embeds: Optional[int] = None,
176
+ upcast_attention: bool = False,
177
+ resnet_time_scale_shift: str = "default",
178
+ projection_class_embeddings_input_dim: Optional[int] = None,
179
+ brushnet_conditioning_channel_order: str = "rgb",
180
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
181
+ global_pool_conditions: bool = False,
182
+ addition_embed_type_num_heads: int = 64,
183
+ ):
184
+ super().__init__()
185
+
186
+ # If `num_attention_heads` is not defined (which is the case for most models)
187
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
188
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
189
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
190
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
191
+ # which is why we correct for the naming here.
192
+ num_attention_heads = num_attention_heads or attention_head_dim
193
+
194
+ # Check inputs
195
+ if len(down_block_types) != len(up_block_types):
196
+ raise ValueError(
197
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
198
+ )
199
+
200
+ if len(block_out_channels) != len(down_block_types):
201
+ raise ValueError(
202
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
203
+ )
204
+
205
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
206
+ raise ValueError(
207
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
208
+ )
209
+
210
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
211
+ raise ValueError(
212
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
213
+ )
214
+
215
+ if isinstance(transformer_layers_per_block, int):
216
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
217
+
218
+ # input
219
+ conv_in_kernel = 3
220
+ conv_in_padding = (conv_in_kernel - 1) // 2
221
+ self.conv_in_condition = nn.Conv2d(
222
+ in_channels + conditioning_channels,
223
+ block_out_channels[0],
224
+ kernel_size=conv_in_kernel,
225
+ padding=conv_in_padding,
226
+ )
227
+
228
+ # time
229
+ time_embed_dim = block_out_channels[0] * 4
230
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
231
+ timestep_input_dim = block_out_channels[0]
232
+ self.time_embedding = TimestepEmbedding(
233
+ timestep_input_dim,
234
+ time_embed_dim,
235
+ act_fn=act_fn,
236
+ )
237
+
238
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
239
+ encoder_hid_dim_type = "text_proj"
240
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
241
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
242
+
243
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
244
+ raise ValueError(
245
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
246
+ )
247
+
248
+ if encoder_hid_dim_type == "text_proj":
249
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
250
+ elif encoder_hid_dim_type == "text_image_proj":
251
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
252
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
253
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
254
+ self.encoder_hid_proj = TextImageProjection(
255
+ text_embed_dim=encoder_hid_dim,
256
+ image_embed_dim=cross_attention_dim,
257
+ cross_attention_dim=cross_attention_dim,
258
+ )
259
+
260
+ elif encoder_hid_dim_type is not None:
261
+ raise ValueError(
262
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
263
+ )
264
+ else:
265
+ self.encoder_hid_proj = None
266
+
267
+ # class embedding
268
+ if class_embed_type is None and num_class_embeds is not None:
269
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
270
+ elif class_embed_type == "timestep":
271
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
272
+ elif class_embed_type == "identity":
273
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
274
+ elif class_embed_type == "projection":
275
+ if projection_class_embeddings_input_dim is None:
276
+ raise ValueError(
277
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
278
+ )
279
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
280
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
281
+ # 2. it projects from an arbitrary input dimension.
282
+ #
283
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
284
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
285
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
286
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
287
+ else:
288
+ self.class_embedding = None
289
+
290
+ if addition_embed_type == "text":
291
+ if encoder_hid_dim is not None:
292
+ text_time_embedding_from_dim = encoder_hid_dim
293
+ else:
294
+ text_time_embedding_from_dim = cross_attention_dim
295
+
296
+ self.add_embedding = TextTimeEmbedding(
297
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
298
+ )
299
+ elif addition_embed_type == "text_image":
300
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
301
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
302
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
303
+ self.add_embedding = TextImageTimeEmbedding(
304
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
305
+ )
306
+ elif addition_embed_type == "text_time":
307
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
308
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
309
+
310
+ elif addition_embed_type is not None:
311
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
312
+
313
+ self.down_blocks = nn.ModuleList([])
314
+ self.brushnet_down_blocks = nn.ModuleList([])
315
+
316
+ if isinstance(only_cross_attention, bool):
317
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
318
+
319
+ if isinstance(attention_head_dim, int):
320
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
321
+
322
+ if isinstance(num_attention_heads, int):
323
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
324
+
325
+ # down
326
+ output_channel = block_out_channels[0]
327
+
328
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
329
+ brushnet_block = zero_module(brushnet_block)
330
+ self.brushnet_down_blocks.append(brushnet_block)
331
+
332
+ for i, down_block_type in enumerate(down_block_types):
333
+ input_channel = output_channel
334
+ output_channel = block_out_channels[i]
335
+ is_final_block = i == len(block_out_channels) - 1
336
+
337
+ down_block = get_down_block(
338
+ down_block_type,
339
+ num_layers=layers_per_block,
340
+ transformer_layers_per_block=transformer_layers_per_block[i],
341
+ in_channels=input_channel,
342
+ out_channels=output_channel,
343
+ temb_channels=time_embed_dim,
344
+ add_downsample=not is_final_block,
345
+ resnet_eps=norm_eps,
346
+ resnet_act_fn=act_fn,
347
+ resnet_groups=norm_num_groups,
348
+ cross_attention_dim=cross_attention_dim,
349
+ num_attention_heads=num_attention_heads[i],
350
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
351
+ downsample_padding=downsample_padding,
352
+ use_linear_projection=use_linear_projection,
353
+ only_cross_attention=only_cross_attention[i],
354
+ upcast_attention=upcast_attention,
355
+ resnet_time_scale_shift=resnet_time_scale_shift,
356
+ )
357
+ self.down_blocks.append(down_block)
358
+
359
+ for _ in range(layers_per_block):
360
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
361
+ brushnet_block = zero_module(brushnet_block)
362
+ self.brushnet_down_blocks.append(brushnet_block)
363
+
364
+ if not is_final_block:
365
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
366
+ brushnet_block = zero_module(brushnet_block)
367
+ self.brushnet_down_blocks.append(brushnet_block)
368
+
369
+ # mid
370
+ mid_block_channel = block_out_channels[-1]
371
+
372
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
373
+ brushnet_block = zero_module(brushnet_block)
374
+ self.brushnet_mid_block = brushnet_block
375
+
376
+ self.mid_block = get_mid_block(
377
+ mid_block_type,
378
+ transformer_layers_per_block=transformer_layers_per_block[-1],
379
+ in_channels=mid_block_channel,
380
+ temb_channels=time_embed_dim,
381
+ resnet_eps=norm_eps,
382
+ resnet_act_fn=act_fn,
383
+ output_scale_factor=mid_block_scale_factor,
384
+ resnet_time_scale_shift=resnet_time_scale_shift,
385
+ cross_attention_dim=cross_attention_dim,
386
+ num_attention_heads=num_attention_heads[-1],
387
+ resnet_groups=norm_num_groups,
388
+ use_linear_projection=use_linear_projection,
389
+ upcast_attention=upcast_attention,
390
+ )
391
+
392
+ # count how many layers upsample the images
393
+ self.num_upsamplers = 0
394
+
395
+ # up
396
+ reversed_block_out_channels = list(reversed(block_out_channels))
397
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
398
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
399
+ only_cross_attention = list(reversed(only_cross_attention))
400
+
401
+ output_channel = reversed_block_out_channels[0]
402
+
403
+ self.up_blocks = nn.ModuleList([])
404
+ self.brushnet_up_blocks = nn.ModuleList([])
405
+
406
+ for i, up_block_type in enumerate(up_block_types):
407
+ is_final_block = i == len(block_out_channels) - 1
408
+
409
+ prev_output_channel = output_channel
410
+ output_channel = reversed_block_out_channels[i]
411
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
412
+
413
+ # add upsample block for all BUT final layer
414
+ if not is_final_block:
415
+ add_upsample = True
416
+ self.num_upsamplers += 1
417
+ else:
418
+ add_upsample = False
419
+
420
+ up_block = get_up_block(
421
+ up_block_type,
422
+ num_layers=layers_per_block + 1,
423
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
424
+ in_channels=input_channel,
425
+ out_channels=output_channel,
426
+ prev_output_channel=prev_output_channel,
427
+ temb_channels=time_embed_dim,
428
+ add_upsample=add_upsample,
429
+ resnet_eps=norm_eps,
430
+ resnet_act_fn=act_fn,
431
+ resolution_idx=i,
432
+ resnet_groups=norm_num_groups,
433
+ cross_attention_dim=cross_attention_dim,
434
+ num_attention_heads=reversed_num_attention_heads[i],
435
+ use_linear_projection=use_linear_projection,
436
+ only_cross_attention=only_cross_attention[i],
437
+ upcast_attention=upcast_attention,
438
+ resnet_time_scale_shift=resnet_time_scale_shift,
439
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
440
+ )
441
+ self.up_blocks.append(up_block)
442
+ prev_output_channel = output_channel
443
+
444
+ for _ in range(layers_per_block + 1):
445
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
446
+ brushnet_block = zero_module(brushnet_block)
447
+ self.brushnet_up_blocks.append(brushnet_block)
448
+
449
+ if not is_final_block:
450
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
451
+ brushnet_block = zero_module(brushnet_block)
452
+ self.brushnet_up_blocks.append(brushnet_block)
453
+
454
+ @classmethod
455
+ def from_unet(
456
+ cls,
457
+ unet: UNet2DConditionModel,
458
+ brushnet_conditioning_channel_order: str = "rgb",
459
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
460
+ load_weights_from_unet: bool = True,
461
+ conditioning_channels: int = 5,
462
+ ):
463
+ r"""
464
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
465
+
466
+ Parameters:
467
+ unet (`UNet2DConditionModel`):
468
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
469
+ where applicable.
470
+ """
471
+ transformer_layers_per_block = (
472
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
473
+ )
474
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
475
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
476
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
477
+ addition_time_embed_dim = (
478
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
479
+ )
480
+
481
+ brushnet = cls(
482
+ in_channels=unet.config.in_channels,
483
+ conditioning_channels=conditioning_channels,
484
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
485
+ freq_shift=unet.config.freq_shift,
486
+ # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
487
+ down_block_types=[
488
+ "CrossAttnDownBlock2D",
489
+ "CrossAttnDownBlock2D",
490
+ "CrossAttnDownBlock2D",
491
+ "DownBlock2D",
492
+ ],
493
+ # mid_block_type='MidBlock2D',
494
+ mid_block_type="UNetMidBlock2DCrossAttn",
495
+ # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
496
+ up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
497
+ only_cross_attention=unet.config.only_cross_attention,
498
+ block_out_channels=unet.config.block_out_channels,
499
+ layers_per_block=unet.config.layers_per_block,
500
+ downsample_padding=unet.config.downsample_padding,
501
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
502
+ act_fn=unet.config.act_fn,
503
+ norm_num_groups=unet.config.norm_num_groups,
504
+ norm_eps=unet.config.norm_eps,
505
+ cross_attention_dim=unet.config.cross_attention_dim,
506
+ transformer_layers_per_block=transformer_layers_per_block,
507
+ encoder_hid_dim=encoder_hid_dim,
508
+ encoder_hid_dim_type=encoder_hid_dim_type,
509
+ attention_head_dim=unet.config.attention_head_dim,
510
+ num_attention_heads=unet.config.num_attention_heads,
511
+ use_linear_projection=unet.config.use_linear_projection,
512
+ class_embed_type=unet.config.class_embed_type,
513
+ addition_embed_type=addition_embed_type,
514
+ addition_time_embed_dim=addition_time_embed_dim,
515
+ num_class_embeds=unet.config.num_class_embeds,
516
+ upcast_attention=unet.config.upcast_attention,
517
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
518
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
519
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
520
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
521
+ )
522
+
523
+ if load_weights_from_unet:
524
+ conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
525
+ conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
526
+ conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
527
+ brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
528
+ brushnet.conv_in_condition.bias = unet.conv_in.bias
529
+
530
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
531
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
532
+
533
+ if brushnet.class_embedding:
534
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
535
+
536
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
537
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
538
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
539
+
540
+ return brushnet.to(unet.dtype)
541
+
542
+ @property
543
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
544
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
545
+ r"""
546
+ Returns:
547
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
548
+ indexed by its weight name.
549
+ """
550
+ # set recursively
551
+ processors = {}
552
+
553
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
554
+ if hasattr(module, "get_processor"):
555
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
556
+
557
+ for sub_name, child in module.named_children():
558
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
559
+
560
+ return processors
561
+
562
+ for name, module in self.named_children():
563
+ fn_recursive_add_processors(name, module, processors)
564
+
565
+ return processors
566
+
567
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
568
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
569
+ r"""
570
+ Sets the attention processor to use to compute attention.
571
+
572
+ Parameters:
573
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
574
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
575
+ for **all** `Attention` layers.
576
+
577
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
578
+ processor. This is strongly recommended when setting trainable attention processors.
579
+
580
+ """
581
+ count = len(self.attn_processors.keys())
582
+
583
+ if isinstance(processor, dict) and len(processor) != count:
584
+ raise ValueError(
585
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
586
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
587
+ )
588
+
589
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
590
+ if hasattr(module, "set_processor"):
591
+ if not isinstance(processor, dict):
592
+ module.set_processor(processor)
593
+ else:
594
+ module.set_processor(processor.pop(f"{name}.processor"))
595
+
596
+ for sub_name, child in module.named_children():
597
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
598
+
599
+ for name, module in self.named_children():
600
+ fn_recursive_attn_processor(name, module, processor)
601
+
602
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
603
+ def set_default_attn_processor(self):
604
+ """
605
+ Disables custom attention processors and sets the default attention implementation.
606
+ """
607
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
608
+ processor = AttnAddedKVProcessor()
609
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
610
+ processor = AttnProcessor()
611
+ else:
612
+ raise ValueError(
613
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
614
+ )
615
+
616
+ self.set_attn_processor(processor)
617
+
618
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
619
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
620
+ r"""
621
+ Enable sliced attention computation.
622
+
623
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
624
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
625
+
626
+ Args:
627
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
628
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
629
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
630
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
631
+ must be a multiple of `slice_size`.
632
+ """
633
+ sliceable_head_dims = []
634
+
635
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
636
+ if hasattr(module, "set_attention_slice"):
637
+ sliceable_head_dims.append(module.sliceable_head_dim)
638
+
639
+ for child in module.children():
640
+ fn_recursive_retrieve_sliceable_dims(child)
641
+
642
+ # retrieve number of attention layers
643
+ for module in self.children():
644
+ fn_recursive_retrieve_sliceable_dims(module)
645
+
646
+ num_sliceable_layers = len(sliceable_head_dims)
647
+
648
+ if slice_size == "auto":
649
+ # half the attention head size is usually a good trade-off between
650
+ # speed and memory
651
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
652
+ elif slice_size == "max":
653
+ # make smallest slice possible
654
+ slice_size = num_sliceable_layers * [1]
655
+
656
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
657
+
658
+ if len(slice_size) != len(sliceable_head_dims):
659
+ raise ValueError(
660
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
661
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
662
+ )
663
+
664
+ for i in range(len(slice_size)):
665
+ size = slice_size[i]
666
+ dim = sliceable_head_dims[i]
667
+ if size is not None and size > dim:
668
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
669
+
670
+ # Recursively walk through all the children.
671
+ # Any children which exposes the set_attention_slice method
672
+ # gets the message
673
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
674
+ if hasattr(module, "set_attention_slice"):
675
+ module.set_attention_slice(slice_size.pop())
676
+
677
+ for child in module.children():
678
+ fn_recursive_set_attention_slice(child, slice_size)
679
+
680
+ reversed_slice_size = list(reversed(slice_size))
681
+ for module in self.children():
682
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
683
+
684
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
685
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
686
+ module.gradient_checkpointing = value
687
+
688
+ def forward(
689
+ self,
690
+ sample: torch.FloatTensor,
691
+ timestep: Union[torch.Tensor, float, int],
692
+ encoder_hidden_states: torch.Tensor,
693
+ brushnet_cond: torch.FloatTensor,
694
+ conditioning_scale: float = 1.0,
695
+ class_labels: Optional[torch.Tensor] = None,
696
+ timestep_cond: Optional[torch.Tensor] = None,
697
+ attention_mask: Optional[torch.Tensor] = None,
698
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
699
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
700
+ guess_mode: bool = False,
701
+ return_dict: bool = True,
702
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
703
+ """
704
+ The [`BrushNetModel`] forward method.
705
+
706
+ Args:
707
+ sample (`torch.FloatTensor`):
708
+ The noisy input tensor.
709
+ timestep (`Union[torch.Tensor, float, int]`):
710
+ The number of timesteps to denoise an input.
711
+ encoder_hidden_states (`torch.Tensor`):
712
+ The encoder hidden states.
713
+ brushnet_cond (`torch.FloatTensor`):
714
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
715
+ conditioning_scale (`float`, defaults to `1.0`):
716
+ The scale factor for BrushNet outputs.
717
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
718
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
719
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
720
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
721
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
722
+ embeddings.
723
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
724
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
725
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
726
+ negative values to the attention scores corresponding to "discard" tokens.
727
+ added_cond_kwargs (`dict`):
728
+ Additional conditions for the Stable Diffusion XL UNet.
729
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
730
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
731
+ guess_mode (`bool`, defaults to `False`):
732
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
733
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
734
+ return_dict (`bool`, defaults to `True`):
735
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
736
+
737
+ Returns:
738
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
739
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
740
+ returned where the first element is the sample tensor.
741
+ """
742
+ # check channel order
743
+ channel_order = self.config.brushnet_conditioning_channel_order
744
+
745
+ if channel_order == "rgb":
746
+ # in rgb order by default
747
+ ...
748
+ elif channel_order == "bgr":
749
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
750
+ else:
751
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
752
+
753
+ # prepare attention_mask
754
+ if attention_mask is not None:
755
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
756
+ attention_mask = attention_mask.unsqueeze(1)
757
+
758
+ # 1. time
759
+ timesteps = timestep
760
+ if not torch.is_tensor(timesteps):
761
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
762
+ # This would be a good case for the `match` statement (Python 3.10+)
763
+ is_mps = sample.device.type == "mps"
764
+ if isinstance(timestep, float):
765
+ dtype = torch.float32 if is_mps else torch.float64
766
+ else:
767
+ dtype = torch.int32 if is_mps else torch.int64
768
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
769
+ elif len(timesteps.shape) == 0:
770
+ timesteps = timesteps[None].to(sample.device)
771
+
772
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
773
+ timesteps = timesteps.expand(sample.shape[0])
774
+
775
+ t_emb = self.time_proj(timesteps)
776
+
777
+ # timesteps does not contain any weights and will always return f32 tensors
778
+ # but time_embedding might actually be running in fp16. so we need to cast here.
779
+ # there might be better ways to encapsulate this.
780
+ t_emb = t_emb.to(dtype=sample.dtype)
781
+
782
+ emb = self.time_embedding(t_emb, timestep_cond)
783
+ aug_emb = None
784
+
785
+ if self.class_embedding is not None:
786
+ if class_labels is None:
787
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
788
+
789
+ if self.config.class_embed_type == "timestep":
790
+ class_labels = self.time_proj(class_labels)
791
+
792
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
793
+ emb = emb + class_emb
794
+
795
+ if self.config.addition_embed_type is not None:
796
+ if self.config.addition_embed_type == "text":
797
+ aug_emb = self.add_embedding(encoder_hidden_states)
798
+
799
+ elif self.config.addition_embed_type == "text_time":
800
+ if "text_embeds" not in added_cond_kwargs:
801
+ raise ValueError(
802
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
803
+ )
804
+ text_embeds = added_cond_kwargs.get("text_embeds")
805
+ if "time_ids" not in added_cond_kwargs:
806
+ raise ValueError(
807
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
808
+ )
809
+ time_ids = added_cond_kwargs.get("time_ids")
810
+ time_embeds = self.add_time_proj(time_ids.flatten())
811
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
812
+
813
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
814
+ add_embeds = add_embeds.to(emb.dtype)
815
+ aug_emb = self.add_embedding(add_embeds)
816
+
817
+ emb = emb + aug_emb if aug_emb is not None else emb
818
+
819
+ # 2. pre-process
820
+ brushnet_cond = torch.concat([sample, brushnet_cond], 1)
821
+ sample = self.conv_in_condition(brushnet_cond)
822
+
823
+ # 3. down
824
+ down_block_res_samples = (sample,)
825
+ for downsample_block in self.down_blocks:
826
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
827
+ sample, res_samples = downsample_block(
828
+ hidden_states=sample,
829
+ temb=emb,
830
+ encoder_hidden_states=encoder_hidden_states,
831
+ attention_mask=attention_mask,
832
+ cross_attention_kwargs=cross_attention_kwargs,
833
+ )
834
+ else:
835
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
836
+
837
+ down_block_res_samples += res_samples
838
+
839
+ # 4. PaintingNet down blocks
840
+ brushnet_down_block_res_samples = ()
841
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
842
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
843
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
844
+
845
+ # 5. mid
846
+ if self.mid_block is not None:
847
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
848
+ sample = self.mid_block(
849
+ sample,
850
+ emb,
851
+ encoder_hidden_states=encoder_hidden_states,
852
+ attention_mask=attention_mask,
853
+ cross_attention_kwargs=cross_attention_kwargs,
854
+ )
855
+ else:
856
+ sample = self.mid_block(sample, emb)
857
+
858
+ # 6. BrushNet mid blocks
859
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
860
+
861
+ # 7. up
862
+ up_block_res_samples = ()
863
+ for i, upsample_block in enumerate(self.up_blocks):
864
+ is_final_block = i == len(self.up_blocks) - 1
865
+
866
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
867
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
868
+
869
+ # if we have not reached the final block and need to forward the
870
+ # upsample size, we do it here
871
+ if not is_final_block:
872
+ upsample_size = down_block_res_samples[-1].shape[2:]
873
+
874
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
875
+ sample, up_res_samples = upsample_block(
876
+ hidden_states=sample,
877
+ temb=emb,
878
+ res_hidden_states_tuple=res_samples,
879
+ encoder_hidden_states=encoder_hidden_states,
880
+ cross_attention_kwargs=cross_attention_kwargs,
881
+ upsample_size=upsample_size,
882
+ attention_mask=attention_mask,
883
+ return_res_samples=True,
884
+ )
885
+ else:
886
+ sample, up_res_samples = upsample_block(
887
+ hidden_states=sample,
888
+ temb=emb,
889
+ res_hidden_states_tuple=res_samples,
890
+ upsample_size=upsample_size,
891
+ return_res_samples=True,
892
+ )
893
+
894
+ up_block_res_samples += up_res_samples
895
+
896
+ # 8. BrushNet up blocks
897
+ brushnet_up_block_res_samples = ()
898
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
899
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
900
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
901
+
902
+ # 6. scaling
903
+ if guess_mode and not self.config.global_pool_conditions:
904
+ scales = torch.logspace(
905
+ -1,
906
+ 0,
907
+ len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
908
+ device=sample.device,
909
+ ) # 0.1 to 1.0
910
+ scales = scales * conditioning_scale
911
+
912
+ brushnet_down_block_res_samples = [
913
+ sample * scale
914
+ for sample, scale in zip(
915
+ brushnet_down_block_res_samples, scales[: len(brushnet_down_block_res_samples)]
916
+ )
917
+ ]
918
+ brushnet_mid_block_res_sample = (
919
+ brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
920
+ )
921
+ brushnet_up_block_res_samples = [
922
+ sample * scale
923
+ for sample, scale in zip(
924
+ brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples) + 1 :]
925
+ )
926
+ ]
927
+ else:
928
+ brushnet_down_block_res_samples = [
929
+ sample * conditioning_scale for sample in brushnet_down_block_res_samples
930
+ ]
931
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
932
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
933
+
934
+ if self.config.global_pool_conditions:
935
+ brushnet_down_block_res_samples = [
936
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
937
+ ]
938
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
939
+ brushnet_up_block_res_samples = [
940
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
941
+ ]
942
+
943
+ if not return_dict:
944
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
945
+
946
+ return BrushNetOutput(
947
+ down_block_res_samples=brushnet_down_block_res_samples,
948
+ mid_block_res_sample=brushnet_mid_block_res_sample,
949
+ up_block_res_samples=brushnet_up_block_res_samples,
950
+ )
951
+
952
+
953
+ def zero_module(module):
954
+ for p in module.parameters():
955
+ nn.init.zeros_(p)
956
+ return module
ComfyUI-BrushNet/brushnet/brushnet_xl.json ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BrushNetModel",
3
+ "_diffusers_version": "0.27.0.dev0",
4
+ "_name_or_path": "runs/logs/brushnetsdxl_randommask/checkpoint-80000",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": "text_time",
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": 256,
9
+ "attention_head_dim": [
10
+ 5,
11
+ 10,
12
+ 20
13
+ ],
14
+ "block_out_channels": [
15
+ 320,
16
+ 640,
17
+ 1280
18
+ ],
19
+ "brushnet_conditioning_channel_order": "rgb",
20
+ "class_embed_type": null,
21
+ "conditioning_channels": 5,
22
+ "conditioning_embedding_out_channels": [
23
+ 16,
24
+ 32,
25
+ 96,
26
+ 256
27
+ ],
28
+ "cross_attention_dim": 2048,
29
+ "down_block_types": [
30
+ "DownBlock2D",
31
+ "DownBlock2D",
32
+ "DownBlock2D"
33
+ ],
34
+ "downsample_padding": 1,
35
+ "encoder_hid_dim": null,
36
+ "encoder_hid_dim_type": null,
37
+ "flip_sin_to_cos": true,
38
+ "freq_shift": 0,
39
+ "global_pool_conditions": false,
40
+ "in_channels": 4,
41
+ "layers_per_block": 2,
42
+ "mid_block_scale_factor": 1,
43
+ "mid_block_type": "MidBlock2D",
44
+ "norm_eps": 1e-05,
45
+ "norm_num_groups": 32,
46
+ "num_attention_heads": null,
47
+ "num_class_embeds": null,
48
+ "only_cross_attention": false,
49
+ "projection_class_embeddings_input_dim": 2816,
50
+ "resnet_time_scale_shift": "default",
51
+ "transformer_layers_per_block": [
52
+ 1,
53
+ 2,
54
+ 10
55
+ ],
56
+ "up_block_types": [
57
+ "UpBlock2D",
58
+ "UpBlock2D",
59
+ "UpBlock2D"
60
+ ],
61
+ "upcast_attention": null,
62
+ "use_linear_projection": true
63
+ }
ComfyUI-BrushNet/brushnet/powerpaint.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "BrushNetModel",
3
+ "_diffusers_version": "0.27.2",
4
+ "act_fn": "silu",
5
+ "addition_embed_type": null,
6
+ "addition_embed_type_num_heads": 64,
7
+ "addition_time_embed_dim": null,
8
+ "attention_head_dim": 8,
9
+ "block_out_channels": [
10
+ 320,
11
+ 640,
12
+ 1280,
13
+ 1280
14
+ ],
15
+ "brushnet_conditioning_channel_order": "rgb",
16
+ "class_embed_type": null,
17
+ "conditioning_channels": 5,
18
+ "conditioning_embedding_out_channels": [
19
+ 16,
20
+ 32,
21
+ 96,
22
+ 256
23
+ ],
24
+ "cross_attention_dim": 768,
25
+ "down_block_types": [
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "CrossAttnDownBlock2D",
29
+ "DownBlock2D"
30
+ ],
31
+ "downsample_padding": 1,
32
+ "encoder_hid_dim": null,
33
+ "encoder_hid_dim_type": null,
34
+ "flip_sin_to_cos": true,
35
+ "freq_shift": 0,
36
+ "global_pool_conditions": false,
37
+ "in_channels": 4,
38
+ "layers_per_block": 2,
39
+ "mid_block_scale_factor": 1,
40
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
41
+ "norm_eps": 1e-05,
42
+ "norm_num_groups": 32,
43
+ "num_attention_heads": null,
44
+ "num_class_embeds": null,
45
+ "only_cross_attention": false,
46
+ "projection_class_embeddings_input_dim": null,
47
+ "resnet_time_scale_shift": "default",
48
+ "transformer_layers_per_block": 1,
49
+ "up_block_types": [
50
+ "UpBlock2D",
51
+ "CrossAttnUpBlock2D",
52
+ "CrossAttnUpBlock2D",
53
+ "CrossAttnUpBlock2D"
54
+ ],
55
+ "upcast_attention": false,
56
+ "use_linear_projection": false
57
+ }
ComfyUI-BrushNet/brushnet/powerpaint_utils.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import CLIPTokenizer
7
+ from typing import Any, List, Optional, Union
8
+
9
+ class TokenizerWrapper:
10
+ """Tokenizer wrapper for CLIPTokenizer. Only support CLIPTokenizer
11
+ currently. This wrapper is modified from https://github.com/huggingface/dif
12
+ fusers/blob/e51f19aee82c8dd874b715a09dbc521d88835d68/src/diffusers/loaders.
13
+ py#L358 # noqa.
14
+
15
+ Args:
16
+ from_pretrained (Union[str, os.PathLike], optional): The *model id*
17
+ of a pretrained model or a path to a *directory* containing
18
+ model weights and config. Defaults to None.
19
+ from_config (Union[str, os.PathLike], optional): The *model id*
20
+ of a pretrained model or a path to a *directory* containing
21
+ model weights and config. Defaults to None.
22
+
23
+ *args, **kwargs: If `from_pretrained` is passed, *args and **kwargs
24
+ will be passed to `from_pretrained` function. Otherwise, *args
25
+ and **kwargs will be used to initialize the model by
26
+ `self._module_cls(*args, **kwargs)`.
27
+ """
28
+
29
+ def __init__(self, tokenizer: CLIPTokenizer):
30
+ self.wrapped = tokenizer
31
+ self.token_map = {}
32
+
33
+ def __getattr__(self, name: str) -> Any:
34
+ if name in self.__dict__:
35
+ return getattr(self, name)
36
+ #if name == "wrapped":
37
+ # return getattr(self, 'wrapped')#super().__getattr__("wrapped")
38
+
39
+ try:
40
+ return getattr(self.wrapped, name)
41
+ except AttributeError:
42
+ raise AttributeError(
43
+ "'name' cannot be found in both "
44
+ f"'{self.__class__.__name__}' and "
45
+ f"'{self.__class__.__name__}.tokenizer'."
46
+ )
47
+
48
+ def try_adding_tokens(self, tokens: Union[str, List[str]], *args, **kwargs):
49
+ """Attempt to add tokens to the tokenizer.
50
+
51
+ Args:
52
+ tokens (Union[str, List[str]]): The tokens to be added.
53
+ """
54
+ num_added_tokens = self.wrapped.add_tokens(tokens, *args, **kwargs)
55
+ assert num_added_tokens != 0, (
56
+ f"The tokenizer already contains the token {tokens}. Please pass "
57
+ "a different `placeholder_token` that is not already in the "
58
+ "tokenizer."
59
+ )
60
+
61
+ def get_token_info(self, token: str) -> dict:
62
+ """Get the information of a token, including its start and end index in
63
+ the current tokenizer.
64
+
65
+ Args:
66
+ token (str): The token to be queried.
67
+
68
+ Returns:
69
+ dict: The information of the token, including its start and end
70
+ index in current tokenizer.
71
+ """
72
+ token_ids = self.__call__(token).input_ids
73
+ start, end = token_ids[1], token_ids[-2] + 1
74
+ return {"name": token, "start": start, "end": end}
75
+
76
+ def add_placeholder_token(self, placeholder_token: str, *args, num_vec_per_token: int = 1, **kwargs):
77
+ """Add placeholder tokens to the tokenizer.
78
+
79
+ Args:
80
+ placeholder_token (str): The placeholder token to be added.
81
+ num_vec_per_token (int, optional): The number of vectors of
82
+ the added placeholder token.
83
+ *args, **kwargs: The arguments for `self.wrapped.add_tokens`.
84
+ """
85
+ output = []
86
+ if num_vec_per_token == 1:
87
+ self.try_adding_tokens(placeholder_token, *args, **kwargs)
88
+ output.append(placeholder_token)
89
+ else:
90
+ output = []
91
+ for i in range(num_vec_per_token):
92
+ ith_token = placeholder_token + f"_{i}"
93
+ self.try_adding_tokens(ith_token, *args, **kwargs)
94
+ output.append(ith_token)
95
+
96
+ for token in self.token_map:
97
+ if token in placeholder_token:
98
+ raise ValueError(
99
+ f"The tokenizer already has placeholder token {token} "
100
+ f"that can get confused with {placeholder_token} "
101
+ "keep placeholder tokens independent"
102
+ )
103
+ self.token_map[placeholder_token] = output
104
+
105
+ def replace_placeholder_tokens_in_text(
106
+ self, text: Union[str, List[str]], vector_shuffle: bool = False, prop_tokens_to_load: float = 1.0
107
+ ) -> Union[str, List[str]]:
108
+ """Replace the keywords in text with placeholder tokens. This function
109
+ will be called in `self.__call__` and `self.encode`.
110
+
111
+ Args:
112
+ text (Union[str, List[str]]): The text to be processed.
113
+ vector_shuffle (bool, optional): Whether to shuffle the vectors.
114
+ Defaults to False.
115
+ prop_tokens_to_load (float, optional): The proportion of tokens to
116
+ be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0.
117
+
118
+ Returns:
119
+ Union[str, List[str]]: The processed text.
120
+ """
121
+ if isinstance(text, list):
122
+ output = []
123
+ for i in range(len(text)):
124
+ output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
125
+ return output
126
+
127
+ for placeholder_token in self.token_map:
128
+ if placeholder_token in text:
129
+ tokens = self.token_map[placeholder_token]
130
+ tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
131
+ if vector_shuffle:
132
+ tokens = copy.copy(tokens)
133
+ random.shuffle(tokens)
134
+ text = text.replace(placeholder_token, " ".join(tokens))
135
+ return text
136
+
137
+ def replace_text_with_placeholder_tokens(self, text: Union[str, List[str]]) -> Union[str, List[str]]:
138
+ """Replace the placeholder tokens in text with the original keywords.
139
+ This function will be called in `self.decode`.
140
+
141
+ Args:
142
+ text (Union[str, List[str]]): The text to be processed.
143
+
144
+ Returns:
145
+ Union[str, List[str]]: The processed text.
146
+ """
147
+ if isinstance(text, list):
148
+ output = []
149
+ for i in range(len(text)):
150
+ output.append(self.replace_text_with_placeholder_tokens(text[i]))
151
+ return output
152
+
153
+ for placeholder_token, tokens in self.token_map.items():
154
+ merged_tokens = " ".join(tokens)
155
+ if merged_tokens in text:
156
+ text = text.replace(merged_tokens, placeholder_token)
157
+ return text
158
+
159
+ def __call__(
160
+ self,
161
+ text: Union[str, List[str]],
162
+ *args,
163
+ vector_shuffle: bool = False,
164
+ prop_tokens_to_load: float = 1.0,
165
+ **kwargs,
166
+ ):
167
+ """The call function of the wrapper.
168
+
169
+ Args:
170
+ text (Union[str, List[str]]): The text to be tokenized.
171
+ vector_shuffle (bool, optional): Whether to shuffle the vectors.
172
+ Defaults to False.
173
+ prop_tokens_to_load (float, optional): The proportion of tokens to
174
+ be loaded. If 1.0, all tokens will be loaded. Defaults to 1.0
175
+ *args, **kwargs: The arguments for `self.wrapped.__call__`.
176
+ """
177
+ replaced_text = self.replace_placeholder_tokens_in_text(
178
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
179
+ )
180
+
181
+ return self.wrapped.__call__(replaced_text, *args, **kwargs)
182
+
183
+ def encode(self, text: Union[str, List[str]], *args, **kwargs):
184
+ """Encode the passed text to token index.
185
+
186
+ Args:
187
+ text (Union[str, List[str]]): The text to be encode.
188
+ *args, **kwargs: The arguments for `self.wrapped.__call__`.
189
+ """
190
+ replaced_text = self.replace_placeholder_tokens_in_text(text)
191
+ return self.wrapped(replaced_text, *args, **kwargs)
192
+
193
+ def decode(self, token_ids, return_raw: bool = False, *args, **kwargs) -> Union[str, List[str]]:
194
+ """Decode the token index to text.
195
+
196
+ Args:
197
+ token_ids: The token index to be decoded.
198
+ return_raw: Whether keep the placeholder token in the text.
199
+ Defaults to False.
200
+ *args, **kwargs: The arguments for `self.wrapped.decode`.
201
+
202
+ Returns:
203
+ Union[str, List[str]]: The decoded text.
204
+ """
205
+ text = self.wrapped.decode(token_ids, *args, **kwargs)
206
+ if return_raw:
207
+ return text
208
+ replaced_text = self.replace_text_with_placeholder_tokens(text)
209
+ return replaced_text
210
+
211
+ def __repr__(self):
212
+ """The representation of the wrapper."""
213
+ s = super().__repr__()
214
+ prefix = f"Wrapped Module Class: {self._module_cls}\n"
215
+ prefix += f"Wrapped Module Name: {self._module_name}\n"
216
+ if self._from_pretrained:
217
+ prefix += f"From Pretrained: {self._from_pretrained}\n"
218
+ s = prefix + s
219
+ return s
220
+
221
+
222
+ class EmbeddingLayerWithFixes(nn.Module):
223
+ """The revised embedding layer to support external embeddings. This design
224
+ of this class is inspired by https://github.com/AUTOMATIC1111/stable-
225
+ diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
226
+ jack.py#L224 # noqa.
227
+
228
+ Args:
229
+ wrapped (nn.Emebdding): The embedding layer to be wrapped.
230
+ external_embeddings (Union[dict, List[dict]], optional): The external
231
+ embeddings added to this layer. Defaults to None.
232
+ """
233
+
234
+ def __init__(self, wrapped: nn.Embedding, external_embeddings: Optional[Union[dict, List[dict]]] = None):
235
+ super().__init__()
236
+ self.wrapped = wrapped
237
+ self.num_embeddings = wrapped.weight.shape[0]
238
+
239
+ self.external_embeddings = []
240
+ if external_embeddings:
241
+ self.add_embeddings(external_embeddings)
242
+
243
+ self.trainable_embeddings = nn.ParameterDict()
244
+
245
+ @property
246
+ def weight(self):
247
+ """Get the weight of wrapped embedding layer."""
248
+ return self.wrapped.weight
249
+
250
+ def check_duplicate_names(self, embeddings: List[dict]):
251
+ """Check whether duplicate names exist in list of 'external
252
+ embeddings'.
253
+
254
+ Args:
255
+ embeddings (List[dict]): A list of embedding to be check.
256
+ """
257
+ names = [emb["name"] for emb in embeddings]
258
+ assert len(names) == len(set(names)), (
259
+ "Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
260
+ )
261
+
262
+ def check_ids_overlap(self, embeddings):
263
+ """Check whether overlap exist in token ids of 'external_embeddings'.
264
+
265
+ Args:
266
+ embeddings (List[dict]): A list of embedding to be check.
267
+ """
268
+ ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
269
+ ids_range.sort() # sort by 'start'
270
+ # check if 'end' has overlapping
271
+ for idx in range(len(ids_range) - 1):
272
+ name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
273
+ assert ids_range[idx][1] <= ids_range[idx + 1][0], (
274
+ f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
275
+ )
276
+
277
+ def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
278
+ """Add external embeddings to this layer.
279
+
280
+ Use case:
281
+
282
+ >>> 1. Add token to tokenizer and get the token id.
283
+ >>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
284
+ >>> # 'how much' in kiswahili
285
+ >>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
286
+ >>>
287
+ >>> 2. Add external embeddings to the model.
288
+ >>> new_embedding = {
289
+ >>> 'name': 'ngapi', # 'how much' in kiswahili
290
+ >>> 'embedding': torch.ones(1, 15) * 4,
291
+ >>> 'start': tokenizer.get_token_info('kwaheri')['start'],
292
+ >>> 'end': tokenizer.get_token_info('kwaheri')['end'],
293
+ >>> 'trainable': False # if True, will registry as a parameter
294
+ >>> }
295
+ >>> embedding_layer = nn.Embedding(10, 15)
296
+ >>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
297
+ >>> embedding_layer_wrapper.add_embeddings(new_embedding)
298
+ >>>
299
+ >>> 3. Forward tokenizer and embedding layer!
300
+ >>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
301
+ >>> input_ids = tokenizer(
302
+ >>> input_text, padding='max_length', truncation=True,
303
+ >>> return_tensors='pt')['input_ids']
304
+ >>> out_feat = embedding_layer_wrapper(input_ids)
305
+ >>>
306
+ >>> 4. Let's validate the result!
307
+ >>> assert (out_feat[0, 3: 7] == 2.3).all()
308
+ >>> assert (out_feat[2, 5: 9] == 2.3).all()
309
+
310
+ Args:
311
+ embeddings (Union[dict, list[dict]]): The external embeddings to
312
+ be added. Each dict must contain the following 4 fields: 'name'
313
+ (the name of this embedding), 'embedding' (the embedding
314
+ tensor), 'start' (the start token id of this embedding), 'end'
315
+ (the end token id of this embedding). For example:
316
+ `{name: NAME, start: START, end: END, embedding: torch.Tensor}`
317
+ """
318
+ if isinstance(embeddings, dict):
319
+ embeddings = [embeddings]
320
+
321
+ self.external_embeddings += embeddings
322
+ self.check_duplicate_names(self.external_embeddings)
323
+ self.check_ids_overlap(self.external_embeddings)
324
+
325
+ # set for trainable
326
+ added_trainable_emb_info = []
327
+ for embedding in embeddings:
328
+ trainable = embedding.get("trainable", False)
329
+ if trainable:
330
+ name = embedding["name"]
331
+ embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
332
+ self.trainable_embeddings[name] = embedding["embedding"]
333
+ added_trainable_emb_info.append(name)
334
+
335
+ added_emb_info = [emb["name"] for emb in embeddings]
336
+ added_emb_info = ", ".join(added_emb_info)
337
+ print(f"Successfully add external embeddings: {added_emb_info}.", "current")
338
+
339
+ if added_trainable_emb_info:
340
+ added_trainable_emb_info = ", ".join(added_trainable_emb_info)
341
+ print("Successfully add trainable external embeddings: " f"{added_trainable_emb_info}", "current")
342
+
343
+ def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
344
+ """Replace external input ids to 0.
345
+
346
+ Args:
347
+ input_ids (torch.Tensor): The input ids to be replaced.
348
+
349
+ Returns:
350
+ torch.Tensor: The replaced input ids.
351
+ """
352
+ input_ids_fwd = input_ids.clone()
353
+ input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
354
+ return input_ids_fwd
355
+
356
+ def replace_embeddings(
357
+ self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
358
+ ) -> torch.Tensor:
359
+ """Replace external embedding to the embedding layer. Noted that, in
360
+ this function we use `torch.cat` to avoid inplace modification.
361
+
362
+ Args:
363
+ input_ids (torch.Tensor): The original token ids. Shape like
364
+ [LENGTH, ].
365
+ embedding (torch.Tensor): The embedding of token ids after
366
+ `replace_input_ids` function.
367
+ external_embedding (dict): The external embedding to be replaced.
368
+
369
+ Returns:
370
+ torch.Tensor: The replaced embedding.
371
+ """
372
+ new_embedding = []
373
+
374
+ name = external_embedding["name"]
375
+ start = external_embedding["start"]
376
+ end = external_embedding["end"]
377
+ target_ids_to_replace = [i for i in range(start, end)]
378
+ ext_emb = external_embedding["embedding"]
379
+
380
+ # do not need to replace
381
+ if not (input_ids == start).any():
382
+ return embedding
383
+
384
+ # start replace
385
+ s_idx, e_idx = 0, 0
386
+ while e_idx < len(input_ids):
387
+ if input_ids[e_idx] == start:
388
+ if e_idx != 0:
389
+ # add embedding do not need to replace
390
+ new_embedding.append(embedding[s_idx:e_idx])
391
+
392
+ # check if the next embedding need to replace is valid
393
+ actually_ids_to_replace = [int(i) for i in input_ids[e_idx : e_idx + end - start]]
394
+ assert actually_ids_to_replace == target_ids_to_replace, (
395
+ f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
396
+ f"Expect '{target_ids_to_replace}' for embedding "
397
+ f"'{name}' but found '{actually_ids_to_replace}'."
398
+ )
399
+
400
+ new_embedding.append(ext_emb)
401
+
402
+ s_idx = e_idx + end - start
403
+ e_idx = s_idx + 1
404
+ else:
405
+ e_idx += 1
406
+
407
+ if e_idx == len(input_ids):
408
+ new_embedding.append(embedding[s_idx:e_idx])
409
+
410
+ return torch.cat(new_embedding, dim=0)
411
+
412
+ def forward(self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None):
413
+ """The forward function.
414
+
415
+ Args:
416
+ input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
417
+ [LENGTH, ].
418
+ external_embeddings (Optional[List[dict]]): The external
419
+ embeddings. If not passed, only `self.external_embeddings`
420
+ will be used. Defaults to None.
421
+
422
+ input_ids: shape like [bz, LENGTH] or [LENGTH].
423
+ """
424
+ assert input_ids.ndim in [1, 2]
425
+ if input_ids.ndim == 1:
426
+ input_ids = input_ids.unsqueeze(0)
427
+
428
+ if external_embeddings is None and not self.external_embeddings:
429
+ return self.wrapped(input_ids)
430
+
431
+ input_ids_fwd = self.replace_input_ids(input_ids)
432
+ inputs_embeds = self.wrapped(input_ids_fwd)
433
+
434
+ vecs = []
435
+
436
+ if external_embeddings is None:
437
+ external_embeddings = []
438
+ elif isinstance(external_embeddings, dict):
439
+ external_embeddings = [external_embeddings]
440
+ embeddings = self.external_embeddings + external_embeddings
441
+
442
+ for input_id, embedding in zip(input_ids, inputs_embeds):
443
+ new_embedding = embedding
444
+ for external_embedding in embeddings:
445
+ new_embedding = self.replace_embeddings(input_id, new_embedding, external_embedding)
446
+ vecs.append(new_embedding)
447
+
448
+ return torch.stack(vecs)
449
+
450
+
451
+
452
+ def add_tokens(
453
+ tokenizer, text_encoder, placeholder_tokens: list, initialize_tokens: list = None, num_vectors_per_token: int = 1
454
+ ):
455
+ """Add token for training.
456
+
457
+ # TODO: support add tokens as dict, then we can load pretrained tokens.
458
+ """
459
+ if initialize_tokens is not None:
460
+ assert len(initialize_tokens) == len(
461
+ placeholder_tokens
462
+ ), "placeholder_token should be the same length as initialize_token"
463
+ for ii in range(len(placeholder_tokens)):
464
+ tokenizer.add_placeholder_token(placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token)
465
+
466
+ # text_encoder.set_embedding_layer()
467
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
468
+ text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(embedding_layer)
469
+ embedding_layer = text_encoder.text_model.embeddings.token_embedding
470
+
471
+ assert embedding_layer is not None, (
472
+ "Do not support get embedding layer for current text encoder. " "Please check your configuration."
473
+ )
474
+ initialize_embedding = []
475
+ if initialize_tokens is not None:
476
+ for ii in range(len(placeholder_tokens)):
477
+ init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
478
+ temp_embedding = embedding_layer.weight[init_id]
479
+ initialize_embedding.append(temp_embedding[None, ...].repeat(num_vectors_per_token, 1))
480
+ else:
481
+ for ii in range(len(placeholder_tokens)):
482
+ init_id = tokenizer("a").input_ids[1]
483
+ temp_embedding = embedding_layer.weight[init_id]
484
+ len_emb = temp_embedding.shape[0]
485
+ init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
486
+ initialize_embedding.append(init_weight)
487
+
488
+ # initialize_embedding = torch.cat(initialize_embedding,dim=0)
489
+
490
+ token_info_all = []
491
+ for ii in range(len(placeholder_tokens)):
492
+ token_info = tokenizer.get_token_info(placeholder_tokens[ii])
493
+ token_info["embedding"] = initialize_embedding[ii]
494
+ token_info["trainable"] = True
495
+ token_info_all.append(token_info)
496
+ embedding_layer.add_embeddings(token_info_all)
ComfyUI-BrushNet/brushnet/unet_2d_blocks.py ADDED
The diff for this file is too large to render. See raw diff
 
ComfyUI-BrushNet/brushnet/unet_2d_condition.py ADDED
@@ -0,0 +1,1355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Any, Dict, List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.checkpoint
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
+ from diffusers.models.activations import get_activation
25
+ from diffusers.models.attention_processor import (
26
+ ADDED_KV_ATTENTION_PROCESSORS,
27
+ CROSS_ATTENTION_PROCESSORS,
28
+ Attention,
29
+ AttentionProcessor,
30
+ AttnAddedKVProcessor,
31
+ AttnProcessor,
32
+ )
33
+ from diffusers.models.embeddings import (
34
+ GaussianFourierProjection,
35
+ GLIGENTextBoundingboxProjection,
36
+ ImageHintTimeEmbedding,
37
+ ImageProjection,
38
+ ImageTimeEmbedding,
39
+ TextImageProjection,
40
+ TextImageTimeEmbedding,
41
+ TextTimeEmbedding,
42
+ TimestepEmbedding,
43
+ Timesteps,
44
+ )
45
+ from diffusers.models.modeling_utils import ModelMixin
46
+ from .unet_2d_blocks import (
47
+ get_down_block,
48
+ get_mid_block,
49
+ get_up_block,
50
+ )
51
+
52
+
53
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
54
+
55
+
56
+ @dataclass
57
+ class UNet2DConditionOutput(BaseOutput):
58
+ """
59
+ The output of [`UNet2DConditionModel`].
60
+
61
+ Args:
62
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
63
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
64
+ """
65
+
66
+ sample: torch.FloatTensor = None
67
+
68
+
69
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
70
+ r"""
71
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
72
+ shaped output.
73
+
74
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
75
+ for all models (such as downloading or saving).
76
+
77
+ Parameters:
78
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
79
+ Height and width of input/output sample.
80
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
81
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
82
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
83
+ flip_sin_to_cos (`bool`, *optional*, defaults to `True`):
84
+ Whether to flip the sin to cos in the time embedding.
85
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
86
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
87
+ The tuple of downsample blocks to use.
88
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
89
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
90
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
91
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
92
+ The tuple of upsample blocks to use.
93
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
94
+ Whether to include self-attention in the basic transformer blocks, see
95
+ [`~models.attention.BasicTransformerBlock`].
96
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
97
+ The tuple of output channels for each block.
98
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
99
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
100
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
101
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
102
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
103
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
104
+ If `None`, normalization and activation layers is skipped in post-processing.
105
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
106
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
107
+ The dimension of the cross attention features.
108
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
109
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
110
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
111
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
112
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
113
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
114
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
115
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
116
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
117
+ encoder_hid_dim (`int`, *optional*, defaults to None):
118
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
119
+ dimension to `cross_attention_dim`.
120
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
121
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
122
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
123
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
124
+ num_attention_heads (`int`, *optional*):
125
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
126
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
127
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
128
+ class_embed_type (`str`, *optional*, defaults to `None`):
129
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
130
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
131
+ addition_embed_type (`str`, *optional*, defaults to `None`):
132
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
133
+ "text". "text" will use the `TextTimeEmbedding` layer.
134
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
135
+ Dimension for the timestep embeddings.
136
+ num_class_embeds (`int`, *optional*, defaults to `None`):
137
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
138
+ class conditioning with `class_embed_type` equal to `None`.
139
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
140
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
141
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
142
+ An optional override for the dimension of the projected time embedding.
143
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
144
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
145
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
146
+ timestep_post_act (`str`, *optional*, defaults to `None`):
147
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
148
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
149
+ The dimension of `cond_proj` layer in the timestep embedding.
150
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
151
+ conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
152
+ projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
153
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
154
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
155
+ embeddings with the class embeddings.
156
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
157
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
158
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
159
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
160
+ otherwise.
161
+ """
162
+
163
+ _supports_gradient_checkpointing = True
164
+
165
+ @register_to_config
166
+ def __init__(
167
+ self,
168
+ sample_size: Optional[int] = None,
169
+ in_channels: int = 4,
170
+ out_channels: int = 4,
171
+ center_input_sample: bool = False,
172
+ flip_sin_to_cos: bool = True,
173
+ freq_shift: int = 0,
174
+ down_block_types: Tuple[str] = (
175
+ "CrossAttnDownBlock2D",
176
+ "CrossAttnDownBlock2D",
177
+ "CrossAttnDownBlock2D",
178
+ "DownBlock2D",
179
+ ),
180
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
181
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
182
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
183
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
184
+ layers_per_block: Union[int, Tuple[int]] = 2,
185
+ downsample_padding: int = 1,
186
+ mid_block_scale_factor: float = 1,
187
+ dropout: float = 0.0,
188
+ act_fn: str = "silu",
189
+ norm_num_groups: Optional[int] = 32,
190
+ norm_eps: float = 1e-5,
191
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
192
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
193
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
194
+ encoder_hid_dim: Optional[int] = None,
195
+ encoder_hid_dim_type: Optional[str] = None,
196
+ attention_head_dim: Union[int, Tuple[int]] = 8,
197
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
198
+ dual_cross_attention: bool = False,
199
+ use_linear_projection: bool = False,
200
+ class_embed_type: Optional[str] = None,
201
+ addition_embed_type: Optional[str] = None,
202
+ addition_time_embed_dim: Optional[int] = None,
203
+ num_class_embeds: Optional[int] = None,
204
+ upcast_attention: bool = False,
205
+ resnet_time_scale_shift: str = "default",
206
+ resnet_skip_time_act: bool = False,
207
+ resnet_out_scale_factor: float = 1.0,
208
+ time_embedding_type: str = "positional",
209
+ time_embedding_dim: Optional[int] = None,
210
+ time_embedding_act_fn: Optional[str] = None,
211
+ timestep_post_act: Optional[str] = None,
212
+ time_cond_proj_dim: Optional[int] = None,
213
+ conv_in_kernel: int = 3,
214
+ conv_out_kernel: int = 3,
215
+ projection_class_embeddings_input_dim: Optional[int] = None,
216
+ attention_type: str = "default",
217
+ class_embeddings_concat: bool = False,
218
+ mid_block_only_cross_attention: Optional[bool] = None,
219
+ cross_attention_norm: Optional[str] = None,
220
+ addition_embed_type_num_heads: int = 64,
221
+ ):
222
+ super().__init__()
223
+
224
+ self.sample_size = sample_size
225
+
226
+ if num_attention_heads is not None:
227
+ raise ValueError(
228
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
229
+ )
230
+
231
+ # If `num_attention_heads` is not defined (which is the case for most models)
232
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
233
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
234
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
235
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
236
+ # which is why we correct for the naming here.
237
+ num_attention_heads = num_attention_heads or attention_head_dim
238
+
239
+ # Check inputs
240
+ self._check_config(
241
+ down_block_types=down_block_types,
242
+ up_block_types=up_block_types,
243
+ only_cross_attention=only_cross_attention,
244
+ block_out_channels=block_out_channels,
245
+ layers_per_block=layers_per_block,
246
+ cross_attention_dim=cross_attention_dim,
247
+ transformer_layers_per_block=transformer_layers_per_block,
248
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
249
+ attention_head_dim=attention_head_dim,
250
+ num_attention_heads=num_attention_heads,
251
+ )
252
+
253
+ # input
254
+ conv_in_padding = (conv_in_kernel - 1) // 2
255
+ self.conv_in = nn.Conv2d(
256
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
257
+ )
258
+
259
+ # time
260
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
261
+ time_embedding_type,
262
+ block_out_channels=block_out_channels,
263
+ flip_sin_to_cos=flip_sin_to_cos,
264
+ freq_shift=freq_shift,
265
+ time_embedding_dim=time_embedding_dim,
266
+ )
267
+
268
+ self.time_embedding = TimestepEmbedding(
269
+ timestep_input_dim,
270
+ time_embed_dim,
271
+ act_fn=act_fn,
272
+ post_act_fn=timestep_post_act,
273
+ cond_proj_dim=time_cond_proj_dim,
274
+ )
275
+
276
+ self._set_encoder_hid_proj(
277
+ encoder_hid_dim_type,
278
+ cross_attention_dim=cross_attention_dim,
279
+ encoder_hid_dim=encoder_hid_dim,
280
+ )
281
+
282
+ # class embedding
283
+ self._set_class_embedding(
284
+ class_embed_type,
285
+ act_fn=act_fn,
286
+ num_class_embeds=num_class_embeds,
287
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
288
+ time_embed_dim=time_embed_dim,
289
+ timestep_input_dim=timestep_input_dim,
290
+ )
291
+
292
+ self._set_add_embedding(
293
+ addition_embed_type,
294
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
295
+ addition_time_embed_dim=addition_time_embed_dim,
296
+ cross_attention_dim=cross_attention_dim,
297
+ encoder_hid_dim=encoder_hid_dim,
298
+ flip_sin_to_cos=flip_sin_to_cos,
299
+ freq_shift=freq_shift,
300
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
301
+ time_embed_dim=time_embed_dim,
302
+ )
303
+
304
+ if time_embedding_act_fn is None:
305
+ self.time_embed_act = None
306
+ else:
307
+ self.time_embed_act = get_activation(time_embedding_act_fn)
308
+
309
+ self.down_blocks = nn.ModuleList([])
310
+ self.up_blocks = nn.ModuleList([])
311
+
312
+ if isinstance(only_cross_attention, bool):
313
+ if mid_block_only_cross_attention is None:
314
+ mid_block_only_cross_attention = only_cross_attention
315
+
316
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
317
+
318
+ if mid_block_only_cross_attention is None:
319
+ mid_block_only_cross_attention = False
320
+
321
+ if isinstance(num_attention_heads, int):
322
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
323
+
324
+ if isinstance(attention_head_dim, int):
325
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
326
+
327
+ if isinstance(cross_attention_dim, int):
328
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
329
+
330
+ if isinstance(layers_per_block, int):
331
+ layers_per_block = [layers_per_block] * len(down_block_types)
332
+
333
+ if isinstance(transformer_layers_per_block, int):
334
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
335
+
336
+ if class_embeddings_concat:
337
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
338
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
339
+ # regular time embeddings
340
+ blocks_time_embed_dim = time_embed_dim * 2
341
+ else:
342
+ blocks_time_embed_dim = time_embed_dim
343
+
344
+ # down
345
+ output_channel = block_out_channels[0]
346
+ for i, down_block_type in enumerate(down_block_types):
347
+ input_channel = output_channel
348
+ output_channel = block_out_channels[i]
349
+ is_final_block = i == len(block_out_channels) - 1
350
+
351
+ down_block = get_down_block(
352
+ down_block_type,
353
+ num_layers=layers_per_block[i],
354
+ transformer_layers_per_block=transformer_layers_per_block[i],
355
+ in_channels=input_channel,
356
+ out_channels=output_channel,
357
+ temb_channels=blocks_time_embed_dim,
358
+ add_downsample=not is_final_block,
359
+ resnet_eps=norm_eps,
360
+ resnet_act_fn=act_fn,
361
+ resnet_groups=norm_num_groups,
362
+ cross_attention_dim=cross_attention_dim[i],
363
+ num_attention_heads=num_attention_heads[i],
364
+ downsample_padding=downsample_padding,
365
+ dual_cross_attention=dual_cross_attention,
366
+ use_linear_projection=use_linear_projection,
367
+ only_cross_attention=only_cross_attention[i],
368
+ upcast_attention=upcast_attention,
369
+ resnet_time_scale_shift=resnet_time_scale_shift,
370
+ attention_type=attention_type,
371
+ resnet_skip_time_act=resnet_skip_time_act,
372
+ resnet_out_scale_factor=resnet_out_scale_factor,
373
+ cross_attention_norm=cross_attention_norm,
374
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
375
+ dropout=dropout,
376
+ )
377
+ self.down_blocks.append(down_block)
378
+
379
+ # mid
380
+ self.mid_block = get_mid_block(
381
+ mid_block_type,
382
+ temb_channels=blocks_time_embed_dim,
383
+ in_channels=block_out_channels[-1],
384
+ resnet_eps=norm_eps,
385
+ resnet_act_fn=act_fn,
386
+ resnet_groups=norm_num_groups,
387
+ output_scale_factor=mid_block_scale_factor,
388
+ transformer_layers_per_block=transformer_layers_per_block[-1],
389
+ num_attention_heads=num_attention_heads[-1],
390
+ cross_attention_dim=cross_attention_dim[-1],
391
+ dual_cross_attention=dual_cross_attention,
392
+ use_linear_projection=use_linear_projection,
393
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
394
+ upcast_attention=upcast_attention,
395
+ resnet_time_scale_shift=resnet_time_scale_shift,
396
+ attention_type=attention_type,
397
+ resnet_skip_time_act=resnet_skip_time_act,
398
+ cross_attention_norm=cross_attention_norm,
399
+ attention_head_dim=attention_head_dim[-1],
400
+ dropout=dropout,
401
+ )
402
+
403
+ # count how many layers upsample the images
404
+ self.num_upsamplers = 0
405
+
406
+ # up
407
+ reversed_block_out_channels = list(reversed(block_out_channels))
408
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
409
+ reversed_layers_per_block = list(reversed(layers_per_block))
410
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
411
+ reversed_transformer_layers_per_block = (
412
+ list(reversed(transformer_layers_per_block))
413
+ if reverse_transformer_layers_per_block is None
414
+ else reverse_transformer_layers_per_block
415
+ )
416
+ only_cross_attention = list(reversed(only_cross_attention))
417
+
418
+ output_channel = reversed_block_out_channels[0]
419
+ for i, up_block_type in enumerate(up_block_types):
420
+ is_final_block = i == len(block_out_channels) - 1
421
+
422
+ prev_output_channel = output_channel
423
+ output_channel = reversed_block_out_channels[i]
424
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
425
+
426
+ # add upsample block for all BUT final layer
427
+ if not is_final_block:
428
+ add_upsample = True
429
+ self.num_upsamplers += 1
430
+ else:
431
+ add_upsample = False
432
+
433
+ up_block = get_up_block(
434
+ up_block_type,
435
+ num_layers=reversed_layers_per_block[i] + 1,
436
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
437
+ in_channels=input_channel,
438
+ out_channels=output_channel,
439
+ prev_output_channel=prev_output_channel,
440
+ temb_channels=blocks_time_embed_dim,
441
+ add_upsample=add_upsample,
442
+ resnet_eps=norm_eps,
443
+ resnet_act_fn=act_fn,
444
+ resolution_idx=i,
445
+ resnet_groups=norm_num_groups,
446
+ cross_attention_dim=reversed_cross_attention_dim[i],
447
+ num_attention_heads=reversed_num_attention_heads[i],
448
+ dual_cross_attention=dual_cross_attention,
449
+ use_linear_projection=use_linear_projection,
450
+ only_cross_attention=only_cross_attention[i],
451
+ upcast_attention=upcast_attention,
452
+ resnet_time_scale_shift=resnet_time_scale_shift,
453
+ attention_type=attention_type,
454
+ resnet_skip_time_act=resnet_skip_time_act,
455
+ resnet_out_scale_factor=resnet_out_scale_factor,
456
+ cross_attention_norm=cross_attention_norm,
457
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
458
+ dropout=dropout,
459
+ )
460
+ self.up_blocks.append(up_block)
461
+ prev_output_channel = output_channel
462
+
463
+ # out
464
+ if norm_num_groups is not None:
465
+ self.conv_norm_out = nn.GroupNorm(
466
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
467
+ )
468
+
469
+ self.conv_act = get_activation(act_fn)
470
+
471
+ else:
472
+ self.conv_norm_out = None
473
+ self.conv_act = None
474
+
475
+ conv_out_padding = (conv_out_kernel - 1) // 2
476
+ self.conv_out = nn.Conv2d(
477
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
478
+ )
479
+
480
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
481
+
482
+ def _check_config(
483
+ self,
484
+ down_block_types: Tuple[str],
485
+ up_block_types: Tuple[str],
486
+ only_cross_attention: Union[bool, Tuple[bool]],
487
+ block_out_channels: Tuple[int],
488
+ layers_per_block: Union[int, Tuple[int]],
489
+ cross_attention_dim: Union[int, Tuple[int]],
490
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
491
+ reverse_transformer_layers_per_block: bool,
492
+ attention_head_dim: int,
493
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
494
+ ):
495
+ if len(down_block_types) != len(up_block_types):
496
+ raise ValueError(
497
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
498
+ )
499
+
500
+ if len(block_out_channels) != len(down_block_types):
501
+ raise ValueError(
502
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
503
+ )
504
+
505
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
506
+ raise ValueError(
507
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
508
+ )
509
+
510
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
511
+ raise ValueError(
512
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
513
+ )
514
+
515
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
516
+ raise ValueError(
517
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
518
+ )
519
+
520
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
521
+ raise ValueError(
522
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
523
+ )
524
+
525
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
526
+ raise ValueError(
527
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
528
+ )
529
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
530
+ for layer_number_per_block in transformer_layers_per_block:
531
+ if isinstance(layer_number_per_block, list):
532
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
533
+
534
+ def _set_time_proj(
535
+ self,
536
+ time_embedding_type: str,
537
+ block_out_channels: int,
538
+ flip_sin_to_cos: bool,
539
+ freq_shift: float,
540
+ time_embedding_dim: int,
541
+ ) -> Tuple[int, int]:
542
+ if time_embedding_type == "fourier":
543
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
544
+ if time_embed_dim % 2 != 0:
545
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
546
+ self.time_proj = GaussianFourierProjection(
547
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
548
+ )
549
+ timestep_input_dim = time_embed_dim
550
+ elif time_embedding_type == "positional":
551
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
552
+
553
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
554
+ timestep_input_dim = block_out_channels[0]
555
+ else:
556
+ raise ValueError(
557
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
558
+ )
559
+
560
+ return time_embed_dim, timestep_input_dim
561
+
562
+ def _set_encoder_hid_proj(
563
+ self,
564
+ encoder_hid_dim_type: Optional[str],
565
+ cross_attention_dim: Union[int, Tuple[int]],
566
+ encoder_hid_dim: Optional[int],
567
+ ):
568
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
569
+ encoder_hid_dim_type = "text_proj"
570
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
571
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
572
+
573
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
574
+ raise ValueError(
575
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
576
+ )
577
+
578
+ if encoder_hid_dim_type == "text_proj":
579
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
580
+ elif encoder_hid_dim_type == "text_image_proj":
581
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
582
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
583
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
584
+ self.encoder_hid_proj = TextImageProjection(
585
+ text_embed_dim=encoder_hid_dim,
586
+ image_embed_dim=cross_attention_dim,
587
+ cross_attention_dim=cross_attention_dim,
588
+ )
589
+ elif encoder_hid_dim_type == "image_proj":
590
+ # Kandinsky 2.2
591
+ self.encoder_hid_proj = ImageProjection(
592
+ image_embed_dim=encoder_hid_dim,
593
+ cross_attention_dim=cross_attention_dim,
594
+ )
595
+ elif encoder_hid_dim_type is not None:
596
+ raise ValueError(
597
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
598
+ )
599
+ else:
600
+ self.encoder_hid_proj = None
601
+
602
+ def _set_class_embedding(
603
+ self,
604
+ class_embed_type: Optional[str],
605
+ act_fn: str,
606
+ num_class_embeds: Optional[int],
607
+ projection_class_embeddings_input_dim: Optional[int],
608
+ time_embed_dim: int,
609
+ timestep_input_dim: int,
610
+ ):
611
+ if class_embed_type is None and num_class_embeds is not None:
612
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
613
+ elif class_embed_type == "timestep":
614
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
615
+ elif class_embed_type == "identity":
616
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
617
+ elif class_embed_type == "projection":
618
+ if projection_class_embeddings_input_dim is None:
619
+ raise ValueError(
620
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
621
+ )
622
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
623
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
624
+ # 2. it projects from an arbitrary input dimension.
625
+ #
626
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
627
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
628
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
629
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
630
+ elif class_embed_type == "simple_projection":
631
+ if projection_class_embeddings_input_dim is None:
632
+ raise ValueError(
633
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
634
+ )
635
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
636
+ else:
637
+ self.class_embedding = None
638
+
639
+ def _set_add_embedding(
640
+ self,
641
+ addition_embed_type: str,
642
+ addition_embed_type_num_heads: int,
643
+ addition_time_embed_dim: Optional[int],
644
+ flip_sin_to_cos: bool,
645
+ freq_shift: float,
646
+ cross_attention_dim: Optional[int],
647
+ encoder_hid_dim: Optional[int],
648
+ projection_class_embeddings_input_dim: Optional[int],
649
+ time_embed_dim: int,
650
+ ):
651
+ if addition_embed_type == "text":
652
+ if encoder_hid_dim is not None:
653
+ text_time_embedding_from_dim = encoder_hid_dim
654
+ else:
655
+ text_time_embedding_from_dim = cross_attention_dim
656
+
657
+ self.add_embedding = TextTimeEmbedding(
658
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
659
+ )
660
+ elif addition_embed_type == "text_image":
661
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
662
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
663
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
664
+ self.add_embedding = TextImageTimeEmbedding(
665
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
666
+ )
667
+ elif addition_embed_type == "text_time":
668
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
669
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
670
+ elif addition_embed_type == "image":
671
+ # Kandinsky 2.2
672
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
673
+ elif addition_embed_type == "image_hint":
674
+ # Kandinsky 2.2 ControlNet
675
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
676
+ elif addition_embed_type is not None:
677
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
678
+
679
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
680
+ if attention_type in ["gated", "gated-text-image"]:
681
+ positive_len = 768
682
+ if isinstance(cross_attention_dim, int):
683
+ positive_len = cross_attention_dim
684
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
685
+ positive_len = cross_attention_dim[0]
686
+
687
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
688
+ self.position_net = GLIGENTextBoundingboxProjection(
689
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
690
+ )
691
+
692
+ @property
693
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
694
+ r"""
695
+ Returns:
696
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
697
+ indexed by its weight name.
698
+ """
699
+ # set recursively
700
+ processors = {}
701
+
702
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
703
+ if hasattr(module, "get_processor"):
704
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
705
+
706
+ for sub_name, child in module.named_children():
707
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
708
+
709
+ return processors
710
+
711
+ for name, module in self.named_children():
712
+ fn_recursive_add_processors(name, module, processors)
713
+
714
+ return processors
715
+
716
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
717
+ r"""
718
+ Sets the attention processor to use to compute attention.
719
+
720
+ Parameters:
721
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
722
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
723
+ for **all** `Attention` layers.
724
+
725
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
726
+ processor. This is strongly recommended when setting trainable attention processors.
727
+
728
+ """
729
+ count = len(self.attn_processors.keys())
730
+
731
+ if isinstance(processor, dict) and len(processor) != count:
732
+ raise ValueError(
733
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
734
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
735
+ )
736
+
737
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
738
+ if hasattr(module, "set_processor"):
739
+ if not isinstance(processor, dict):
740
+ module.set_processor(processor)
741
+ else:
742
+ module.set_processor(processor.pop(f"{name}.processor"))
743
+
744
+ for sub_name, child in module.named_children():
745
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
746
+
747
+ for name, module in self.named_children():
748
+ fn_recursive_attn_processor(name, module, processor)
749
+
750
+ def set_default_attn_processor(self):
751
+ """
752
+ Disables custom attention processors and sets the default attention implementation.
753
+ """
754
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
755
+ processor = AttnAddedKVProcessor()
756
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
757
+ processor = AttnProcessor()
758
+ else:
759
+ raise ValueError(
760
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
761
+ )
762
+
763
+ self.set_attn_processor(processor)
764
+
765
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
766
+ r"""
767
+ Enable sliced attention computation.
768
+
769
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
770
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
771
+
772
+ Args:
773
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
774
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
775
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
776
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
777
+ must be a multiple of `slice_size`.
778
+ """
779
+ sliceable_head_dims = []
780
+
781
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
782
+ if hasattr(module, "set_attention_slice"):
783
+ sliceable_head_dims.append(module.sliceable_head_dim)
784
+
785
+ for child in module.children():
786
+ fn_recursive_retrieve_sliceable_dims(child)
787
+
788
+ # retrieve number of attention layers
789
+ for module in self.children():
790
+ fn_recursive_retrieve_sliceable_dims(module)
791
+
792
+ num_sliceable_layers = len(sliceable_head_dims)
793
+
794
+ if slice_size == "auto":
795
+ # half the attention head size is usually a good trade-off between
796
+ # speed and memory
797
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
798
+ elif slice_size == "max":
799
+ # make smallest slice possible
800
+ slice_size = num_sliceable_layers * [1]
801
+
802
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
803
+
804
+ if len(slice_size) != len(sliceable_head_dims):
805
+ raise ValueError(
806
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
807
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
808
+ )
809
+
810
+ for i in range(len(slice_size)):
811
+ size = slice_size[i]
812
+ dim = sliceable_head_dims[i]
813
+ if size is not None and size > dim:
814
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
815
+
816
+ # Recursively walk through all the children.
817
+ # Any children which exposes the set_attention_slice method
818
+ # gets the message
819
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
820
+ if hasattr(module, "set_attention_slice"):
821
+ module.set_attention_slice(slice_size.pop())
822
+
823
+ for child in module.children():
824
+ fn_recursive_set_attention_slice(child, slice_size)
825
+
826
+ reversed_slice_size = list(reversed(slice_size))
827
+ for module in self.children():
828
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
829
+
830
+ def _set_gradient_checkpointing(self, module, value=False):
831
+ if hasattr(module, "gradient_checkpointing"):
832
+ module.gradient_checkpointing = value
833
+
834
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
835
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
836
+
837
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
838
+
839
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
840
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
841
+
842
+ Args:
843
+ s1 (`float`):
844
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
845
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
846
+ s2 (`float`):
847
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
848
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
849
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
850
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
851
+ """
852
+ for i, upsample_block in enumerate(self.up_blocks):
853
+ setattr(upsample_block, "s1", s1)
854
+ setattr(upsample_block, "s2", s2)
855
+ setattr(upsample_block, "b1", b1)
856
+ setattr(upsample_block, "b2", b2)
857
+
858
+ def disable_freeu(self):
859
+ """Disables the FreeU mechanism."""
860
+ freeu_keys = {"s1", "s2", "b1", "b2"}
861
+ for i, upsample_block in enumerate(self.up_blocks):
862
+ for k in freeu_keys:
863
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
864
+ setattr(upsample_block, k, None)
865
+
866
+ def fuse_qkv_projections(self):
867
+ """
868
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
869
+ are fused. For cross-attention modules, key and value projection matrices are fused.
870
+
871
+ <Tip warning={true}>
872
+
873
+ This API is 🧪 experimental.
874
+
875
+ </Tip>
876
+ """
877
+ self.original_attn_processors = None
878
+
879
+ for _, attn_processor in self.attn_processors.items():
880
+ if "Added" in str(attn_processor.__class__.__name__):
881
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
882
+
883
+ self.original_attn_processors = self.attn_processors
884
+
885
+ for module in self.modules():
886
+ if isinstance(module, Attention):
887
+ module.fuse_projections(fuse=True)
888
+
889
+ def unfuse_qkv_projections(self):
890
+ """Disables the fused QKV projection if enabled.
891
+
892
+ <Tip warning={true}>
893
+
894
+ This API is 🧪 experimental.
895
+
896
+ </Tip>
897
+
898
+ """
899
+ if self.original_attn_processors is not None:
900
+ self.set_attn_processor(self.original_attn_processors)
901
+
902
+ def unload_lora(self):
903
+ """Unloads LoRA weights."""
904
+ deprecate(
905
+ "unload_lora",
906
+ "0.28.0",
907
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
908
+ )
909
+ for module in self.modules():
910
+ if hasattr(module, "set_lora_layer"):
911
+ module.set_lora_layer(None)
912
+
913
+ def get_time_embed(
914
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
915
+ ) -> Optional[torch.Tensor]:
916
+ timesteps = timestep
917
+ if not torch.is_tensor(timesteps):
918
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
919
+ # This would be a good case for the `match` statement (Python 3.10+)
920
+ is_mps = sample.device.type == "mps"
921
+ if isinstance(timestep, float):
922
+ dtype = torch.float32 if is_mps else torch.float64
923
+ else:
924
+ dtype = torch.int32 if is_mps else torch.int64
925
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
926
+ elif len(timesteps.shape) == 0:
927
+ timesteps = timesteps[None].to(sample.device)
928
+
929
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
930
+ timesteps = timesteps.expand(sample.shape[0])
931
+
932
+ t_emb = self.time_proj(timesteps)
933
+ # `Timesteps` does not contain any weights and will always return f32 tensors
934
+ # but time_embedding might actually be running in fp16. so we need to cast here.
935
+ # there might be better ways to encapsulate this.
936
+ t_emb = t_emb.to(dtype=sample.dtype)
937
+ return t_emb
938
+
939
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
940
+ class_emb = None
941
+ if self.class_embedding is not None:
942
+ if class_labels is None:
943
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
944
+
945
+ if self.config.class_embed_type == "timestep":
946
+ class_labels = self.time_proj(class_labels)
947
+
948
+ # `Timesteps` does not contain any weights and will always return f32 tensors
949
+ # there might be better ways to encapsulate this.
950
+ class_labels = class_labels.to(dtype=sample.dtype)
951
+
952
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
953
+ return class_emb
954
+
955
+ def get_aug_embed(
956
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
957
+ ) -> Optional[torch.Tensor]:
958
+ aug_emb = None
959
+ if self.config.addition_embed_type == "text":
960
+ aug_emb = self.add_embedding(encoder_hidden_states)
961
+ elif self.config.addition_embed_type == "text_image":
962
+ # Kandinsky 2.1 - style
963
+ if "image_embeds" not in added_cond_kwargs:
964
+ raise ValueError(
965
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
966
+ )
967
+
968
+ image_embs = added_cond_kwargs.get("image_embeds")
969
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
970
+ aug_emb = self.add_embedding(text_embs, image_embs)
971
+ elif self.config.addition_embed_type == "text_time":
972
+ # SDXL - style
973
+ if "text_embeds" not in added_cond_kwargs:
974
+ raise ValueError(
975
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
976
+ )
977
+ text_embeds = added_cond_kwargs.get("text_embeds")
978
+ if "time_ids" not in added_cond_kwargs:
979
+ raise ValueError(
980
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
981
+ )
982
+ time_ids = added_cond_kwargs.get("time_ids")
983
+ time_embeds = self.add_time_proj(time_ids.flatten())
984
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
985
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
986
+ add_embeds = add_embeds.to(emb.dtype)
987
+ aug_emb = self.add_embedding(add_embeds)
988
+ elif self.config.addition_embed_type == "image":
989
+ # Kandinsky 2.2 - style
990
+ if "image_embeds" not in added_cond_kwargs:
991
+ raise ValueError(
992
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
993
+ )
994
+ image_embs = added_cond_kwargs.get("image_embeds")
995
+ aug_emb = self.add_embedding(image_embs)
996
+ elif self.config.addition_embed_type == "image_hint":
997
+ # Kandinsky 2.2 - style
998
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
999
+ raise ValueError(
1000
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1001
+ )
1002
+ image_embs = added_cond_kwargs.get("image_embeds")
1003
+ hint = added_cond_kwargs.get("hint")
1004
+ aug_emb = self.add_embedding(image_embs, hint)
1005
+ return aug_emb
1006
+
1007
+ def process_encoder_hidden_states(
1008
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
1009
+ ) -> torch.Tensor:
1010
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1011
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1012
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1013
+ # Kandinsky 2.1 - style
1014
+ if "image_embeds" not in added_cond_kwargs:
1015
+ raise ValueError(
1016
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1017
+ )
1018
+
1019
+ image_embeds = added_cond_kwargs.get("image_embeds")
1020
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
1021
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
1022
+ # Kandinsky 2.2 - style
1023
+ if "image_embeds" not in added_cond_kwargs:
1024
+ raise ValueError(
1025
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1026
+ )
1027
+ image_embeds = added_cond_kwargs.get("image_embeds")
1028
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1029
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
1030
+ if "image_embeds" not in added_cond_kwargs:
1031
+ raise ValueError(
1032
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1033
+ )
1034
+ image_embeds = added_cond_kwargs.get("image_embeds")
1035
+ image_embeds = self.encoder_hid_proj(image_embeds)
1036
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
1037
+ return encoder_hidden_states
1038
+
1039
+ def forward(
1040
+ self,
1041
+ sample: torch.FloatTensor,
1042
+ timestep: Union[torch.Tensor, float, int],
1043
+ encoder_hidden_states: torch.Tensor,
1044
+ class_labels: Optional[torch.Tensor] = None,
1045
+ timestep_cond: Optional[torch.Tensor] = None,
1046
+ attention_mask: Optional[torch.Tensor] = None,
1047
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1048
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
1049
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1050
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1051
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1052
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1053
+ return_dict: bool = True,
1054
+ down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1055
+ mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
1056
+ up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
1057
+ ) -> Union[UNet2DConditionOutput, Tuple]:
1058
+ r"""
1059
+ The [`UNet2DConditionModel`] forward method.
1060
+
1061
+ Args:
1062
+ sample (`torch.FloatTensor`):
1063
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
1064
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1065
+ encoder_hidden_states (`torch.FloatTensor`):
1066
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1067
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1068
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
1069
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
1070
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
1071
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
1072
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
1073
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
1074
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
1075
+ negative values to the attention scores corresponding to "discard" tokens.
1076
+ cross_attention_kwargs (`dict`, *optional*):
1077
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1078
+ `self.processor` in
1079
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1080
+ added_cond_kwargs: (`dict`, *optional*):
1081
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
1082
+ are passed along to the UNet blocks.
1083
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
1084
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
1085
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
1086
+ A tensor that if specified is added to the residual of the middle unet block.
1087
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
1088
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
1089
+ encoder_attention_mask (`torch.Tensor`):
1090
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
1091
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
1092
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
1093
+ return_dict (`bool`, *optional*, defaults to `True`):
1094
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
1095
+ tuple.
1096
+
1097
+ Returns:
1098
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1099
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1100
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1101
+ """
1102
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
1103
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
1104
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
1105
+ # on the fly if necessary.
1106
+ default_overall_up_factor = 2**self.num_upsamplers
1107
+
1108
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1109
+ forward_upsample_size = False
1110
+ upsample_size = None
1111
+
1112
+ for dim in sample.shape[-2:]:
1113
+ if dim % default_overall_up_factor != 0:
1114
+ # Forward upsample size to force interpolation output size.
1115
+ forward_upsample_size = True
1116
+ break
1117
+
1118
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
1119
+ # expects mask of shape:
1120
+ # [batch, key_tokens]
1121
+ # adds singleton query_tokens dimension:
1122
+ # [batch, 1, key_tokens]
1123
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
1124
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
1125
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
1126
+ if attention_mask is not None:
1127
+ # assume that mask is expressed as:
1128
+ # (1 = keep, 0 = discard)
1129
+ # convert mask into a bias that can be added to attention scores:
1130
+ # (keep = +0, discard = -10000.0)
1131
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1132
+ attention_mask = attention_mask.unsqueeze(1)
1133
+
1134
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
1135
+ if encoder_attention_mask is not None:
1136
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
1137
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1138
+
1139
+ # 0. center input if necessary
1140
+ if self.config.center_input_sample:
1141
+ sample = 2 * sample - 1.0
1142
+
1143
+ # 1. time
1144
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
1145
+ emb = self.time_embedding(t_emb, timestep_cond)
1146
+ aug_emb = None
1147
+
1148
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
1149
+ if class_emb is not None:
1150
+ if self.config.class_embeddings_concat:
1151
+ emb = torch.cat([emb, class_emb], dim=-1)
1152
+ else:
1153
+ emb = emb + class_emb
1154
+
1155
+ aug_emb = self.get_aug_embed(
1156
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1157
+ )
1158
+ if self.config.addition_embed_type == "image_hint":
1159
+ aug_emb, hint = aug_emb
1160
+ sample = torch.cat([sample, hint], dim=1)
1161
+
1162
+ emb = emb + aug_emb if aug_emb is not None else emb
1163
+
1164
+ if self.time_embed_act is not None:
1165
+ emb = self.time_embed_act(emb)
1166
+
1167
+ encoder_hidden_states = self.process_encoder_hidden_states(
1168
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1169
+ )
1170
+
1171
+ # 2. pre-process
1172
+ sample = self.conv_in(sample)
1173
+
1174
+ # 2.5 GLIGEN position net
1175
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1176
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1177
+ gligen_args = cross_attention_kwargs.pop("gligen")
1178
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1179
+
1180
+ # 3. down
1181
+ # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
1182
+ # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
1183
+ if cross_attention_kwargs is not None:
1184
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1185
+ lora_scale = cross_attention_kwargs.pop("scale", 1.0)
1186
+ else:
1187
+ lora_scale = 1.0
1188
+
1189
+ if USE_PEFT_BACKEND:
1190
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1191
+ scale_lora_layers(self, lora_scale)
1192
+
1193
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
1194
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1195
+ is_adapter = down_intrablock_additional_residuals is not None
1196
+ # maintain backward compatibility for legacy usage, where
1197
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1198
+ # but can only use one or the other
1199
+ is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
1200
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
1201
+ deprecate(
1202
+ "T2I should not use down_block_additional_residuals",
1203
+ "1.3.0",
1204
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1205
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1206
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1207
+ standard_warn=False,
1208
+ )
1209
+ down_intrablock_additional_residuals = down_block_additional_residuals
1210
+ is_adapter = True
1211
+
1212
+ down_block_res_samples = (sample,)
1213
+
1214
+ if is_brushnet:
1215
+ sample = sample + down_block_add_samples.pop(0)
1216
+
1217
+ for downsample_block in self.down_blocks:
1218
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1219
+ # For t2i-adapter CrossAttnDownBlock2D
1220
+ additional_residuals = {}
1221
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1222
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
1223
+
1224
+ i = len(down_block_add_samples)
1225
+
1226
+ if is_brushnet and len(down_block_add_samples)>0:
1227
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
1228
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
1229
+
1230
+ sample, res_samples = downsample_block(
1231
+ hidden_states=sample,
1232
+ temb=emb,
1233
+ encoder_hidden_states=encoder_hidden_states,
1234
+ attention_mask=attention_mask,
1235
+ cross_attention_kwargs=cross_attention_kwargs,
1236
+ encoder_attention_mask=encoder_attention_mask,
1237
+ **additional_residuals,
1238
+ )
1239
+ else:
1240
+ additional_residuals = {}
1241
+
1242
+ i = len(down_block_add_samples)
1243
+
1244
+ if is_brushnet and len(down_block_add_samples)>0:
1245
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
1246
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
1247
+
1248
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, **additional_residuals)
1249
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1250
+ sample += down_intrablock_additional_residuals.pop(0)
1251
+
1252
+ down_block_res_samples += res_samples
1253
+
1254
+ if is_controlnet:
1255
+ new_down_block_res_samples = ()
1256
+
1257
+ for down_block_res_sample, down_block_additional_residual in zip(
1258
+ down_block_res_samples, down_block_additional_residuals
1259
+ ):
1260
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1261
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
1262
+
1263
+ down_block_res_samples = new_down_block_res_samples
1264
+
1265
+ # 4. mid
1266
+ if self.mid_block is not None:
1267
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
1268
+ sample = self.mid_block(
1269
+ sample,
1270
+ emb,
1271
+ encoder_hidden_states=encoder_hidden_states,
1272
+ attention_mask=attention_mask,
1273
+ cross_attention_kwargs=cross_attention_kwargs,
1274
+ encoder_attention_mask=encoder_attention_mask,
1275
+ )
1276
+ else:
1277
+ sample = self.mid_block(sample, emb)
1278
+
1279
+ # To support T2I-Adapter-XL
1280
+ if (
1281
+ is_adapter
1282
+ and len(down_intrablock_additional_residuals) > 0
1283
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1284
+ ):
1285
+ sample += down_intrablock_additional_residuals.pop(0)
1286
+
1287
+ if is_controlnet:
1288
+ sample = sample + mid_block_additional_residual
1289
+
1290
+ if is_brushnet:
1291
+ sample = sample + mid_block_add_sample
1292
+
1293
+ # 5. up
1294
+ for i, upsample_block in enumerate(self.up_blocks):
1295
+ is_final_block = i == len(self.up_blocks) - 1
1296
+
1297
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1298
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1299
+
1300
+ # if we have not reached the final block and need to forward the
1301
+ # upsample size, we do it here
1302
+ if not is_final_block and forward_upsample_size:
1303
+ upsample_size = down_block_res_samples[-1].shape[2:]
1304
+
1305
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1306
+ additional_residuals = {}
1307
+
1308
+ i = len(up_block_add_samples)
1309
+
1310
+ if is_brushnet and len(up_block_add_samples)>0:
1311
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
1312
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
1313
+
1314
+ sample = upsample_block(
1315
+ hidden_states=sample,
1316
+ temb=emb,
1317
+ res_hidden_states_tuple=res_samples,
1318
+ encoder_hidden_states=encoder_hidden_states,
1319
+ cross_attention_kwargs=cross_attention_kwargs,
1320
+ upsample_size=upsample_size,
1321
+ attention_mask=attention_mask,
1322
+ encoder_attention_mask=encoder_attention_mask,
1323
+ **additional_residuals,
1324
+ )
1325
+ else:
1326
+ additional_residuals = {}
1327
+
1328
+ i = len(up_block_add_samples)
1329
+
1330
+ if is_brushnet and len(up_block_add_samples)>0:
1331
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
1332
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
1333
+
1334
+ sample = upsample_block(
1335
+ hidden_states=sample,
1336
+ temb=emb,
1337
+ res_hidden_states_tuple=res_samples,
1338
+ upsample_size=upsample_size,
1339
+ **additional_residuals,
1340
+ )
1341
+
1342
+ # 6. post-process
1343
+ if self.conv_norm_out:
1344
+ sample = self.conv_norm_out(sample)
1345
+ sample = self.conv_act(sample)
1346
+ sample = self.conv_out(sample)
1347
+
1348
+ if USE_PEFT_BACKEND:
1349
+ # remove `lora_scale` from each PEFT layer
1350
+ unscale_lora_layers(self, lora_scale)
1351
+
1352
+ if not return_dict:
1353
+ return (sample,)
1354
+
1355
+ return UNet2DConditionOutput(sample=sample)
ComfyUI-BrushNet/brushnet_nodes.py ADDED
@@ -0,0 +1,1080 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import types
3
+ from typing import Tuple
4
+
5
+ import torch
6
+ import torchvision.transforms as T
7
+ import torch.nn.functional as F
8
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
9
+
10
+ #import sys
11
+ #from sys import platform
12
+ # Get the parent directory of 'comfy' and add it to the Python path
13
+ #comfy_parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
14
+ #sys.path.append(comfy_parent_dir)
15
+
16
+ import comfy
17
+ import folder_paths
18
+
19
+ from .model_patch import add_model_patch_option, patch_model_function_wrapper
20
+
21
+ from .brushnet.brushnet import BrushNetModel
22
+ from .brushnet.brushnet_ca import BrushNetModel as PowerPaintModel
23
+
24
+ from .brushnet.powerpaint_utils import TokenizerWrapper, add_tokens
25
+
26
+ current_directory = os.path.dirname(os.path.abspath(__file__))
27
+ brushnet_config_file = os.path.join(current_directory, 'brushnet', 'brushnet.json')
28
+ brushnet_xl_config_file = os.path.join(current_directory, 'brushnet', 'brushnet_xl.json')
29
+ powerpaint_config_file = os.path.join(current_directory,'brushnet', 'powerpaint.json')
30
+
31
+ sd15_scaling_factor = 0.18215
32
+ sdxl_scaling_factor = 0.13025
33
+
34
+ ModelsToUnload = [comfy.sd1_clip.SD1ClipModel,
35
+ comfy.ldm.models.autoencoder.AutoencoderKL
36
+ ]
37
+
38
+
39
+ class BrushNetLoader:
40
+
41
+ @classmethod
42
+ def INPUT_TYPES(s):
43
+ files, inpaint_path = get_files_with_extension('inpaint')
44
+ s.inpaint_path = inpaint_path
45
+ return {"required":
46
+ {
47
+ "brushnet": (files, ),
48
+ "dtype": (['float16', 'bfloat16', 'float32', 'float64'], ),
49
+ },
50
+ }
51
+
52
+ CATEGORY = "inpaint"
53
+ RETURN_TYPES = ("BRMODEL",)
54
+ RETURN_NAMES = ("brushnet",)
55
+
56
+ FUNCTION = "brushnet_loading"
57
+
58
+ def brushnet_loading(self, brushnet, dtype):
59
+ brushnet_file = os.path.join(self.inpaint_path, brushnet)
60
+ is_SDXL = False
61
+ is_PP = False
62
+ sd = comfy.utils.load_torch_file(brushnet_file)
63
+ brushnet_down_block, brushnet_mid_block, brushnet_up_block, keys = brushnet_blocks(sd)
64
+ del sd
65
+ if brushnet_down_block == 24 and brushnet_mid_block == 2 and brushnet_up_block == 30:
66
+ is_SDXL = False
67
+ if keys == 322:
68
+ is_PP = False
69
+ print('BrushNet model type: SD1.5')
70
+ else:
71
+ is_PP = True
72
+ print('PowerPaint model type: SD1.5')
73
+ elif brushnet_down_block == 18 and brushnet_mid_block == 2 and brushnet_up_block == 22:
74
+ print('BrushNet model type: Loading SDXL')
75
+ is_SDXL = True
76
+ is_PP = False
77
+ else:
78
+ raise Exception("Unknown BrushNet model")
79
+
80
+ with init_empty_weights():
81
+ if is_SDXL:
82
+ brushnet_config = BrushNetModel.load_config(brushnet_xl_config_file)
83
+ brushnet_model = BrushNetModel.from_config(brushnet_config)
84
+ elif is_PP:
85
+ brushnet_config = PowerPaintModel.load_config(powerpaint_config_file)
86
+ brushnet_model = PowerPaintModel.from_config(brushnet_config)
87
+ else:
88
+ brushnet_config = BrushNetModel.load_config(brushnet_config_file)
89
+ brushnet_model = BrushNetModel.from_config(brushnet_config)
90
+
91
+ if is_PP:
92
+ print("PowerPaint model file:", brushnet_file)
93
+ else:
94
+ print("BrushNet model file:", brushnet_file)
95
+
96
+ if dtype == 'float16':
97
+ torch_dtype = torch.float16
98
+ elif dtype == 'bfloat16':
99
+ torch_dtype = torch.bfloat16
100
+ elif dtype == 'float32':
101
+ torch_dtype = torch.float32
102
+ else:
103
+ torch_dtype = torch.float64
104
+
105
+ brushnet_model = load_checkpoint_and_dispatch(
106
+ brushnet_model,
107
+ brushnet_file,
108
+ device_map="sequential",
109
+ max_memory=None,
110
+ offload_folder=None,
111
+ offload_state_dict=False,
112
+ dtype=torch_dtype,
113
+ force_hooks=False,
114
+ )
115
+
116
+ if is_PP:
117
+ print("PowerPaint model is loaded")
118
+ elif is_SDXL:
119
+ print("BrushNet SDXL model is loaded")
120
+ else:
121
+ print("BrushNet SD1.5 model is loaded")
122
+
123
+ return ({"brushnet": brushnet_model, "SDXL": is_SDXL, "PP": is_PP, "dtype": torch_dtype}, )
124
+
125
+
126
+ class PowerPaintCLIPLoader:
127
+
128
+ @classmethod
129
+ def INPUT_TYPES(s):
130
+ inpaint_files, inpaint_path = get_files_with_extension('inpaint', ['bin'])
131
+ s.inpaint_path = inpaint_path
132
+ clip_files, clip_path = get_files_with_extension('clip')
133
+ s.clip_path = clip_path
134
+ return {"required":
135
+ {
136
+ "base": (clip_files, ),
137
+ "powerpaint": (inpaint_files, ),
138
+ },
139
+ }
140
+
141
+ CATEGORY = "inpaint"
142
+ RETURN_TYPES = ("CLIP",)
143
+ RETURN_NAMES = ("clip",)
144
+
145
+ FUNCTION = "ppclip_loading"
146
+
147
+ def ppclip_loading(self, base, powerpaint):
148
+ base_CLIP_file = os.path.join(self.clip_path, base)
149
+ pp_CLIP_file = os.path.join(self.inpaint_path, powerpaint)
150
+
151
+ pp_clip = comfy.sd.load_clip(ckpt_paths=[base_CLIP_file])
152
+
153
+ print('PowerPaint base CLIP file: ', base_CLIP_file)
154
+
155
+ pp_tokenizer = TokenizerWrapper(pp_clip.tokenizer.clip_l.tokenizer)
156
+ pp_text_encoder = pp_clip.patcher.model.clip_l.transformer
157
+
158
+ add_tokens(
159
+ tokenizer = pp_tokenizer,
160
+ text_encoder = pp_text_encoder,
161
+ placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"],
162
+ initialize_tokens = ["a", "a", "a"],
163
+ num_vectors_per_token = 10,
164
+ )
165
+
166
+ pp_text_encoder.load_state_dict(comfy.utils.load_torch_file(pp_CLIP_file), strict=False)
167
+
168
+ print('PowerPaint CLIP file: ', pp_CLIP_file)
169
+
170
+ pp_clip.tokenizer.clip_l.tokenizer = pp_tokenizer
171
+ pp_clip.patcher.model.clip_l.transformer = pp_text_encoder
172
+
173
+ return (pp_clip,)
174
+
175
+
176
+ class PowerPaint:
177
+
178
+ @classmethod
179
+ def INPUT_TYPES(s):
180
+ return {"required":
181
+ {
182
+ "model": ("MODEL",),
183
+ "vae": ("VAE", ),
184
+ "image": ("IMAGE",),
185
+ "mask": ("MASK",),
186
+ "powerpaint": ("BRMODEL", ),
187
+ "clip": ("CLIP", ),
188
+ "positive": ("CONDITIONING", ),
189
+ "negative": ("CONDITIONING", ),
190
+ "fitting" : ("FLOAT", {"default": 1.0, "min": 0.3, "max": 1.0}),
191
+ "function": (['text guided', 'shape guided', 'object removal', 'context aware', 'image outpainting'], ),
192
+ "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
193
+ "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
194
+ "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
195
+ },
196
+ }
197
+
198
+ CATEGORY = "inpaint"
199
+ RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
200
+ RETURN_NAMES = ("model","positive","negative","latent",)
201
+
202
+ FUNCTION = "model_update"
203
+
204
+ def model_update(self, model, vae, image, mask, powerpaint, clip, positive, negative, fitting, function, scale, start_at, end_at):
205
+
206
+ is_SDXL, is_PP = check_compatibilty(model, powerpaint)
207
+ if not is_PP:
208
+ raise Exception("BrushNet model was loaded, please use BrushNet node")
209
+
210
+ # Make a copy of the model so that we're not patching it everywhere in the workflow.
211
+ model = model.clone()
212
+
213
+ # prepare image and mask
214
+ # no batches for original image and mask
215
+ masked_image, mask = prepare_image(image, mask)
216
+
217
+ batch = masked_image.shape[0]
218
+ #width = masked_image.shape[2]
219
+ #height = masked_image.shape[1]
220
+
221
+ if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
222
+ scaling_factor = model.model.model_config.latent_format.scale_factor
223
+ else:
224
+ scaling_factor = sd15_scaling_factor
225
+
226
+ torch_dtype = powerpaint['dtype']
227
+
228
+ # prepare conditioning latents
229
+ conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
230
+ conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
231
+ conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
232
+
233
+ # prepare embeddings
234
+
235
+ if function == "object removal":
236
+ promptA = "P_ctxt"
237
+ promptB = "P_ctxt"
238
+ negative_promptA = "P_obj"
239
+ negative_promptB = "P_obj"
240
+ print('You should add to positive prompt: "empty scene blur"')
241
+ #positive = positive + " empty scene blur"
242
+ elif function == "context aware":
243
+ promptA = "P_ctxt"
244
+ promptB = "P_ctxt"
245
+ negative_promptA = ""
246
+ negative_promptB = ""
247
+ #positive = positive + " empty scene"
248
+ print('You should add to positive prompt: "empty scene"')
249
+ elif function == "shape guided":
250
+ promptA = "P_shape"
251
+ promptB = "P_ctxt"
252
+ negative_promptA = "P_shape"
253
+ negative_promptB = "P_ctxt"
254
+ elif function == "image outpainting":
255
+ promptA = "P_ctxt"
256
+ promptB = "P_ctxt"
257
+ negative_promptA = "P_obj"
258
+ negative_promptB = "P_obj"
259
+ #positive = positive + " empty scene"
260
+ print('You should add to positive prompt: "empty scene"')
261
+ else:
262
+ promptA = "P_obj"
263
+ promptB = "P_obj"
264
+ negative_promptA = "P_obj"
265
+ negative_promptB = "P_obj"
266
+
267
+ tokens = clip.tokenize(promptA)
268
+ prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
269
+
270
+ tokens = clip.tokenize(negative_promptA)
271
+ negative_prompt_embedsA = clip.encode_from_tokens(tokens, return_pooled=False)
272
+
273
+ tokens = clip.tokenize(promptB)
274
+ prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
275
+
276
+ tokens = clip.tokenize(negative_promptB)
277
+ negative_prompt_embedsB = clip.encode_from_tokens(tokens, return_pooled=False)
278
+
279
+ prompt_embeds_pp = (prompt_embedsA * fitting + (1.0 - fitting) * prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
280
+ negative_prompt_embeds_pp = (negative_prompt_embedsA * fitting + (1.0 - fitting) * negative_prompt_embedsB).to(dtype=torch_dtype).to(powerpaint['brushnet'].device)
281
+
282
+ # unload vae and CLIPs
283
+ del vae
284
+ del clip
285
+ for loaded_model in comfy.model_management.current_loaded_models:
286
+ if type(loaded_model.model.model) in ModelsToUnload:
287
+ comfy.model_management.current_loaded_models.remove(loaded_model)
288
+ loaded_model.model_unload()
289
+ del loaded_model
290
+
291
+ # apply patch to model
292
+
293
+ brushnet_conditioning_scale = scale
294
+ control_guidance_start = start_at
295
+ control_guidance_end = end_at
296
+
297
+ add_brushnet_patch(model,
298
+ powerpaint['brushnet'],
299
+ torch_dtype,
300
+ conditioning_latents,
301
+ (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
302
+ negative_prompt_embeds_pp, prompt_embeds_pp,
303
+ None, None, None)
304
+
305
+ latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=powerpaint['brushnet'].device)
306
+
307
+ return (model, positive, negative, {"samples":latent},)
308
+
309
+
310
+ class BrushNet:
311
+
312
+ @classmethod
313
+ def INPUT_TYPES(s):
314
+ return {"required":
315
+ {
316
+ "model": ("MODEL",),
317
+ "vae": ("VAE", ),
318
+ "image": ("IMAGE",),
319
+ "mask": ("MASK",),
320
+ "brushnet": ("BRMODEL", ),
321
+ "positive": ("CONDITIONING", ),
322
+ "negative": ("CONDITIONING", ),
323
+ "scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
324
+ "start_at": ("INT", {"default": 0, "min": 0, "max": 10000}),
325
+ "end_at": ("INT", {"default": 10000, "min": 0, "max": 10000}),
326
+ },
327
+ }
328
+
329
+ CATEGORY = "inpaint"
330
+ RETURN_TYPES = ("MODEL","CONDITIONING","CONDITIONING","LATENT",)
331
+ RETURN_NAMES = ("model","positive","negative","latent",)
332
+
333
+ FUNCTION = "model_update"
334
+
335
+ def model_update(self, model, vae, image, mask, brushnet, positive, negative, scale, start_at, end_at):
336
+
337
+ is_SDXL, is_PP = check_compatibilty(model, brushnet)
338
+
339
+ if is_PP:
340
+ raise Exception("PowerPaint model was loaded, please use PowerPaint node")
341
+
342
+ # Make a copy of the model so that we're not patching it everywhere in the workflow.
343
+ model = model.clone()
344
+
345
+ # prepare image and mask
346
+ # no batches for original image and mask
347
+ masked_image, mask = prepare_image(image, mask)
348
+
349
+ batch = masked_image.shape[0]
350
+ width = masked_image.shape[2]
351
+ height = masked_image.shape[1]
352
+
353
+ if hasattr(model.model.model_config, 'latent_format') and hasattr(model.model.model_config.latent_format, 'scale_factor'):
354
+ scaling_factor = model.model.model_config.latent_format.scale_factor
355
+ elif is_SDXL:
356
+ scaling_factor = sdxl_scaling_factor
357
+ else:
358
+ scaling_factor = sd15_scaling_factor
359
+
360
+ torch_dtype = brushnet['dtype']
361
+
362
+ # prepare conditioning latents
363
+ conditioning_latents = get_image_latents(masked_image, mask, vae, scaling_factor)
364
+ conditioning_latents[0] = conditioning_latents[0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
365
+ conditioning_latents[1] = conditioning_latents[1].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
366
+
367
+ # unload vae
368
+ del vae
369
+ for loaded_model in comfy.model_management.current_loaded_models:
370
+ if type(loaded_model.model.model) in ModelsToUnload:
371
+ comfy.model_management.current_loaded_models.remove(loaded_model)
372
+ loaded_model.model_unload()
373
+ del loaded_model
374
+
375
+ # prepare embeddings
376
+
377
+ prompt_embeds = positive[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
378
+ negative_prompt_embeds = negative[0][0].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
379
+
380
+ max_tokens = max(prompt_embeds.shape[1], negative_prompt_embeds.shape[1])
381
+ if prompt_embeds.shape[1] < max_tokens:
382
+ multiplier = max_tokens // 77 - prompt_embeds.shape[1] // 77
383
+ prompt_embeds = torch.concat([prompt_embeds] + [prompt_embeds[:,-77:,:]] * multiplier, dim=1)
384
+ print('BrushNet: negative prompt more than 75 tokens:', negative_prompt_embeds.shape, 'multiplying prompt_embeds')
385
+ if negative_prompt_embeds.shape[1] < max_tokens:
386
+ multiplier = max_tokens // 77 - negative_prompt_embeds.shape[1] // 77
387
+ negative_prompt_embeds = torch.concat([negative_prompt_embeds] + [negative_prompt_embeds[:,-77:,:]] * multiplier, dim=1)
388
+ print('BrushNet: positive prompt more than 75 tokens:', prompt_embeds.shape, 'multiplying negative_prompt_embeds')
389
+
390
+ if len(positive[0]) > 1 and 'pooled_output' in positive[0][1] and positive[0][1]['pooled_output'] is not None:
391
+ pooled_prompt_embeds = positive[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
392
+ else:
393
+ print('BrushNet: positive conditioning has not pooled_output')
394
+ if is_SDXL:
395
+ print('BrushNet will not produce correct results')
396
+ pooled_prompt_embeds = torch.empty([2, 1280], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
397
+
398
+ if len(negative[0]) > 1 and 'pooled_output' in negative[0][1] and negative[0][1]['pooled_output'] is not None:
399
+ negative_pooled_prompt_embeds = negative[0][1]['pooled_output'].to(dtype=torch_dtype).to(brushnet['brushnet'].device)
400
+ else:
401
+ print('BrushNet: negative conditioning has not pooled_output')
402
+ if is_SDXL:
403
+ print('BrushNet will not produce correct results')
404
+ negative_pooled_prompt_embeds = torch.empty([1, pooled_prompt_embeds.shape[1]], device=brushnet['brushnet'].device).to(dtype=torch_dtype)
405
+
406
+ time_ids = torch.FloatTensor([[height, width, 0., 0., height, width]]).to(dtype=torch_dtype).to(brushnet['brushnet'].device)
407
+
408
+ if not is_SDXL:
409
+ pooled_prompt_embeds = None
410
+ negative_pooled_prompt_embeds = None
411
+ time_ids = None
412
+
413
+ # apply patch to model
414
+
415
+ brushnet_conditioning_scale = scale
416
+ control_guidance_start = start_at
417
+ control_guidance_end = end_at
418
+
419
+ add_brushnet_patch(model,
420
+ brushnet['brushnet'],
421
+ torch_dtype,
422
+ conditioning_latents,
423
+ (brushnet_conditioning_scale, control_guidance_start, control_guidance_end),
424
+ prompt_embeds, negative_prompt_embeds,
425
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
426
+
427
+ latent = torch.zeros([batch, 4, conditioning_latents[0].shape[2], conditioning_latents[0].shape[3]], device=brushnet['brushnet'].device)
428
+
429
+ return (model, positive, negative, {"samples":latent},)
430
+
431
+
432
+ class BlendInpaint:
433
+
434
+ @classmethod
435
+ def INPUT_TYPES(s):
436
+ return {"required":
437
+ {
438
+ "inpaint": ("IMAGE",),
439
+ "original": ("IMAGE",),
440
+ "mask": ("MASK",),
441
+ "kernel": ("INT", {"default": 10, "min": 1, "max": 1000}),
442
+ "sigma": ("FLOAT", {"default": 10.0, "min": 0.01, "max": 1000}),
443
+ },
444
+ "optional":
445
+ {
446
+ "origin": ("VECTOR",),
447
+ },
448
+ }
449
+
450
+ CATEGORY = "inpaint"
451
+ RETURN_TYPES = ("IMAGE","MASK",)
452
+ RETURN_NAMES = ("image","MASK",)
453
+
454
+ FUNCTION = "blend_inpaint"
455
+
456
+ def blend_inpaint(self, inpaint: torch.Tensor, original: torch.Tensor, mask, kernel: int, sigma:int, origin=None) -> Tuple[torch.Tensor]:
457
+
458
+ original, mask = check_image_mask(original, mask, 'Blend Inpaint')
459
+
460
+ if len(inpaint.shape) < 4:
461
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
462
+ inpaint = inpaint[None,:,:,:]
463
+
464
+ if inpaint.shape[0] < original.shape[0]:
465
+ print("Blend Inpaint gets batch of original images (%d) but only (%d) inpaint images" % (original.shape[0], inpaint.shape[0]))
466
+ original= original[:inpaint.shape[0],:,:]
467
+ mask = mask[:inpaint.shape[0],:,:]
468
+
469
+ if inpaint.shape[0] > original.shape[0]:
470
+ # batch over inpaint
471
+ count = 0
472
+ original_list = []
473
+ mask_list = []
474
+ origin_list = []
475
+ while (count < inpaint.shape[0]):
476
+ for i in range(original.shape[0]):
477
+ original_list.append(original[i][None,:,:,:])
478
+ mask_list.append(mask[i][None,:,:])
479
+ if origin is not None:
480
+ origin_list.append(origin[i][None,:])
481
+ count += 1
482
+ if count >= inpaint.shape[0]:
483
+ break
484
+ original = torch.concat(original_list, dim=0)
485
+ mask = torch.concat(mask_list, dim=0)
486
+ if origin is not None:
487
+ origin = torch.concat(origin_list, dim=0)
488
+
489
+ if kernel % 2 == 0:
490
+ kernel += 1
491
+ transform = T.GaussianBlur(kernel_size=(kernel, kernel), sigma=(sigma, sigma))
492
+
493
+ ret = []
494
+ blurred = []
495
+ for i in range(inpaint.shape[0]):
496
+ if origin is None:
497
+ blurred_mask = transform(mask[i][None,None,:,:]).to(original.device).to(original.dtype)
498
+ blurred.append(blurred_mask[0])
499
+
500
+ result = torch.nn.functional.interpolate(
501
+ inpaint[i][None,:,:,:].permute(0, 3, 1, 2),
502
+ size=(
503
+ original[i].shape[0],
504
+ original[i].shape[1],
505
+ )
506
+ ).permute(0, 2, 3, 1).to(original.device).to(original.dtype)
507
+ else:
508
+ # got mask from CutForInpaint
509
+ height, width, _ = original[i].shape
510
+ x0 = origin[i][0].item()
511
+ y0 = origin[i][1].item()
512
+
513
+ if mask[i].shape[0] < height or mask[i].shape[1] < width:
514
+ padded_mask = F.pad(input=mask[i], pad=(x0, width-x0-mask[i].shape[1],
515
+ y0, height-y0-mask[i].shape[0]), mode='constant', value=0)
516
+ else:
517
+ padded_mask = mask[i]
518
+ blurred_mask = transform(padded_mask[None,None,:,:]).to(original.device).to(original.dtype)
519
+ blurred.append(blurred_mask[0][0])
520
+
521
+ result = F.pad(input=inpaint[i], pad=(0, 0, x0, width-x0-inpaint[i].shape[1],
522
+ y0, height-y0-inpaint[i].shape[0]), mode='constant', value=0)
523
+ result = result[None,:,:,:].to(original.device).to(original.dtype)
524
+
525
+ ret.append(original[i] * (1.0 - blurred_mask[0][0][:,:,None]) + result[0] * blurred_mask[0][0][:,:,None])
526
+
527
+ return (torch.stack(ret), torch.stack(blurred), )
528
+
529
+
530
+ class CutForInpaint:
531
+
532
+ @classmethod
533
+ def INPUT_TYPES(s):
534
+ return {"required":
535
+ {
536
+ "image": ("IMAGE",),
537
+ "mask": ("MASK",),
538
+ "width": ("INT", {"default": 512, "min": 64, "max": 2048}),
539
+ "height": ("INT", {"default": 512, "min": 64, "max": 2048}),
540
+ },
541
+ }
542
+
543
+ CATEGORY = "inpaint"
544
+ RETURN_TYPES = ("IMAGE","MASK","VECTOR",)
545
+ RETURN_NAMES = ("image","mask","origin",)
546
+
547
+ FUNCTION = "cut_for_inpaint"
548
+
549
+ def cut_for_inpaint(self, image: torch.Tensor, mask: torch.Tensor, width: int, height: int):
550
+
551
+ image, mask = check_image_mask(image, mask, 'BrushNet')
552
+
553
+ ret = []
554
+ msk = []
555
+ org = []
556
+ for i in range(image.shape[0]):
557
+ x0, y0, w, h = cut_with_mask(mask[i], width, height)
558
+ ret.append((image[i][y0:y0+h,x0:x0+w,:]))
559
+ msk.append((mask[i][y0:y0+h,x0:x0+w]))
560
+ org.append(torch.IntTensor([x0,y0]))
561
+
562
+ return (torch.stack(ret), torch.stack(msk), torch.stack(org), )
563
+
564
+
565
+ #### Utility function
566
+
567
+ def get_files_with_extension(folder_name, extension=['safetensors']):
568
+
569
+ try:
570
+ inpaint_path = folder_paths.get_folder_paths(folder_name)[0]
571
+ except:
572
+ inpaint_path = os.path.join(folder_paths.models_dir, folder_name)
573
+
574
+ if not os.path.isdir(inpaint_path):
575
+ inpaint_path = os.path.join(folder_paths.base_path, inpaint_path)
576
+ if not os.path.isdir(inpaint_path):
577
+ return ([], '')
578
+ #raise Exception("Can't find", folder_name, " path")
579
+
580
+ while not inpaint_path[-1].isalpha():
581
+ inpaint_path = inpaint_path[:-1]
582
+
583
+ abs_list = []
584
+ for x in os.walk(inpaint_path):
585
+ for name in x[2]:
586
+ for ext in extension:
587
+ if ext in name:
588
+ abs_list.append(os.path.join(x[0], name))
589
+
590
+ abs_list = sorted(list(set(abs_list)))
591
+
592
+ names = []
593
+ for x in abs_list:
594
+ remain = x
595
+ y = ''
596
+ while remain != inpaint_path:
597
+ remain, folder = os.path.split(remain)
598
+ if len(y) > 0:
599
+ y = os.path.join(folder, y)
600
+ else:
601
+ y = folder
602
+ names.append(y)
603
+ return names, inpaint_path
604
+
605
+
606
+ def brushnet_blocks(sd):
607
+ brushnet_down_block = 0
608
+ brushnet_mid_block = 0
609
+ brushnet_up_block = 0
610
+ for key in sd:
611
+ if 'brushnet_down_block' in key:
612
+ brushnet_down_block += 1
613
+ if 'brushnet_mid_block' in key:
614
+ brushnet_mid_block += 1
615
+ if 'brushnet_up_block' in key:
616
+ brushnet_up_block += 1
617
+ return (brushnet_down_block, brushnet_mid_block, brushnet_up_block, len(sd))
618
+
619
+
620
+ # Check models compatibility
621
+ def check_compatibilty(model, brushnet):
622
+ is_SDXL = False
623
+ is_PP = False
624
+ if isinstance(model.model.model_config, comfy.supported_models.SD15):
625
+ print('Base model type: SD1.5')
626
+ is_SDXL = False
627
+ if brushnet["SDXL"]:
628
+ raise Exception("Base model is SD15, but BrushNet is SDXL type")
629
+ if brushnet["PP"]:
630
+ is_PP = True
631
+ elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
632
+ print('Base model type: SDXL')
633
+ is_SDXL = True
634
+ if not brushnet["SDXL"]:
635
+ raise Exception("Base model is SDXL, but BrushNet is SD15 type")
636
+ else:
637
+ print('Base model type: ', type(model.model.model_config))
638
+ raise Exception("Unsupported model type: " + str(type(model.model.model_config)))
639
+
640
+ return (is_SDXL, is_PP)
641
+
642
+
643
+ def check_image_mask(image, mask, name):
644
+ if len(image.shape) < 4:
645
+ # image tensor shape should be [B, H, W, C], but batch somehow is missing
646
+ image = image[None,:,:,:]
647
+
648
+ if len(mask.shape) > 3:
649
+ # mask tensor shape should be [B, H, W] but we get [B, H, W, C], image may be?
650
+ # take first mask, red channel
651
+ mask = (mask[:,:,:,0])[:,:,:]
652
+ elif len(mask.shape) < 3:
653
+ # mask tensor shape should be [B, H, W] but batch somehow is missing
654
+ mask = mask[None,:,:]
655
+
656
+ if image.shape[0] > mask.shape[0]:
657
+ print(name, "gets batch of images (%d) but only %d masks" % (image.shape[0], mask.shape[0]))
658
+ if mask.shape[0] == 1:
659
+ print(name, "will copy the mask to fill batch")
660
+ mask = torch.cat([mask] * image.shape[0], dim=0)
661
+ else:
662
+ print(name, "will add empty masks to fill batch")
663
+ empty_mask = torch.zeros([image.shape[0] - mask.shape[0], mask.shape[1], mask.shape[2]])
664
+ mask = torch.cat([mask, empty_mask], dim=0)
665
+ elif image.shape[0] < mask.shape[0]:
666
+ print(name, "gets batch of images (%d) but too many (%d) masks" % (image.shape[0], mask.shape[0]))
667
+ mask = mask[:image.shape[0],:,:]
668
+
669
+ return (image, mask)
670
+
671
+ # Prepare image and mask
672
+ def prepare_image(image, mask):
673
+
674
+ image, mask = check_image_mask(image, mask, 'BrushNet')
675
+
676
+ print("BrushNet image.shape =", image.shape, "mask.shape =", mask.shape)
677
+
678
+ if mask.shape[2] != image.shape[2] or mask.shape[1] != image.shape[1]:
679
+ raise Exception("Image and mask should be the same size")
680
+
681
+ # As a suggestion of inferno46n2 (https://github.com/nullquant/ComfyUI-BrushNet/issues/64)
682
+ mask = mask.round()
683
+
684
+ masked_image = image * (1.0 - mask[:,:,:,None])
685
+
686
+ return (masked_image, mask)
687
+
688
+
689
+ def cut_with_mask(mask, width, height):
690
+ iy, ix = (mask == 1).nonzero(as_tuple=True)
691
+
692
+ h0, w0 = mask.shape
693
+
694
+ if iy.numel() == 0:
695
+ x_c = w0 / 2.0
696
+ y_c = h0 / 2.0
697
+ else:
698
+ x_min = ix.min().item()
699
+ x_max = ix.max().item()
700
+ y_min = iy.min().item()
701
+ y_max = iy.max().item()
702
+
703
+ if x_max - x_min > width or y_max - y_min > height:
704
+ raise Exception("Mask is bigger than provided dimensions")
705
+
706
+ x_c = (x_min + x_max) / 2.0
707
+ y_c = (y_min + y_max) / 2.0
708
+
709
+ width2 = width / 2.0
710
+ height2 = height / 2.0
711
+
712
+ if w0 <= width:
713
+ x0 = 0
714
+ w = w0
715
+ else:
716
+ x0 = max(0, x_c - width2)
717
+ w = width
718
+ if x0 + width > w0:
719
+ x0 = w0 - width
720
+
721
+ if h0 <= height:
722
+ y0 = 0
723
+ h = h0
724
+ else:
725
+ y0 = max(0, y_c - height2)
726
+ h = height
727
+ if y0 + height > h0:
728
+ y0 = h0 - height
729
+
730
+ return (int(x0), int(y0), int(w), int(h))
731
+
732
+
733
+ # Prepare conditioning_latents
734
+ @torch.inference_mode()
735
+ def get_image_latents(masked_image, mask, vae, scaling_factor):
736
+ processed_image = masked_image.to(vae.device)
737
+ image_latents = vae.encode(processed_image[:,:,:,:3]) * scaling_factor
738
+ processed_mask = 1. - mask[:,None,:,:]
739
+ interpolated_mask = torch.nn.functional.interpolate(
740
+ processed_mask,
741
+ size=(
742
+ image_latents.shape[-2],
743
+ image_latents.shape[-1]
744
+ )
745
+ )
746
+ interpolated_mask = interpolated_mask.to(image_latents.device)
747
+
748
+ conditioning_latents = [image_latents, interpolated_mask]
749
+
750
+ print('BrushNet CL: image_latents shape =', image_latents.shape, 'interpolated_mask shape =', interpolated_mask.shape)
751
+
752
+ return conditioning_latents
753
+
754
+
755
+ # Main function where magic happens
756
+ @torch.inference_mode()
757
+ def brushnet_inference(x, timesteps, transformer_options):
758
+ if 'model_patch' not in transformer_options:
759
+ print('BrushNet inference: there is no model_patch key in transformer_options')
760
+ return ([], 0, [])
761
+ mp = transformer_options['model_patch']
762
+ if 'brushnet' not in mp:
763
+ print('BrushNet inference: there is no brushnet key in mdel_patch')
764
+ return ([], 0, [])
765
+ bo = mp['brushnet']
766
+ if 'model' not in bo:
767
+ print('BrushNet inference: there is no model key in brushnet')
768
+ return ([], 0, [])
769
+ brushnet = bo['model']
770
+ if not (isinstance(brushnet, BrushNetModel) or isinstance(brushnet, PowerPaintModel)):
771
+ print('BrushNet model is not a BrushNetModel class')
772
+ return ([], 0, [])
773
+
774
+ torch_dtype = bo['dtype']
775
+ cl_list = bo['latents']
776
+ brushnet_conditioning_scale, control_guidance_start, control_guidance_end = bo['controls']
777
+ pe = bo['prompt_embeds']
778
+ npe = bo['negative_prompt_embeds']
779
+ ppe, nppe, time_ids = bo['add_embeds']
780
+
781
+ #do_classifier_free_guidance = mp['free_guidance']
782
+ do_classifier_free_guidance = len(transformer_options['cond_or_uncond']) > 1
783
+
784
+ x = x.detach().clone()
785
+ x = x.to(torch_dtype).to(brushnet.device)
786
+
787
+ timesteps = timesteps.detach().clone()
788
+ timesteps = timesteps.to(torch_dtype).to(brushnet.device)
789
+
790
+ total_steps = mp['total_steps']
791
+ step = mp['step']
792
+
793
+ added_cond_kwargs = {}
794
+
795
+ if do_classifier_free_guidance and step == 0:
796
+ print('BrushNet inference: do_classifier_free_guidance is True')
797
+
798
+ sub_idx = None
799
+ if 'ad_params' in transformer_options and 'sub_idxs' in transformer_options['ad_params']:
800
+ sub_idx = transformer_options['ad_params']['sub_idxs']
801
+
802
+ # we have batch input images
803
+ batch = cl_list[0].shape[0]
804
+ # we have incoming latents
805
+ latents_incoming = x.shape[0]
806
+ # and we already got some
807
+ latents_got = bo['latent_id']
808
+ if step == 0 or batch > 1:
809
+ print('BrushNet inference, step = %d: image batch = %d, got %d latents, starting from %d' \
810
+ % (step, batch, latents_incoming, latents_got))
811
+
812
+ image_latents = []
813
+ masks = []
814
+ prompt_embeds = []
815
+ negative_prompt_embeds = []
816
+ pooled_prompt_embeds = []
817
+ negative_pooled_prompt_embeds = []
818
+ if sub_idx:
819
+ # AnimateDiff indexes detected
820
+ if step == 0:
821
+ print('BrushNet inference: AnimateDiff indexes detected and applied')
822
+
823
+ batch = len(sub_idx)
824
+
825
+ if do_classifier_free_guidance:
826
+ for i in sub_idx:
827
+ image_latents.append(cl_list[0][i][None,:,:,:])
828
+ masks.append(cl_list[1][i][None,:,:,:])
829
+ prompt_embeds.append(pe)
830
+ negative_prompt_embeds.append(npe)
831
+ pooled_prompt_embeds.append(ppe)
832
+ negative_pooled_prompt_embeds.append(nppe)
833
+ for i in sub_idx:
834
+ image_latents.append(cl_list[0][i][None,:,:,:])
835
+ masks.append(cl_list[1][i][None,:,:,:])
836
+ else:
837
+ for i in sub_idx:
838
+ image_latents.append(cl_list[0][i][None,:,:,:])
839
+ masks.append(cl_list[1][i][None,:,:,:])
840
+ prompt_embeds.append(pe)
841
+ pooled_prompt_embeds.append(ppe)
842
+ else:
843
+ # do_classifier_free_guidance = 2 passes, 1st pass is cond, 2nd is uncond
844
+ continue_batch = True
845
+ for i in range(latents_incoming):
846
+ number = latents_got + i
847
+ if number < batch:
848
+ # 1st pass, cond
849
+ image_latents.append(cl_list[0][number][None,:,:,:])
850
+ masks.append(cl_list[1][number][None,:,:,:])
851
+ prompt_embeds.append(pe)
852
+ pooled_prompt_embeds.append(ppe)
853
+ elif do_classifier_free_guidance and number < batch * 2:
854
+ # 2nd pass, uncond
855
+ image_latents.append(cl_list[0][number-batch][None,:,:,:])
856
+ masks.append(cl_list[1][number-batch][None,:,:,:])
857
+ negative_prompt_embeds.append(npe)
858
+ negative_pooled_prompt_embeds.append(nppe)
859
+ else:
860
+ # latent batch
861
+ image_latents.append(cl_list[0][0][None,:,:,:])
862
+ masks.append(cl_list[1][0][None,:,:,:])
863
+ prompt_embeds.append(pe)
864
+ pooled_prompt_embeds.append(ppe)
865
+ latents_got = -i
866
+ continue_batch = False
867
+
868
+ if continue_batch:
869
+ # we don't have full batch yet
870
+ if do_classifier_free_guidance:
871
+ if number < batch * 2 - 1:
872
+ bo['latent_id'] = number + 1
873
+ else:
874
+ bo['latent_id'] = 0
875
+ else:
876
+ if number < batch - 1:
877
+ bo['latent_id'] = number + 1
878
+ else:
879
+ bo['latent_id'] = 0
880
+ else:
881
+ bo['latent_id'] = 0
882
+
883
+ cl = []
884
+ for il, m in zip(image_latents, masks):
885
+ cl.append(torch.concat([il, m], dim=1))
886
+ cl2apply = torch.concat(cl, dim=0)
887
+
888
+ conditioning_latents = cl2apply.to(torch_dtype).to(brushnet.device)
889
+
890
+ prompt_embeds.extend(negative_prompt_embeds)
891
+ prompt_embeds = torch.concat(prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
892
+
893
+ if ppe is not None:
894
+ added_cond_kwargs = {}
895
+ added_cond_kwargs['time_ids'] = torch.concat([time_ids] * latents_incoming, dim = 0).to(torch_dtype).to(brushnet.device)
896
+
897
+ pooled_prompt_embeds.extend(negative_pooled_prompt_embeds)
898
+ pooled_prompt_embeds = torch.concat(pooled_prompt_embeds, dim=0).to(torch_dtype).to(brushnet.device)
899
+ added_cond_kwargs['text_embeds'] = pooled_prompt_embeds
900
+ else:
901
+ added_cond_kwargs = None
902
+
903
+ if x.shape[2] != conditioning_latents.shape[2] or x.shape[3] != conditioning_latents.shape[3]:
904
+ if step == 0:
905
+ print('BrushNet inference: image', conditioning_latents.shape, 'and latent', x.shape, 'have different size, resizing image')
906
+ conditioning_latents = torch.nn.functional.interpolate(
907
+ conditioning_latents, size=(
908
+ x.shape[2],
909
+ x.shape[3],
910
+ ), mode='bicubic',
911
+ ).to(torch_dtype).to(brushnet.device)
912
+
913
+ if step == 0:
914
+ print('BrushNet inference: sample', x.shape, ', CL', conditioning_latents.shape)
915
+
916
+ if step < control_guidance_start or step > control_guidance_end:
917
+ cond_scale = 0.0
918
+ else:
919
+ cond_scale = brushnet_conditioning_scale
920
+
921
+ return brushnet(x,
922
+ encoder_hidden_states=prompt_embeds,
923
+ brushnet_cond=conditioning_latents,
924
+ timestep = timesteps,
925
+ conditioning_scale=cond_scale,
926
+ guess_mode=False,
927
+ added_cond_kwargs=added_cond_kwargs,
928
+ return_dict=False,
929
+ )
930
+
931
+
932
+ # This is main patch function
933
+ def add_brushnet_patch(model, brushnet, torch_dtype, conditioning_latents,
934
+ controls,
935
+ prompt_embeds, negative_prompt_embeds,
936
+ pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids):
937
+
938
+ is_SDXL = isinstance(model.model.model_config, comfy.supported_models.SDXL)
939
+
940
+ if is_SDXL:
941
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
942
+ [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
943
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
944
+ [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
945
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
946
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
947
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
948
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
949
+ [8, comfy.ldm.modules.attention.SpatialTransformer]]
950
+ middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
951
+ output_blocks = [[0, comfy.ldm.modules.attention.SpatialTransformer],
952
+ [1, comfy.ldm.modules.attention.SpatialTransformer],
953
+ [2, comfy.ldm.modules.attention.SpatialTransformer],
954
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
955
+ [3, comfy.ldm.modules.attention.SpatialTransformer],
956
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
957
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
958
+ [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
959
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
960
+ [7, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
961
+ [8, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
962
+ else:
963
+ input_blocks = [[0, comfy.ops.disable_weight_init.Conv2d],
964
+ [1, comfy.ldm.modules.attention.SpatialTransformer],
965
+ [2, comfy.ldm.modules.attention.SpatialTransformer],
966
+ [3, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
967
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
968
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
969
+ [6, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
970
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
971
+ [8, comfy.ldm.modules.attention.SpatialTransformer],
972
+ [9, comfy.ldm.modules.diffusionmodules.openaimodel.Downsample],
973
+ [10, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
974
+ [11, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]]
975
+ middle_block = [0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock]
976
+ output_blocks = [[0, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
977
+ [1, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
978
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.ResBlock],
979
+ [2, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
980
+ [3, comfy.ldm.modules.attention.SpatialTransformer],
981
+ [4, comfy.ldm.modules.attention.SpatialTransformer],
982
+ [5, comfy.ldm.modules.attention.SpatialTransformer],
983
+ [5, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
984
+ [6, comfy.ldm.modules.attention.SpatialTransformer],
985
+ [7, comfy.ldm.modules.attention.SpatialTransformer],
986
+ [8, comfy.ldm.modules.attention.SpatialTransformer],
987
+ [8, comfy.ldm.modules.diffusionmodules.openaimodel.Upsample],
988
+ [9, comfy.ldm.modules.attention.SpatialTransformer],
989
+ [10, comfy.ldm.modules.attention.SpatialTransformer],
990
+ [11, comfy.ldm.modules.attention.SpatialTransformer]]
991
+
992
+ def last_layer_index(block, tp):
993
+ layer_list = []
994
+ for layer in block:
995
+ layer_list.append(type(layer))
996
+ layer_list.reverse()
997
+ if tp not in layer_list:
998
+ return -1, layer_list.reverse()
999
+ return len(layer_list) - 1 - layer_list.index(tp), layer_list
1000
+
1001
+ def brushnet_forward(model, x, timesteps, transformer_options, control):
1002
+ if 'brushnet' not in transformer_options['model_patch']:
1003
+ input_samples = []
1004
+ mid_sample = 0
1005
+ output_samples = []
1006
+ else:
1007
+ # brushnet inference
1008
+ input_samples, mid_sample, output_samples = brushnet_inference(x, timesteps, transformer_options)
1009
+
1010
+ # give additional samples to blocks
1011
+ for i, tp in input_blocks:
1012
+ idx, layer_list = last_layer_index(model.input_blocks[i], tp)
1013
+ if idx < 0:
1014
+ print("BrushNet can't find", tp, "layer in", i,"input block:", layer_list)
1015
+ continue
1016
+ model.input_blocks[i][idx].add_sample_after = input_samples.pop(0) if input_samples else 0
1017
+
1018
+ idx, layer_list = last_layer_index(model.middle_block, middle_block[1])
1019
+ if idx < 0:
1020
+ print("BrushNet can't find", middle_block[1], "layer in middle block", layer_list)
1021
+ model.middle_block[idx].add_sample_after = mid_sample
1022
+
1023
+ for i, tp in output_blocks:
1024
+ idx, layer_list = last_layer_index(model.output_blocks[i], tp)
1025
+ if idx < 0:
1026
+ print("BrushNet can't find", tp, "layer in", i,"outnput block:", layer_list)
1027
+ continue
1028
+ model.output_blocks[i][idx].add_sample_after = output_samples.pop(0) if output_samples else 0
1029
+
1030
+ patch_model_function_wrapper(model, brushnet_forward)
1031
+
1032
+ to = add_model_patch_option(model)
1033
+ mp = to['model_patch']
1034
+ if 'brushnet' not in mp:
1035
+ mp['brushnet'] = {}
1036
+ bo = mp['brushnet']
1037
+
1038
+ bo['model'] = brushnet
1039
+ bo['dtype'] = torch_dtype
1040
+ bo['latents'] = conditioning_latents
1041
+ bo['controls'] = controls
1042
+ bo['prompt_embeds'] = prompt_embeds
1043
+ bo['negative_prompt_embeds'] = negative_prompt_embeds
1044
+ bo['add_embeds'] = (pooled_prompt_embeds, negative_pooled_prompt_embeds, time_ids)
1045
+ bo['latent_id'] = 0
1046
+
1047
+ # patch layers `forward` so we can apply brushnet
1048
+ def forward_patched_by_brushnet(self, x, *args, **kwargs):
1049
+ h = self.original_forward(x, *args, **kwargs)
1050
+ if hasattr(self, 'add_sample_after') and type(self):
1051
+ to_add = self.add_sample_after
1052
+ if torch.is_tensor(to_add):
1053
+ # interpolate due to RAUNet
1054
+ if h.shape[2] != to_add.shape[2] or h.shape[3] != to_add.shape[3]:
1055
+ to_add = torch.nn.functional.interpolate(to_add, size=(h.shape[2], h.shape[3]), mode='bicubic')
1056
+ h += to_add.to(h.dtype).to(h.device)
1057
+ else:
1058
+ h += self.add_sample_after
1059
+ self.add_sample_after = 0
1060
+ return h
1061
+
1062
+ for i, block in enumerate(model.model.diffusion_model.input_blocks):
1063
+ for j, layer in enumerate(block):
1064
+ if not hasattr(layer, 'original_forward'):
1065
+ layer.original_forward = layer.forward
1066
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1067
+ layer.add_sample_after = 0
1068
+
1069
+ for j, layer in enumerate(model.model.diffusion_model.middle_block):
1070
+ if not hasattr(layer, 'original_forward'):
1071
+ layer.original_forward = layer.forward
1072
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1073
+ layer.add_sample_after = 0
1074
+
1075
+ for i, block in enumerate(model.model.diffusion_model.output_blocks):
1076
+ for j, layer in enumerate(block):
1077
+ if not hasattr(layer, 'original_forward'):
1078
+ layer.original_forward = layer.forward
1079
+ layer.forward = types.MethodType(forward_patched_by_brushnet, layer)
1080
+ layer.add_sample_after = 0
ComfyUI-BrushNet/model_patch.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import comfy
3
+
4
+
5
+ # Check and add 'model_patch' to model.model_options['transformer_options']
6
+ def add_model_patch_option(model):
7
+ if 'transformer_options' not in model.model_options:
8
+ model.model_options['transformer_options'] = {}
9
+ to = model.model_options['transformer_options']
10
+ if "model_patch" not in to:
11
+ to["model_patch"] = {}
12
+ return to
13
+
14
+
15
+ # Patch model with model_function_wrapper
16
+ def patch_model_function_wrapper(model, forward_patch):
17
+
18
+ def brushnet_model_function_wrapper(apply_model_method, options_dict):
19
+ to = options_dict['c']['transformer_options']
20
+
21
+ control = None
22
+ if 'control' in options_dict['c']:
23
+ control = options_dict['c']['control']
24
+
25
+ x = options_dict['input']
26
+ timestep = options_dict['timestep']
27
+
28
+ # check if there are patches to execute
29
+ if 'model_patch' not in to or 'forward' not in to['model_patch']:
30
+ return apply_model_method(x, timestep, **options_dict['c'])
31
+
32
+ mp = to['model_patch']
33
+ unet = mp['unet']
34
+
35
+ all_sigmas = mp['all_sigmas']
36
+ sigma = to['sigmas'][0].item()
37
+ total_steps = all_sigmas.shape[0] - 1
38
+ step = torch.argmin((all_sigmas - sigma).abs()).item()
39
+
40
+ mp['step'] = step
41
+ mp['total_steps'] = total_steps
42
+
43
+ # comfy.model_base.apply_model
44
+ xc = model.model.model_sampling.calculate_input(timestep, x)
45
+ if 'c_concat' in options_dict['c'] and options_dict['c']['c_concat'] is not None:
46
+ xc = torch.cat([xc] + [options_dict['c']['c_concat']], dim=1)
47
+ t = model.model.model_sampling.timestep(timestep).float()
48
+ # execute all patches
49
+ for method in mp['forward']:
50
+ method(unet, xc, t, to, control)
51
+
52
+ return apply_model_method(x, timestep, **options_dict['c'])
53
+
54
+ if "model_function_wrapper" in model.model_options and model.model_options["model_function_wrapper"]:
55
+ print('BrushNet is going to replace existing model_function_wrapper:', model.model_options["model_function_wrapper"])
56
+ model.set_model_unet_function_wrapper(brushnet_model_function_wrapper)
57
+
58
+ to = add_model_patch_option(model)
59
+ mp = to['model_patch']
60
+
61
+ if isinstance(model.model.model_config, comfy.supported_models.SD15):
62
+ mp['SDXL'] = False
63
+ elif isinstance(model.model.model_config, comfy.supported_models.SDXL):
64
+ mp['SDXL'] = True
65
+ else:
66
+ print('Base model type: ', type(model.model.model_config))
67
+ raise Exception("Unsupported model type: ", type(model.model.model_config))
68
+
69
+ if 'forward' not in mp:
70
+ mp['forward'] = [forward_patch]
71
+ else:
72
+ mp['forward'].append(forward_patch)
73
+
74
+ mp['unet'] = model.model.diffusion_model
75
+ mp['step'] = 0
76
+ mp['total_steps'] = 1
77
+
78
+ # apply patches to code
79
+ if comfy.samplers.sample.__doc__ is None or 'BrushNet' not in comfy.samplers.sample.__doc__:
80
+ comfy.samplers.original_sample = comfy.samplers.sample
81
+ comfy.samplers.sample = modified_sample
82
+
83
+ if comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__ is None or \
84
+ 'BrushNet' not in comfy.ldm.modules.diffusionmodules.openaimodel.apply_control.__doc__:
85
+ comfy.ldm.modules.diffusionmodules.openaimodel.original_apply_control = comfy.ldm.modules.diffusionmodules.openaimodel.apply_control
86
+ comfy.ldm.modules.diffusionmodules.openaimodel.apply_control = modified_apply_control
87
+
88
+
89
+ # Model needs current step number and cfg at inference step. It is possible to write a custom KSampler but I'd like to use ComfyUI's one.
90
+ # The first versions had modified_common_ksampler, but it broke custom KSampler nodes
91
+ def modified_sample(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={},
92
+ latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None):
93
+ '''
94
+ Modified by BrushNet nodes
95
+ '''
96
+ cfg_guider = comfy.samplers.CFGGuider(model)
97
+ cfg_guider.set_conds(positive, negative)
98
+ cfg_guider.set_cfg(cfg)
99
+
100
+ ### Modified part ######################################################################
101
+ #
102
+ to = add_model_patch_option(model)
103
+ to['model_patch']['all_sigmas'] = sigmas
104
+ #
105
+ #sigma_start = model.get_model_object("model_sampling").percent_to_sigma(start_at)
106
+ #sigma_end = model.get_model_object("model_sampling").percent_to_sigma(end_at)
107
+ #
108
+ #
109
+ #if math.isclose(cfg, 1.0) and model_options.get("disable_cfg1_optimization", False) == False:
110
+ # to['model_patch']['free_guidance'] = False
111
+ #else:
112
+ # to['model_patch']['free_guidance'] = True
113
+ #
114
+ #######################################################################################
115
+
116
+ return cfg_guider.sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
117
+
118
+
119
+ # To use Controlnet with RAUNet it is much easier to modify apply_control a little
120
+ def modified_apply_control(h, control, name):
121
+ '''
122
+ Modified by BrushNet nodes
123
+ '''
124
+ if control is not None and name in control and len(control[name]) > 0:
125
+ ctrl = control[name].pop()
126
+ if ctrl is not None:
127
+ if h.shape[2] != ctrl.shape[2] or h.shape[3] != ctrl.shape[3]:
128
+ ctrl = torch.nn.functional.interpolate(ctrl, size=(h.shape[2], h.shape[3]), mode='bicubic').to(h.dtype).to(h.device)
129
+ try:
130
+ h += ctrl
131
+ except:
132
+ print.warning("warning control could not be applied {} {}".format(h.shape, ctrl.shape))
133
+ return h
134
+
ComfyUI-BrushNet/raunet_nodes.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ import comfy
3
+
4
+ from .model_patch import add_model_patch_option, patch_model_function_wrapper
5
+
6
+
7
+
8
+ class RAUNet:
9
+
10
+ @classmethod
11
+ def INPUT_TYPES(s):
12
+ return {"required":
13
+ {
14
+ "model": ("MODEL",),
15
+ "du_start": ("INT", {"default": 0, "min": 0, "max": 10000}),
16
+ "du_end": ("INT", {"default": 4, "min": 0, "max": 10000}),
17
+ "xa_start": ("INT", {"default": 4, "min": 0, "max": 10000}),
18
+ "xa_end": ("INT", {"default": 10, "min": 0, "max": 10000}),
19
+ },
20
+ }
21
+
22
+ CATEGORY = "inpaint"
23
+ RETURN_TYPES = ("MODEL",)
24
+ RETURN_NAMES = ("model",)
25
+
26
+ FUNCTION = "model_update"
27
+
28
+ def model_update(self, model, du_start, du_end, xa_start, xa_end):
29
+
30
+ model = model.clone()
31
+
32
+ add_raunet_patch(model,
33
+ du_start,
34
+ du_end,
35
+ xa_start,
36
+ xa_end)
37
+
38
+ return (model,)
39
+
40
+
41
+ # This is main patch function
42
+ def add_raunet_patch(model, du_start, du_end, xa_start, xa_end):
43
+
44
+ def raunet_forward(model, x, timesteps, transformer_options, control):
45
+ if 'model_patch' not in transformer_options:
46
+ print("RAUNet: 'model_patch' not in transformer_options, skip")
47
+ return
48
+
49
+ mp = transformer_options['model_patch']
50
+ is_SDXL = mp['SDXL']
51
+
52
+ if is_SDXL and type(model.input_blocks[6][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
53
+ print('RAUNet: model is SDXL, but input[6] != Downsample, skip')
54
+ return
55
+
56
+ if not is_SDXL and type(model.input_blocks[3][0]) != comfy.ldm.modules.diffusionmodules.openaimodel.Downsample:
57
+ print('RAUNet: model is not SDXL, but input[3] != Downsample, skip')
58
+ return
59
+
60
+ if 'raunet' not in mp:
61
+ print('RAUNet: "raunet" not in model_patch options, skip')
62
+ return
63
+
64
+ if is_SDXL:
65
+ block = model.input_blocks[6][0]
66
+ else:
67
+ block = model.input_blocks[3][0]
68
+
69
+ total_steps = mp['total_steps']
70
+ step = mp['step']
71
+
72
+ ro = mp['raunet']
73
+ du_start = ro['du_start']
74
+ du_end = ro['du_end']
75
+
76
+ if step >= du_start and step < du_end:
77
+ block.op.stride = (4, 4)
78
+ block.op.padding = (2, 2)
79
+ block.op.dilation = (2, 2)
80
+ else:
81
+ block.op.stride = (2, 2)
82
+ block.op.padding = (1, 1)
83
+ block.op.dilation = (1, 1)
84
+
85
+ patch_model_function_wrapper(model, raunet_forward)
86
+ model.set_model_input_block_patch(in_xattn_patch)
87
+ model.set_model_output_block_patch(out_xattn_patch)
88
+
89
+ to = add_model_patch_option(model)
90
+ mp = to['model_patch']
91
+ if 'raunet' not in mp:
92
+ mp['raunet'] = {}
93
+ ro = mp['raunet']
94
+
95
+ ro['du_start'] = du_start
96
+ ro['du_end'] = du_end
97
+ ro['xa_start'] = xa_start
98
+ ro['xa_end'] = xa_end
99
+
100
+
101
+ def in_xattn_patch(h, transformer_options):
102
+ # both SDXL and SD15 = (input,4)
103
+ if transformer_options["block"] != ("input", 4):
104
+ # wrong block
105
+ return h
106
+ if 'model_patch' not in transformer_options:
107
+ print("RAUNet (i-x-p): 'model_patch' not in transformer_options")
108
+ return h
109
+ mp = transformer_options['model_patch']
110
+ if 'raunet' not in mp:
111
+ print("RAUNet (i-x-p): 'raunet' not in model_patch options")
112
+ return h
113
+
114
+ step = mp['step']
115
+ ro = mp['raunet']
116
+ xa_start = ro['xa_start']
117
+ xa_end = ro['xa_end']
118
+
119
+ if step < xa_start or step >= xa_end:
120
+ return h
121
+ h = F.avg_pool2d(h, kernel_size=(2,2))
122
+ return h
123
+
124
+
125
+ def out_xattn_patch(h, hsp, transformer_options):
126
+ if 'model_patch' not in transformer_options:
127
+ print("RAUNet (o-x-p): 'model_patch' not in transformer_options")
128
+ return h, hsp
129
+ mp = transformer_options['model_patch']
130
+ if 'raunet' not in mp:
131
+ print("RAUNet (o-x-p): 'raunet' not in model_patch options")
132
+ return h
133
+
134
+ step = mp['step']
135
+ is_SDXL = mp['SDXL']
136
+ ro = mp['raunet']
137
+ xa_start = ro['xa_start']
138
+ xa_end = ro['xa_end']
139
+
140
+ if is_SDXL:
141
+ if transformer_options["block"] != ("output", 5):
142
+ # wrong block
143
+ return h, hsp
144
+ else:
145
+ if transformer_options["block"] != ("output", 8):
146
+ # wrong block
147
+ return h, hsp
148
+
149
+ if step < xa_start or step >= xa_end:
150
+ return h, hsp
151
+ #error in hidiffusion codebase, size * 2 for particular sizes only
152
+ #re_size = (int(h.shape[-2] * 2), int(h.shape[-1] * 2))
153
+ re_size = (hsp.shape[-2], hsp.shape[-1])
154
+ h = F.interpolate(h, size=re_size, mode='bicubic')
155
+
156
+ return h, hsp
157
+
158
+
ComfyUI-BrushNet/requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ diffusers>=0.27.0
2
+ accelerate>=0.29.0
3
+ peft>=0.7.0
ComfyUI-Easy-Use/LICENSE ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ GNU GENERAL PUBLIC LICENSE
2
+ Version 3, 29 June 2007
3
+
4
+ Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
5
+ Everyone is permitted to copy and distribute verbatim copies
6
+ of this license document, but changing it is not allowed.
7
+
8
+ Preamble
9
+
10
+ The GNU General Public License is a free, copyleft license for
11
+ software and other kinds of works.
12
+
13
+ The licenses for most software and other practical works are designed
14
+ to take away your freedom to share and change the works. By contrast,
15
+ the GNU General Public License is intended to guarantee your freedom to
16
+ share and change all versions of a program--to make sure it remains free
17
+ software for all its users. We, the Free Software Foundation, use the
18
+ GNU General Public License for most of our software; it applies also to
19
+ any other work released this way by its authors. You can apply it to
20
+ your programs, too.
21
+
22
+ When we speak of free software, we are referring to freedom, not
23
+ price. Our General Public Licenses are designed to make sure that you
24
+ have the freedom to distribute copies of free software (and charge for
25
+ them if you wish), that you receive source code or can get it if you
26
+ want it, that you can change the software or use pieces of it in new
27
+ free programs, and that you know you can do these things.
28
+
29
+ To protect your rights, we need to prevent others from denying you
30
+ these rights or asking you to surrender the rights. Therefore, you have
31
+ certain responsibilities if you distribute copies of the software, or if
32
+ you modify it: responsibilities to respect the freedom of others.
33
+
34
+ For example, if you distribute copies of such a program, whether
35
+ gratis or for a fee, you must pass on to the recipients the same
36
+ freedoms that you received. You must make sure that they, too, receive
37
+ or can get the source code. And you must show them these terms so they
38
+ know their rights.
39
+
40
+ Developers that use the GNU GPL protect your rights with two steps:
41
+ (1) assert copyright on the software, and (2) offer you this License
42
+ giving you legal permission to copy, distribute and/or modify it.
43
+
44
+ For the developers' and authors' protection, the GPL clearly explains
45
+ that there is no warranty for this free software. For both users' and
46
+ authors' sake, the GPL requires that modified versions be marked as
47
+ changed, so that their problems will not be attributed erroneously to
48
+ authors of previous versions.
49
+
50
+ Some devices are designed to deny users access to install or run
51
+ modified versions of the software inside them, although the manufacturer
52
+ can do so. This is fundamentally incompatible with the aim of
53
+ protecting users' freedom to change the software. The systematic
54
+ pattern of such abuse occurs in the area of products for individuals to
55
+ use, which is precisely where it is most unacceptable. Therefore, we
56
+ have designed this version of the GPL to prohibit the practice for those
57
+ products. If such problems arise substantially in other domains, we
58
+ stand ready to extend this provision to those domains in future versions
59
+ of the GPL, as needed to protect the freedom of users.
60
+
61
+ Finally, every program is threatened constantly by software patents.
62
+ States should not allow patents to restrict development and use of
63
+ software on general-purpose computers, but in those that do, we wish to
64
+ avoid the special danger that patents applied to a free program could
65
+ make it effectively proprietary. To prevent this, the GPL assures that
66
+ patents cannot be used to render the program non-free.
67
+
68
+ The precise terms and conditions for copying, distribution and
69
+ modification follow.
70
+
71
+ TERMS AND CONDITIONS
72
+
73
+ 0. Definitions.
74
+
75
+ "This License" refers to version 3 of the GNU General Public License.
76
+
77
+ "Copyright" also means copyright-like laws that apply to other kinds of
78
+ works, such as semiconductor masks.
79
+
80
+ "The Program" refers to any copyrightable work licensed under this
81
+ License. Each licensee is addressed as "you". "Licensees" and
82
+ "recipients" may be individuals or organizations.
83
+
84
+ To "modify" a work means to copy from or adapt all or part of the work
85
+ in a fashion requiring copyright permission, other than the making of an
86
+ exact copy. The resulting work is called a "modified version" of the
87
+ earlier work or a work "based on" the earlier work.
88
+
89
+ A "covered work" means either the unmodified Program or a work based
90
+ on the Program.
91
+
92
+ To "propagate" a work means to do anything with it that, without
93
+ permission, would make you directly or secondarily liable for
94
+ infringement under applicable copyright law, except executing it on a
95
+ computer or modifying a private copy. Propagation includes copying,
96
+ distribution (with or without modification), making available to the
97
+ public, and in some countries other activities as well.
98
+
99
+ To "convey" a work means any kind of propagation that enables other
100
+ parties to make or receive copies. Mere interaction with a user through
101
+ a computer network, with no transfer of a copy, is not conveying.
102
+
103
+ An interactive user interface displays "Appropriate Legal Notices"
104
+ to the extent that it includes a convenient and prominently visible
105
+ feature that (1) displays an appropriate copyright notice, and (2)
106
+ tells the user that there is no warranty for the work (except to the
107
+ extent that warranties are provided), that licensees may convey the
108
+ work under this License, and how to view a copy of this License. If
109
+ the interface presents a list of user commands or options, such as a
110
+ menu, a prominent item in the list meets this criterion.
111
+
112
+ 1. Source Code.
113
+
114
+ The "source code" for a work means the preferred form of the work
115
+ for making modifications to it. "Object code" means any non-source
116
+ form of a work.
117
+
118
+ A "Standard Interface" means an interface that either is an official
119
+ standard defined by a recognized standards body, or, in the case of
120
+ interfaces specified for a particular programming language, one that
121
+ is widely used among developers working in that language.
122
+
123
+ The "System Libraries" of an executable work include anything, other
124
+ than the work as a whole, that (a) is included in the normal form of
125
+ packaging a Major Component, but which is not part of that Major
126
+ Component, and (b) serves only to enable use of the work with that
127
+ Major Component, or to implement a Standard Interface for which an
128
+ implementation is available to the public in source code form. A
129
+ "Major Component", in this context, means a major essential component
130
+ (kernel, window system, and so on) of the specific operating system
131
+ (if any) on which the executable work runs, or a compiler used to
132
+ produce the work, or an object code interpreter used to run it.
133
+
134
+ The "Corresponding Source" for a work in object code form means all
135
+ the source code needed to generate, install, and (for an executable
136
+ work) run the object code and to modify the work, including scripts to
137
+ control those activities. However, it does not include the work's
138
+ System Libraries, or general-purpose tools or generally available free
139
+ programs which are used unmodified in performing those activities but
140
+ which are not part of the work. For example, Corresponding Source
141
+ includes interface definition files associated with source files for
142
+ the work, and the source code for shared libraries and dynamically
143
+ linked subprograms that the work is specifically designed to require,
144
+ such as by intimate data communication or control flow between those
145
+ subprograms and other parts of the work.
146
+
147
+ The Corresponding Source need not include anything that users
148
+ can regenerate automatically from other parts of the Corresponding
149
+ Source.
150
+
151
+ The Corresponding Source for a work in source code form is that
152
+ same work.
153
+
154
+ 2. Basic Permissions.
155
+
156
+ All rights granted under this License are granted for the term of
157
+ copyright on the Program, and are irrevocable provided the stated
158
+ conditions are met. This License explicitly affirms your unlimited
159
+ permission to run the unmodified Program. The output from running a
160
+ covered work is covered by this License only if the output, given its
161
+ content, constitutes a covered work. This License acknowledges your
162
+ rights of fair use or other equivalent, as provided by copyright law.
163
+
164
+ You may make, run and propagate covered works that you do not
165
+ convey, without conditions so long as your license otherwise remains
166
+ in force. You may convey covered works to others for the sole purpose
167
+ of having them make modifications exclusively for you, or provide you
168
+ with facilities for running those works, provided that you comply with
169
+ the terms of this License in conveying all material for which you do
170
+ not control copyright. Those thus making or running the covered works
171
+ for you must do so exclusively on your behalf, under your direction
172
+ and control, on terms that prohibit them from making any copies of
173
+ your copyrighted material outside their relationship with you.
174
+
175
+ Conveying under any other circumstances is permitted solely under
176
+ the conditions stated below. Sublicensing is not allowed; section 10
177
+ makes it unnecessary.
178
+
179
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180
+
181
+ No covered work shall be deemed part of an effective technological
182
+ measure under any applicable law fulfilling obligations under article
183
+ 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184
+ similar laws prohibiting or restricting circumvention of such
185
+ measures.
186
+
187
+ When you convey a covered work, you waive any legal power to forbid
188
+ circumvention of technological measures to the extent such circumvention
189
+ is effected by exercising rights under this License with respect to
190
+ the covered work, and you disclaim any intention to limit operation or
191
+ modification of the work as a means of enforcing, against the work's
192
+ users, your or third parties' legal rights to forbid circumvention of
193
+ technological measures.
194
+
195
+ 4. Conveying Verbatim Copies.
196
+
197
+ You may convey verbatim copies of the Program's source code as you
198
+ receive it, in any medium, provided that you conspicuously and
199
+ appropriately publish on each copy an appropriate copyright notice;
200
+ keep intact all notices stating that this License and any
201
+ non-permissive terms added in accord with section 7 apply to the code;
202
+ keep intact all notices of the absence of any warranty; and give all
203
+ recipients a copy of this License along with the Program.
204
+
205
+ You may charge any price or no price for each copy that you convey,
206
+ and you may offer support or warranty protection for a fee.
207
+
208
+ 5. Conveying Modified Source Versions.
209
+
210
+ You may convey a work based on the Program, or the modifications to
211
+ produce it from the Program, in the form of source code under the
212
+ terms of section 4, provided that you also meet all of these conditions:
213
+
214
+ a) The work must carry prominent notices stating that you modified
215
+ it, and giving a relevant date.
216
+
217
+ b) The work must carry prominent notices stating that it is
218
+ released under this License and any conditions added under section
219
+ 7. This requirement modifies the requirement in section 4 to
220
+ "keep intact all notices".
221
+
222
+ c) You must license the entire work, as a whole, under this
223
+ License to anyone who comes into possession of a copy. This
224
+ License will therefore apply, along with any applicable section 7
225
+ additional terms, to the whole of the work, and all its parts,
226
+ regardless of how they are packaged. This License gives no
227
+ permission to license the work in any other way, but it does not
228
+ invalidate such permission if you have separately received it.
229
+
230
+ d) If the work has interactive user interfaces, each must display
231
+ Appropriate Legal Notices; however, if the Program has interactive
232
+ interfaces that do not display Appropriate Legal Notices, your
233
+ work need not make them do so.
234
+
235
+ A compilation of a covered work with other separate and independent
236
+ works, which are not by their nature extensions of the covered work,
237
+ and which are not combined with it such as to form a larger program,
238
+ in or on a volume of a storage or distribution medium, is called an
239
+ "aggregate" if the compilation and its resulting copyright are not
240
+ used to limit the access or legal rights of the compilation's users
241
+ beyond what the individual works permit. Inclusion of a covered work
242
+ in an aggregate does not cause this License to apply to the other
243
+ parts of the aggregate.
244
+
245
+ 6. Conveying Non-Source Forms.
246
+
247
+ You may convey a covered work in object code form under the terms
248
+ of sections 4 and 5, provided that you also convey the
249
+ machine-readable Corresponding Source under the terms of this License,
250
+ in one of these ways:
251
+
252
+ a) Convey the object code in, or embodied in, a physical product
253
+ (including a physical distribution medium), accompanied by the
254
+ Corresponding Source fixed on a durable physical medium
255
+ customarily used for software interchange.
256
+
257
+ b) Convey the object code in, or embodied in, a physical product
258
+ (including a physical distribution medium), accompanied by a
259
+ written offer, valid for at least three years and valid for as
260
+ long as you offer spare parts or customer support for that product
261
+ model, to give anyone who possesses the object code either (1) a
262
+ copy of the Corresponding Source for all the software in the
263
+ product that is covered by this License, on a durable physical
264
+ medium customarily used for software interchange, for a price no
265
+ more than your reasonable cost of physically performing this
266
+ conveying of source, or (2) access to copy the
267
+ Corresponding Source from a network server at no charge.
268
+
269
+ c) Convey individual copies of the object code with a copy of the
270
+ written offer to provide the Corresponding Source. This
271
+ alternative is allowed only occasionally and noncommercially, and
272
+ only if you received the object code with such an offer, in accord
273
+ with subsection 6b.
274
+
275
+ d) Convey the object code by offering access from a designated
276
+ place (gratis or for a charge), and offer equivalent access to the
277
+ Corresponding Source in the same way through the same place at no
278
+ further charge. You need not require recipients to copy the
279
+ Corresponding Source along with the object code. If the place to
280
+ copy the object code is a network server, the Corresponding Source
281
+ may be on a different server (operated by you or a third party)
282
+ that supports equivalent copying facilities, provided you maintain
283
+ clear directions next to the object code saying where to find the
284
+ Corresponding Source. Regardless of what server hosts the
285
+ Corresponding Source, you remain obligated to ensure that it is
286
+ available for as long as needed to satisfy these requirements.
287
+
288
+ e) Convey the object code using peer-to-peer transmission, provided
289
+ you inform other peers where the object code and Corresponding
290
+ Source of the work are being offered to the general public at no
291
+ charge under subsection 6d.
292
+
293
+ A separable portion of the object code, whose source code is excluded
294
+ from the Corresponding Source as a System Library, need not be
295
+ included in conveying the object code work.
296
+
297
+ A "User Product" is either (1) a "consumer product", which means any
298
+ tangible personal property which is normally used for personal, family,
299
+ or household purposes, or (2) anything designed or sold for incorporation
300
+ into a dwelling. In determining whether a product is a consumer product,
301
+ doubtful cases shall be resolved in favor of coverage. For a particular
302
+ product received by a particular user, "normally used" refers to a
303
+ typical or common use of that class of product, regardless of the status
304
+ of the particular user or of the way in which the particular user
305
+ actually uses, or expects or is expected to use, the product. A product
306
+ is a consumer product regardless of whether the product has substantial
307
+ commercial, industrial or non-consumer uses, unless such uses represent
308
+ the only significant mode of use of the product.
309
+
310
+ "Installation Information" for a User Product means any methods,
311
+ procedures, authorization keys, or other information required to install
312
+ and execute modified versions of a covered work in that User Product from
313
+ a modified version of its Corresponding Source. The information must
314
+ suffice to ensure that the continued functioning of the modified object
315
+ code is in no case prevented or interfered with solely because
316
+ modification has been made.
317
+
318
+ If you convey an object code work under this section in, or with, or
319
+ specifically for use in, a User Product, and the conveying occurs as
320
+ part of a transaction in which the right of possession and use of the
321
+ User Product is transferred to the recipient in perpetuity or for a
322
+ fixed term (regardless of how the transaction is characterized), the
323
+ Corresponding Source conveyed under this section must be accompanied
324
+ by the Installation Information. But this requirement does not apply
325
+ if neither you nor any third party retains the ability to install
326
+ modified object code on the User Product (for example, the work has
327
+ been installed in ROM).
328
+
329
+ The requirement to provide Installation Information does not include a
330
+ requirement to continue to provide support service, warranty, or updates
331
+ for a work that has been modified or installed by the recipient, or for
332
+ the User Product in which it has been modified or installed. Access to a
333
+ network may be denied when the modification itself materially and
334
+ adversely affects the operation of the network or violates the rules and
335
+ protocols for communication across the network.
336
+
337
+ Corresponding Source conveyed, and Installation Information provided,
338
+ in accord with this section must be in a format that is publicly
339
+ documented (and with an implementation available to the public in
340
+ source code form), and must require no special password or key for
341
+ unpacking, reading or copying.
342
+
343
+ 7. Additional Terms.
344
+
345
+ "Additional permissions" are terms that supplement the terms of this
346
+ License by making exceptions from one or more of its conditions.
347
+ Additional permissions that are applicable to the entire Program shall
348
+ be treated as though they were included in this License, to the extent
349
+ that they are valid under applicable law. If additional permissions
350
+ apply only to part of the Program, that part may be used separately
351
+ under those permissions, but the entire Program remains governed by
352
+ this License without regard to the additional permissions.
353
+
354
+ When you convey a copy of a covered work, you may at your option
355
+ remove any additional permissions from that copy, or from any part of
356
+ it. (Additional permissions may be written to require their own
357
+ removal in certain cases when you modify the work.) You may place
358
+ additional permissions on material, added by you to a covered work,
359
+ for which you have or can give appropriate copyright permission.
360
+
361
+ Notwithstanding any other provision of this License, for material you
362
+ add to a covered work, you may (if authorized by the copyright holders of
363
+ that material) supplement the terms of this License with terms:
364
+
365
+ a) Disclaiming warranty or limiting liability differently from the
366
+ terms of sections 15 and 16 of this License; or
367
+
368
+ b) Requiring preservation of specified reasonable legal notices or
369
+ author attributions in that material or in the Appropriate Legal
370
+ Notices displayed by works containing it; or
371
+
372
+ c) Prohibiting misrepresentation of the origin of that material, or
373
+ requiring that modified versions of such material be marked in
374
+ reasonable ways as different from the original version; or
375
+
376
+ d) Limiting the use for publicity purposes of names of licensors or
377
+ authors of the material; or
378
+
379
+ e) Declining to grant rights under trademark law for use of some
380
+ trade names, trademarks, or service marks; or
381
+
382
+ f) Requiring indemnification of licensors and authors of that
383
+ material by anyone who conveys the material (or modified versions of
384
+ it) with contractual assumptions of liability to the recipient, for
385
+ any liability that these contractual assumptions directly impose on
386
+ those licensors and authors.
387
+
388
+ All other non-permissive additional terms are considered "further
389
+ restrictions" within the meaning of section 10. If the Program as you
390
+ received it, or any part of it, contains a notice stating that it is
391
+ governed by this License along with a term that is a further
392
+ restriction, you may remove that term. If a license document contains
393
+ a further restriction but permits relicensing or conveying under this
394
+ License, you may add to a covered work material governed by the terms
395
+ of that license document, provided that the further restriction does
396
+ not survive such relicensing or conveying.
397
+
398
+ If you add terms to a covered work in accord with this section, you
399
+ must place, in the relevant source files, a statement of the
400
+ additional terms that apply to those files, or a notice indicating
401
+ where to find the applicable terms.
402
+
403
+ Additional terms, permissive or non-permissive, may be stated in the
404
+ form of a separately written license, or stated as exceptions;
405
+ the above requirements apply either way.
406
+
407
+ 8. Termination.
408
+
409
+ You may not propagate or modify a covered work except as expressly
410
+ provided under this License. Any attempt otherwise to propagate or
411
+ modify it is void, and will automatically terminate your rights under
412
+ this License (including any patent licenses granted under the third
413
+ paragraph of section 11).
414
+
415
+ However, if you cease all violation of this License, then your
416
+ license from a particular copyright holder is reinstated (a)
417
+ provisionally, unless and until the copyright holder explicitly and
418
+ finally terminates your license, and (b) permanently, if the copyright
419
+ holder fails to notify you of the violation by some reasonable means
420
+ prior to 60 days after the cessation.
421
+
422
+ Moreover, your license from a particular copyright holder is
423
+ reinstated permanently if the copyright holder notifies you of the
424
+ violation by some reasonable means, this is the first time you have
425
+ received notice of violation of this License (for any work) from that
426
+ copyright holder, and you cure the violation prior to 30 days after
427
+ your receipt of the notice.
428
+
429
+ Termination of your rights under this section does not terminate the
430
+ licenses of parties who have received copies or rights from you under
431
+ this License. If your rights have been terminated and not permanently
432
+ reinstated, you do not qualify to receive new licenses for the same
433
+ material under section 10.
434
+
435
+ 9. Acceptance Not Required for Having Copies.
436
+
437
+ You are not required to accept this License in order to receive or
438
+ run a copy of the Program. Ancillary propagation of a covered work
439
+ occurring solely as a consequence of using peer-to-peer transmission
440
+ to receive a copy likewise does not require acceptance. However,
441
+ nothing other than this License grants you permission to propagate or
442
+ modify any covered work. These actions infringe copyright if you do
443
+ not accept this License. Therefore, by modifying or propagating a
444
+ covered work, you indicate your acceptance of this License to do so.
445
+
446
+ 10. Automatic Licensing of Downstream Recipients.
447
+
448
+ Each time you convey a covered work, the recipient automatically
449
+ receives a license from the original licensors, to run, modify and
450
+ propagate that work, subject to this License. You are not responsible
451
+ for enforcing compliance by third parties with this License.
452
+
453
+ An "entity transaction" is a transaction transferring control of an
454
+ organization, or substantially all assets of one, or subdividing an
455
+ organization, or merging organizations. If propagation of a covered
456
+ work results from an entity transaction, each party to that
457
+ transaction who receives a copy of the work also receives whatever
458
+ licenses to the work the party's predecessor in interest had or could
459
+ give under the previous paragraph, plus a right to possession of the
460
+ Corresponding Source of the work from the predecessor in interest, if
461
+ the predecessor has it or can get it with reasonable efforts.
462
+
463
+ You may not impose any further restrictions on the exercise of the
464
+ rights granted or affirmed under this License. For example, you may
465
+ not impose a license fee, royalty, or other charge for exercise of
466
+ rights granted under this License, and you may not initiate litigation
467
+ (including a cross-claim or counterclaim in a lawsuit) alleging that
468
+ any patent claim is infringed by making, using, selling, offering for
469
+ sale, or importing the Program or any portion of it.
470
+
471
+ 11. Patents.
472
+
473
+ A "contributor" is a copyright holder who authorizes use under this
474
+ License of the Program or a work on which the Program is based. The
475
+ work thus licensed is called the contributor's "contributor version".
476
+
477
+ A contributor's "essential patent claims" are all patent claims
478
+ owned or controlled by the contributor, whether already acquired or
479
+ hereafter acquired, that would be infringed by some manner, permitted
480
+ by this License, of making, using, or selling its contributor version,
481
+ but do not include claims that would be infringed only as a
482
+ consequence of further modification of the contributor version. For
483
+ purposes of this definition, "control" includes the right to grant
484
+ patent sublicenses in a manner consistent with the requirements of
485
+ this License.
486
+
487
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
488
+ patent license under the contributor's essential patent claims, to
489
+ make, use, sell, offer for sale, import and otherwise run, modify and
490
+ propagate the contents of its contributor version.
491
+
492
+ In the following three paragraphs, a "patent license" is any express
493
+ agreement or commitment, however denominated, not to enforce a patent
494
+ (such as an express permission to practice a patent or covenant not to
495
+ sue for patent infringement). To "grant" such a patent license to a
496
+ party means to make such an agreement or commitment not to enforce a
497
+ patent against the party.
498
+
499
+ If you convey a covered work, knowingly relying on a patent license,
500
+ and the Corresponding Source of the work is not available for anyone
501
+ to copy, free of charge and under the terms of this License, through a
502
+ publicly available network server or other readily accessible means,
503
+ then you must either (1) cause the Corresponding Source to be so
504
+ available, or (2) arrange to deprive yourself of the benefit of the
505
+ patent license for this particular work, or (3) arrange, in a manner
506
+ consistent with the requirements of this License, to extend the patent
507
+ license to downstream recipients. "Knowingly relying" means you have
508
+ actual knowledge that, but for the patent license, your conveying the
509
+ covered work in a country, or your recipient's use of the covered work
510
+ in a country, would infringe one or more identifiable patents in that
511
+ country that you have reason to believe are valid.
512
+
513
+ If, pursuant to or in connection with a single transaction or
514
+ arrangement, you convey, or propagate by procuring conveyance of, a
515
+ covered work, and grant a patent license to some of the parties
516
+ receiving the covered work authorizing them to use, propagate, modify
517
+ or convey a specific copy of the covered work, then the patent license
518
+ you grant is automatically extended to all recipients of the covered
519
+ work and works based on it.
520
+
521
+ A patent license is "discriminatory" if it does not include within
522
+ the scope of its coverage, prohibits the exercise of, or is
523
+ conditioned on the non-exercise of one or more of the rights that are
524
+ specifically granted under this License. You may not convey a covered
525
+ work if you are a party to an arrangement with a third party that is
526
+ in the business of distributing software, under which you make payment
527
+ to the third party based on the extent of your activity of conveying
528
+ the work, and under which the third party grants, to any of the
529
+ parties who would receive the covered work from you, a discriminatory
530
+ patent license (a) in connection with copies of the covered work
531
+ conveyed by you (or copies made from those copies), or (b) primarily
532
+ for and in connection with specific products or compilations that
533
+ contain the covered work, unless you entered into that arrangement,
534
+ or that patent license was granted, prior to 28 March 2007.
535
+
536
+ Nothing in this License shall be construed as excluding or limiting
537
+ any implied license or other defenses to infringement that may
538
+ otherwise be available to you under applicable patent law.
539
+
540
+ 12. No Surrender of Others' Freedom.
541
+
542
+ If conditions are imposed on you (whether by court order, agreement or
543
+ otherwise) that contradict the conditions of this License, they do not
544
+ excuse you from the conditions of this License. If you cannot convey a
545
+ covered work so as to satisfy simultaneously your obligations under this
546
+ License and any other pertinent obligations, then as a consequence you may
547
+ not convey it at all. For example, if you agree to terms that obligate you
548
+ to collect a royalty for further conveying from those to whom you convey
549
+ the Program, the only way you could satisfy both those terms and this
550
+ License would be to refrain entirely from conveying the Program.
551
+
552
+ 13. Use with the GNU Affero General Public License.
553
+
554
+ Notwithstanding any other provision of this License, you have
555
+ permission to link or combine any covered work with a work licensed
556
+ under version 3 of the GNU Affero General Public License into a single
557
+ combined work, and to convey the resulting work. The terms of this
558
+ License will continue to apply to the part which is the covered work,
559
+ but the special requirements of the GNU Affero General Public License,
560
+ section 13, concerning interaction through a network will apply to the
561
+ combination as such.
562
+
563
+ 14. Revised Versions of this License.
564
+
565
+ The Free Software Foundation may publish revised and/or new versions of
566
+ the GNU General Public License from time to time. Such new versions will
567
+ be similar in spirit to the present version, but may differ in detail to
568
+ address new problems or concerns.
569
+
570
+ Each version is given a distinguishing version number. If the
571
+ Program specifies that a certain numbered version of the GNU General
572
+ Public License "or any later version" applies to it, you have the
573
+ option of following the terms and conditions either of that numbered
574
+ version or of any later version published by the Free Software
575
+ Foundation. If the Program does not specify a version number of the
576
+ GNU General Public License, you may choose any version ever published
577
+ by the Free Software Foundation.
578
+
579
+ If the Program specifies that a proxy can decide which future
580
+ versions of the GNU General Public License can be used, that proxy's
581
+ public statement of acceptance of a version permanently authorizes you
582
+ to choose that version for the Program.
583
+
584
+ Later license versions may give you additional or different
585
+ permissions. However, no additional obligations are imposed on any
586
+ author or copyright holder as a result of your choosing to follow a
587
+ later version.
588
+
589
+ 15. Disclaimer of Warranty.
590
+
591
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592
+ APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593
+ HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594
+ OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596
+ PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597
+ IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598
+ ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599
+
600
+ 16. Limitation of Liability.
601
+
602
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603
+ WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604
+ THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605
+ GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606
+ USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607
+ DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608
+ PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609
+ EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610
+ SUCH DAMAGES.
611
+
612
+ 17. Interpretation of Sections 15 and 16.
613
+
614
+ If the disclaimer of warranty and limitation of liability provided
615
+ above cannot be given local legal effect according to their terms,
616
+ reviewing courts shall apply local law that most closely approximates
617
+ an absolute waiver of all civil liability in connection with the
618
+ Program, unless a warranty or assumption of liability accompanies a
619
+ copy of the Program in return for a fee.
620
+
621
+ END OF TERMS AND CONDITIONS
622
+
623
+ How to Apply These Terms to Your New Programs
624
+
625
+ If you develop a new program, and you want it to be of the greatest
626
+ possible use to the public, the best way to achieve this is to make it
627
+ free software which everyone can redistribute and change under these terms.
628
+
629
+ To do so, attach the following notices to the program. It is safest
630
+ to attach them to the start of each source file to most effectively
631
+ state the exclusion of warranty; and each file should have at least
632
+ the "copyright" line and a pointer to where the full notice is found.
633
+
634
+ <one line to give the program's name and a brief idea of what it does.>
635
+ Copyright (C) <year> <name of author>
636
+
637
+ This program is free software: you can redistribute it and/or modify
638
+ it under the terms of the GNU General Public License as published by
639
+ the Free Software Foundation, either version 3 of the License, or
640
+ (at your option) any later version.
641
+
642
+ This program is distributed in the hope that it will be useful,
643
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
644
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645
+ GNU General Public License for more details.
646
+
647
+ You should have received a copy of the GNU General Public License
648
+ along with this program. If not, see <https://www.gnu.org/licenses/>.
649
+
650
+ Also add information on how to contact you by electronic and paper mail.
651
+
652
+ If the program does terminal interaction, make it output a short
653
+ notice like this when it starts in an interactive mode:
654
+
655
+ <program> Copyright (C) <year> <name of author>
656
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657
+ This is free software, and you are welcome to redistribute it
658
+ under certain conditions; type `show c' for details.
659
+
660
+ The hypothetical commands `show w' and `show c' should show the appropriate
661
+ parts of the General Public License. Of course, your program's commands
662
+ might be different; for a GUI interface, you would use an "about box".
663
+
664
+ You should also get your employer (if you work as a programmer) or school,
665
+ if any, to sign a "copyright disclaimer" for the program, if necessary.
666
+ For more information on this, and how to apply and follow the GNU GPL, see
667
+ <https://www.gnu.org/licenses/>.
668
+
669
+ The GNU General Public License does not permit incorporating your program
670
+ into proprietary programs. If your program is a subroutine library, you
671
+ may consider it more useful to permit linking proprietary applications with
672
+ the library. If this is what you want to do, use the GNU Lesser General
673
+ Public License instead of this License. But first, please read
674
+ <https://www.gnu.org/licenses/why-not-lgpl.html>.
ComfyUI-Easy-Use/README.ZH_CN.md ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![comfyui-easy-use](https://github.com/user-attachments/assets/9b7a5e44-f5e2-4c27-aed2-d0e6b50c46bb)
2
+
3
+ <div align="center">
4
+ <a href="https://space.bilibili.com/1840885116">视频介绍</a> |
5
+ 文档 (康明孙) |
6
+ <a href="https://github.com/yolain/ComfyUI-Yolain-Workflows">工作流合集</a> |
7
+ <a href="#%EF%B8%8F-donation">捐助</a>
8
+ <br><br>
9
+ <a href="./README.md"><img src="https://img.shields.io/badge/🇬🇧English-e9e9e9"></a>
10
+ <a href="./README.ZH_CN.md"><img src="https://img.shields.io/badge/🇨🇳中文简体-0b8cf5"></a>
11
+ </div>
12
+
13
+ **ComfyUI-Easy-Use** 是一个化繁为简的节点整合包, 在 [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) 的基础上进行延展,并针对了诸多主流的节点包做了整合与优化,以达到更快更方便使用ComfyUI的目的,在保证自由度的同时还原了本属于Stable Diffusion的极致畅快出图体验。
14
+
15
+ ## 👨🏻‍🎨 特色介绍
16
+
17
+ - 沿用了 [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) 的思路,大大减少了折腾工作流的时间成本。
18
+ - UI界面美化,首次安装的用户,如需使用UI主题,请在 Settings -> Color Palette 中自行切换主题并**刷新页面**即可
19
+ - 增加了预采样参数配置的节点,可与采样节点分离,更方便预览。
20
+ - 支持通配符与Lora的提示词节点,如需使用Lora Block Weight用法,需先保证自定义节点包中安装了 [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack)
21
+ - 可多选的风格化提示词选择器,默认是Fooocus的样式json,可自定义json放在styles底下,samples文件夹里可放预览图(名称和name一致,图片文件名如有空格需转为下划线'_')
22
+ - 加载器可开启A1111提示词风格模式,可重现与webui生成近乎相同的图像,需先安装 [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes)
23
+ - 可使用`easy latentNoisy`或`easy preSamplingNoiseIn`节点实现对潜空间的噪声注入
24
+ - 简化 SD1.x、SD2.x、SDXL、SVD、Zero123等流程
25
+ - 简化 Stable Cascade [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#1-13-stable-cascade)
26
+ - 简化 Layer Diffuse [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-3-layerdiffusion)
27
+ - 简化 InstantID [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-2-instantid), 需先保证自定义节点包中安装了 [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
28
+ - 简化 IPAdapter, 需先保证自定义节点包中安装最新版v2的 [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus)
29
+ - 扩展 XYplot 的可用性
30
+ - 整合了Fooocus Inpaint功能
31
+ - 整合了常用的逻辑计算、转换类型、展示所有类型等
32
+ - 支持节点上checkpoint、lora模型子目录分类及预览图 (请在设置中开启上下文菜单嵌套子目录)
33
+ - 支持BriaAI的RMBG-1.4模型的背景去除节点,[技术参考](https://huggingface.co/briaai/RMBG-1.4)
34
+ - 支持 强制清理comfyUI模型显存占用
35
+ - 支持Stable Diffusion 3 多账号API节点
36
+ - 支持IC-Light的应用 [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-5-ic-light) | [代码整合来源](https://github.com/huchenlei/ComfyUI-IC-Light) | [技术参考](https://github.com/lllyasviel/IC-Light)
37
+ - 中文提示词自动识别,使用[opus-mt-zh-en模型](https://huggingface.co/Helsinki-NLP/opus-mt-zh-en)
38
+ - 支持 sd3 模型
39
+ - 支持 kolors 模型
40
+ - 支持 flux 模型
41
+
42
+ ## 👨🏻‍🔧 安装
43
+
44
+ 1. 将存储库克隆到 **custom_nodes** 目录并安装依赖
45
+ ```shell
46
+ #1. git下载
47
+ git clone https://github.com/yolain/ComfyUI-Easy-Use
48
+ #2. 安装依赖
49
+ 双击install.bat安装依赖
50
+ ```
51
+
52
+ ## 👨🏻‍🚀 计划
53
+
54
+ - [x] 更新便于维护的新前端代码
55
+ - [x] 使用sass维护css样式
56
+ - [x] 对原有扩展进行优化
57
+ - [x] 增加新的组件(如节点时间统计等)
58
+ - [ ] 在[ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Yolain-Workflows)中上传更多的工作流(如kolors,sd3等),并更新english版本的readme
59
+ - [ ] 更详细功能介绍的 gitbook
60
+
61
+ ## 📜 更新日志
62
+
63
+ **v1.2.2**
64
+
65
+ - 增加 v2 版本新前端代码
66
+ - 增加 `easy fluxLoader`
67
+ - 增加 `controlnetApply` 相关节点对sd3和hunyuanDiT的支持
68
+
69
+ **v1.2.1**
70
+
71
+ - 增加 `easy ipadapterApplyFaceIDKolors`
72
+ - `easy ipadapterApply` 和 `easy ipadapterApplyADV` 增加 **PLUS (kolors genernal)** 和 **FACEID PLUS KOLORS** 预置项
73
+ - `easy imageRemBg` 增加 **inspyrenet** 选项
74
+ - 增加 `easy controlnetLoader++`
75
+ - 去除 `easy positive` `easy negative` 等prompt节点的自动将中文翻译功能,自动翻译仅在 `easy a1111Loader` 等不支持中文TE的加载器中生效
76
+ - 增加 `easy kolorsLoader` - 可灵加载器,参考了 [MinusZoneAI](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) 和 [kijai](https://github.com/kijai/ComfyUI-KwaiKolorsWrapper) 的代码。
77
+
78
+ **v1.2.0**
79
+
80
+ - 增加 `easy pulIDApply` 和 `easy pulIDApplyADV`
81
+ - 增加 `easy hunyuanDiTLoader` 和 `easy pixArtLoader`
82
+ - 当新菜单的位置在上或者下时增加上 crystools 的显示,推荐开两个就好(如果后续crystools有更新UI适配我可能会删除掉)
83
+ - 增加 **easy sliderControl** - 滑块控制节点,当前可用于控制ipadapterMS的参数 (双击滑块可重置为默认值)
84
+ - 增加 **layer_weights** 属性在 `easy ipadapterApplyADV` 节点
85
+
86
+ **v1.1.9**
87
+
88
+ - 增加 新的调度器 **gitsScheduler**
89
+ - 增加 `easy imageBatchToImageList` 和 `easy imageListToImageBatch` (修复Impact版的一点小问题)
90
+ - 递归模型子目录嵌套
91
+ - 支持 sd3 模型
92
+ - 增加 `easy applyInpaint` - 局部重绘全模式节点 (相比与之前的kSamplerInpating节点逻辑会更合理些)
93
+
94
+ **v1.1.8**
95
+
96
+ - 增加中文提示词自动翻译,使用[opus-mt-zh-en模型](https://huggingface.co/Helsinki-NLP/opus-mt-zh-en), 默认已对wildcard、lora正则处理, 其他需要保留的中文,可使用`@你的提示词@`包裹 (若依赖安装完成后报错, 请重启),测算大约会占0.3GB显存
97
+ - 增加 `easy controlnetStack` - controlnet堆
98
+ - 增加 `easy applyBrushNet` - [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4brushnet_1.1.8.json)
99
+ - 增加 `easy applyPowerPaint` - [示例参考](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4powerpaint_outpaint_1.1.8.json)
100
+
101
+ **v1.1.7**
102
+
103
+ - 修复 一些模型(如controlnet模型等)未成功写入缓存,导致修改前置节点束参数(如提示词)需要二次载入模型的问题
104
+ - 增加 `easy prompt` - 主体和光影预置项,后期可能会调整
105
+ - 增加 `easy icLightApply` - 重绘光影, 从[ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light)优化
106
+ - 增加 `easy imageSplitGrid` - 图像网格拆分
107
+ - `easy kSamplerInpainting` 的 **additional** 属性增加差异扩散和brushnet等相关选项
108
+ - 增加 brushnet模型加载的支持 - [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)
109
+ - 增加 `easy applyFooocusInpaint` - Fooocus内补节点 替代原有的 FooocusInpaintLoader
110
+ - 移除 `easy fooocusInpaintLoader` - 容易bug,不再使用
111
+ - 修改 easy kSampler等采样器中并联的model 不再替换输出中pipe里的model
112
+
113
+ **v1.1.6**
114
+
115
+ - 增加步调齐整适配 - 在所有的预采样和全采样器节点中的 调度器(schedulder) 增加了 **alignYourSteps** 选项
116
+ - `easy kSampler` 和 `easy fullkSampler` 的 **image_output** 增加 **Preview&Choose**选项
117
+ - 增加 `easy styleAlignedBatchAlign` - 风格对齐 [style_aligned_comfy](https://github.com/brianfitzgerald/style_aligned_comfy)
118
+ - 增加 `easy ckptNames`
119
+ - 增加 `easy controlnetNames`
120
+ - 增加 `easy imagesSplitimage` - 批次图像拆分单张
121
+ - 增加 `easy imageCount` - 图像数量
122
+ - 增加 `easy textSwitch` - 文字切换
123
+
124
+ **v1.1.5**
125
+
126
+ - 重写 `easy cleanGPUUsed` - 可强制清理comfyUI的模型显存占用
127
+ - 增加 `easy humanSegmentation` - 多类分割、人像分割
128
+ - 增加 `easy imageColorMatch`
129
+ - 增加 `easy ipadapterApplyRegional`
130
+ - 增加 `easy ipadapterApplyFromParams`
131
+ - 增加 `easy imageInterrogator` - 图像反推
132
+ - 增加 `easy stableDiffusion3API` - 简易的Stable Diffusion 3 多账号API节点
133
+
134
+ **v1.1.4**
135
+
136
+ - 增加 `easy imageChooser` - 从[cg-image-picker](https://github.com/chrisgoringe/cg-image-picker)简化的图片选择器
137
+ - 增加 `easy preSamplingCustom` - 自定义预采样,可支持cosXL-edit
138
+ - 增加 `easy ipadapterStyleComposition`
139
+ - 增加 在Loaders上右键菜单可查看 checkpoints、lora 信息
140
+ - 修复 `easy preSamplingNoiseIn`、`easy latentNoisy`、`east Unsampler` 以兼容ComfyUI Revision>=2098 [0542088e] 以上版本
141
+ - 修复 FooocusInpaint修改ModelPatcher计算权重引发的问题,理应在生成model后重置ModelPatcher为默认值
142
+
143
+ **v1.1.3**
144
+
145
+ - `easy ipadapterApply` 增加 **COMPOSITION** 预置项
146
+ - 增加 对[ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) lora模型 的加载支持
147
+ - 增加 `easy promptLine`
148
+ - 增加 `easy promptReplace`
149
+ - 增加 `easy promptConcat`
150
+ - `easy wildcards` 增加 **multiline_mode**属性
151
+ - 增加 当节点需要下载模型时,若huggingface连接超时,会切换至镜像地址下载模型
152
+
153
+ <details>
154
+ <summary><b>v1.1.2</b></summary>
155
+
156
+ - 改写 EasyUse 相关节点的部分插槽推荐节点
157
+ - 增加 **启用上下文菜单自动嵌套子目录** 设置项,默认为启用状态,可分类子目录及checkpoints、loras预览图
158
+ - 增加 `easy sv3dLoader`
159
+ - 增加 `easy dynamiCrafterLoader`
160
+ - 增加 `easy ipadapterApply`
161
+ - 增加 `easy ipadapterApplyADV`
162
+ - 增加 `easy ipadapterApplyEncoder`
163
+ - 增加 `easy ipadapterApplyEmbeds`
164
+ - 增加 `easy preMaskDetailerFix`
165
+ - `easy kSamplerInpainting` 增加 **additional** 属性,可设置成 Differential Diffusion 或 Only InpaintModelConditioning
166
+ - 修复 `easy stylesSelector` 当未选择样式时,原���提示词发生了变化
167
+ - 修复 `easy pipeEdit` 提示词输入lora时报错
168
+ - 修复 layerDiffuse xyplot相关bug
169
+ </details>
170
+
171
+ <details>
172
+ <summary><b>v1.1.1/b></summary>
173
+
174
+ - 修复首次添加含seed的节点且当前模式为control_before_generate时,seed为0的问题
175
+ - `easy preSamplingAdvanced` 增加 **return_with_leftover_noise**
176
+ - 修复 `easy stylesSelector` 当选择自定义样式文件时运行队列报错
177
+ - `easy preSamplingLayerDiffusion` 增加 mask 可选传入参数
178
+ - 将所有 **seed_num** 调整回 **seed**
179
+ - 修补官方BUG: 当control_mode为before 在首次加载页面时未修改节点中widget名称为 control_before_generate
180
+ - 去除强制**control_before_generate**设定
181
+ - 增加 `easy imageRemBg` - 默认为BriaAI的RMBG-1.4模型, 移除背景效果更加,速度更快
182
+ </details>
183
+
184
+ <details>
185
+ <summary><b>v1.1.0</b></summary>
186
+
187
+ - 增加 `easy imageSplitList` - 拆分每 N 张图像
188
+ - 增加 `easy preSamplingDiffusionADDTL` - 可配置前景、背景、blended的additional_prompt等
189
+ - 增加 `easy preSamplingNoiseIn` 可替代需要前置的`easy latentNoisy`节点 实现效果更好的噪声注入
190
+ - `easy pipeEdit` 增加 条件拼接模式选择,可选择替换、合并、联结、平均、设置条件时间
191
+ - 增加 `easy pipeEdit` - 可编辑Pipe的节点(包含可重新输入提示词)
192
+ - 增加 `easy preSamplingLayerDiffusion` 与 `easy kSamplerLayerDiffusion` (连接 `easy kSampler` 也能通)
193
+ - 增加 在 加载器、预采样、采样器、Controlnet等节点上右键可快速替换同类型节点的便捷菜单
194
+ - 增加 `easy instantIDApplyADV` 可连入 positive 与 negative
195
+ - 修复 `easy wildcards` 读取lora未填写完整路径时未自动检索导致加载lora失败的问题
196
+ - 修复 `easy instantIDApply` mask 未传入正确值
197
+ - 修复 在 非a1111提示词风格下 BREAK 不生效的问题
198
+ </details>
199
+
200
+ <details>
201
+ <summary><b>v1.0.9</b></summary>
202
+
203
+ - 修复未安装 ComfyUI-Impack-Pack 和 ComfyUI_InstantID 时报错
204
+ - 修复 `easy pipeIn` - pipe设为可不必选
205
+ - 增加 `easy instantIDApply` - 需要先安装 [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID), 工作流参考[示例](https://github.com/yolain/ComfyUI-Yolain-Workflows?tab=readme-ov-file#2-2-instantid)
206
+ - 修复 `easy detailerFix` 未添加到保存图片格式化扩展名可用节点列表
207
+ - 修复 `easy XYInputs: PromptSR` 在替换负面提示词时报错
208
+ </details>
209
+
210
+ <details>
211
+ <summary><b>v1.0.8</b></summary>
212
+
213
+ - `easy cascadeLoader` stage_c 与 stage_b 支持checkpoint模型 (需要下载[checkpoints](https://huggingface.co/stabilityai/stable-cascade/tree/main/comfyui_checkpoints))
214
+ - `easy styleSelector` 搜索框修改为不区分大小写匹配
215
+ - `easy fullLoader` 增加 **positive**、**negative**、**latent** 输出项
216
+ - 修复 SDXLClipModel 在 ComfyUI 修订版本号 2016[c2cb8e88] 及以上的报错(判断了版本号可兼容老版本)
217
+ - 修复 `easy detailerFix` 批次大小大于1时生成出错
218
+ - 修复`easy preSampling`等 latent传入后无法根据批次索引生成的问题
219
+ - 修复 `easy svdLoader` 报错
220
+ - 优化代码,减少了诸多冗余,提升运行速度
221
+ - 去除中文翻译对照文本
222
+
223
+ (翻译对照已由 [AIGODLIKE-COMFYUI-TRANSLATION](https://github.com/AIGODLIKE/AIGODLIKE-ComfyUI-Translation) 统一维护啦!
224
+ 首次下载或者版本较早的朋友请更新 AIGODLIKE-COMFYUI-TRANSLATION 和本节点包至最新版本。)
225
+ </details>
226
+
227
+ <details>
228
+ <summary><b>v1.0.7</b></summary>
229
+
230
+ - 增加 `easy cascadeLoader` - stable cascade 加载器
231
+ - 增加 `easy preSamplingCascade` - stabled cascade stage_c 预采样参数
232
+ - 增加 `easy fullCascadeKSampler` - stable cascade stage_c 完整版采样器
233
+ - 增加 `easy cascadeKSampler` - stable cascade stage-c ksampler simple
234
+ </details>
235
+
236
+ <details>
237
+ <summary><b>v1.0.6</b></summary>
238
+
239
+ - 增加 `easy XYInputs: Checkpoint`
240
+ - 增加 `easy XYInputs: Lora`
241
+ - `easy seed` 增加固定种子值时可手动切换随机种
242
+ - 修复 `easy fullLoader`等加载器切换lora时自动调整节点大小的问题
243
+ - 去除原有ttn的图片保存逻辑并适配ComfyUI默认的图片保存格式化扩展
244
+ </details>
245
+
246
+ <details>
247
+ <summary><b>v1.0.5</b></summary>
248
+
249
+ - 增加 `easy isSDXL`
250
+ - `easy svdLoader` 增加提示词控制, 可配合open_clip模型进行使用
251
+ - `easy wildcards` 增加 **populated_text** 可输出通配填充后文本
252
+ </details>
253
+
254
+ <details>
255
+ <summary><b>v1.0.4</b></summary>
256
+
257
+ - 增加 `easy showLoaderSettingsNames` 可显示与输出加载器部件中的 模型与VAE名称
258
+ - 增加 `easy promptList` - 提示词列表
259
+ - 增加 `easy fooocusInpaintLoader` - Fooocus内补节点(仅支持XL模型的流程)
260
+ - 增加 **Logic** 逻辑类节点 - 包含类型、计算、判断和转换类型等
261
+ - 增加 `easy imageSave` - 带日期转换和宽高格式化的图像保存节点
262
+ - 增加 `easy joinImageBatch` - 合并图像批次
263
+ - `easy showAnything` 增加支持转换其他类型(如:tensor类型的条件、图像等)
264
+ - `easy kSamplerInpainting` 增加 **patch** 传入值,配合Fooocus内补节点使用
265
+ - `easy imageSave` 增加 **only_preivew**
266
+
267
+ - 修复 xyplot在pillow>9.5中报错
268
+ - 修复 `easy wildcards` 在使用PS扩展插件运行时报错
269
+ - 修复 `easy latentCompositeMaskedWithCond`
270
+ - 修复 `easy XYInputs: ControlNet` 报错
271
+ - 修复 `easy loraStack` **toggle** 为 disabled 时报错
272
+
273
+ - 修改首次安装节点包不再自动替换主题,需手动调整并刷新页面
274
+ </details>
275
+
276
+ <details>
277
+ <summary><b>v1.0.3</b></summary>
278
+
279
+ - 增加 `easy stylesSelector` 风格化提示词选择器
280
+ - 增加队列进度条设置项,默认为未启用状态
281
+ - `easy controlnetLoader` 和 `easy controlnetLoaderADV` 增加参数 **scale_soft_weights**
282
+
283
+
284
+ - 修复 `easy XYInputs: Sampler/Scheduler` 报错
285
+ - 修复 右侧菜单 点击按钮时老是跑位的问题
286
+ - 修复 styles 路径在其他环境报错
287
+ - 修复 `easy comfyLoader` 读取错误
288
+ - 修复 xyPlot 在连接 zero123 时报错
289
+ - 修复加载器中提示词为组件时报错
290
+ - 修复 `easy getNode` 和 `easy setNode` 加载时标题未更改
291
+ - 修复所有采样器中存储图片使用子目录前缀不生效的问题
292
+
293
+
294
+ - 调整UI主题
295
+ </details>
296
+
297
+ <details>
298
+ <summary><b>v1.0.2</b></summary>
299
+
300
+ - 增加 **autocomplete** 文件夹,如果您安装了 [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts), 将在启动时合并该文件夹下的所有txt文件并覆盖到pyssss包里的autocomplete.txt文件。
301
+ - 增加 `easy XYPlotAdvanced` 和 `easy XYInputs` 等相关节点
302
+ - 增加 **Alt+1到9** 快捷键,可快速粘贴 Node templates 的节点预设 (对应 1到9 顺序)
303
+
304
+ - 修复 `easy imageInsetCrop` 测量值为百分比时步进为1
305
+ - 修复 开启 `a1111_prompt_style` 时XY图表无法使用的问题
306
+ - 右键菜单中增加了一个 `📜Groups Map(EasyUse)`
307
+
308
+ - 修复在Comfy新版本中UI加载失败
309
+ - 修复 `easy pipeToBasicPipe` 报错
310
+ - 修改 `easy fullLoader` 和 `easy a1111Loader` 中的 **a1111_prompt_style** 默认值为 False
311
+ - `easy XYInputs ModelMergeBlocks` 支持csv文件导入数值
312
+
313
+ - 替换了XY图生成时的字体文件
314
+
315
+ - 移除 `easy imageRemBg`
316
+ - 移除包中的介绍图和工作流文件,减少包体积
317
+
318
+ </details>
319
+
320
+ <details>
321
+ <summary><b>v1.0.1</b></summary>
322
+
323
+ - 新增 `easy seed` - 简易随机种
324
+ - `easy preDetailerFix` 新增了 `optional_image` 传入图像可选,如未传默认取值为pipe里的图像
325
+ - 新增 `easy kSamplerInpainting` 用于内补潜空间的采样器
326
+ - 新增 `easy pipeToBasicPipe` 用于转换到Impact的某些节点上
327
+
328
+ - 修复 `easy comfyLoader` 报错
329
+ - 修复所有包含输出图片尺寸的节点取值方式无法批处理的问题
330
+ - 修复 `width` 和 `height` 无法在 `easy svdLoader` 自定义的报错问题
331
+ - 修复所有采样器预览图片的地址链接 (解决在 MACOS 系统中图片无法在采样器中预览的问题)
332
+ - 修复 `vae_name` 在 `easy fullLoader` 和 `easy a1111Loader` 和 `easy comfyLoader` 中选择但未替换原始vae问题
333
+ - 修复 `easy fullkSampler` 除pipe外其他输出值的报错
334
+ - 修复 `easy hiresFix` 输入连接pipe和image、vae同时存在时报错
335
+ - 修复 `easy fullLoader` 中 `model_override` 连接后未执行
336
+ - 修复 因新增`easy seed` 导致action错误
337
+ - 修复 `easy xyplot` 的字体文件路径读取错误
338
+ - 修复 convert 到 `easy seed` 随机种无法固定的问题
339
+ - 修复 `easy pipeIn` 值传入的报错问题
340
+ - 修复 `easy zero123Loader` 和 `easy svdLoader` 读取模型时将模型加入到缓存中
341
+ - 修复 `easy kSampler` `easy kSamplerTiled` `easy detailerFix` 的 `image_output` 默认值为 Preview
342
+ - `easy fullLoader` 和 `easy a1111Loader` 新增了 `a1111_prompt_style` 参数可以重现和webui生成相同的图像,当前您需要安装 [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) 才能使用此功能
343
+ </details>
344
+
345
+ <details>
346
+ <summary><b>v1.0.0</b></summary>
347
+
348
+ - 新增`easy positive` - 简易正面提示词文本
349
+ - 新增`easy negative` - 简易负面提示词文本
350
+ - 新增`easy wildcards` - 支持通配符和Lora选择的提示词文本
351
+ - 新增`easy portraitMaster` - 肖像大师v2.2
352
+ - 新增`easy loraStack` - Lora堆
353
+ - 新增`easy fullLoader` - 完整版的加载器
354
+ - 新增`easy zero123Loader` - 简易zero123加载器
355
+ - 新增`easy svdLoader` - 简易svd加载器
356
+ - 新增`easy fullkSampler` - 完整版的采样器(无分离)
357
+ - 新增`easy hiresFix` - 支持Pipe的高清修复
358
+ - 新增`easy predetailerFix` `easy DetailerFix` - 支持Pipe的细节修复
359
+ - 新增`easy ultralyticsDetectorPipe` `easy samLoaderPipe` - 检测加载器(细节修复的输入项)
360
+ - 新增`easy pipein` `easy pipeout` - Pipe的输入与输出
361
+ - 新增`easy xyPlot` - 简易的xyplot (后续会更新更多可控参数)
362
+ - 新增`easy imageRemoveBG` - 图像去除背景
363
+ - 新增`easy imagePixelPerfect` - 图像完美像素
364
+ - 新增`easy poseEditor` - 姿势编辑器
365
+ - 新增UI主题(黑曜石)- 默认自动加载UI, 也可在设置中自行更替
366
+
367
+ - 修复 `easy globalSeed` 不生效问题
368
+ - 修复所有的`seed_num` 因 [cg-use-everywhere](https://github.com/chrisgoringe/cg-use-everywhere) 实时更新图表导致值错乱的问题
369
+ - 修复`easy imageSize` `easy imageSizeBySide` `easy imageSizeByLongerSide` 可作为终节点
370
+ - 修复 `seed_num` (随机种子值) 在历史记录中读取无法一致的Bug
371
+ </details>
372
+
373
+
374
+ <details>
375
+ <summary><b>v0.5</b></summary>
376
+
377
+ - 新增 `easy controlnetLoaderADV` 节点
378
+ - 新增 `easy imageSizeBySide` 节点,可选输出为长边或短边
379
+ - 新增 `easy LLLiteLoader` 节点,如果您预先安装过 kohya-ss/ControlNet-LLLite-ComfyUI 包,请将 models 里的模型文件移动至 ComfyUI\models\controlnet\ (即comfy默认的controlnet路径里,请勿修改模型的文件名,不然会读取不到)。
380
+ - 新增 `easy imageSize` 和 `easy imageSizeByLongerSize` 输出的尺寸显示。
381
+ - 新增 `easy showSpentTime` 节点用于展示图片推理花费时间与VAE解码花费时间。
382
+ - `easy controlnetLoaderADV` 和 `easy controlnetLoader` 新增 `control_net` 可选传入参数
383
+ - `easy preSampling` 和 `easy preSamplingAdvanced` 新增 `image_to_latent` 可选传入参数
384
+ - `easy a1111Loader` 和 `easy comfyLoader` 新增 `batch_size` 传入参数
385
+
386
+ - 修改 `easy controlnetLoader` 到 loader 分类底下。
387
+ </details>
388
+
389
+ ## 整合参考到的相关节点包
390
+
391
+ 声明: 非常尊重这些原作者们的付出,开源不易,我仅仅只是做了一些整合与优化。
392
+
393
+ | 节点名 (搜索名) | 相关的库 | 库相关的节点 |
394
+ |:-------------------------------|:----------------------------------------------------------------------------|:------------------------|
395
+ | easy setNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.SetNode |
396
+ | easy getNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.GetNode |
397
+ | easy bookmark | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | Bookmark 🔖 |
398
+ | easy portraitMarker | [comfyui-portrait-master](https://github.com/florestefano1975/comfyui-portrait-master) | Portrait Master |
399
+ | easy LLLiteLoader | [ControlNet-LLLite-ComfyUI](https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI) | LLLiteLoader |
400
+ | easy globalSeed | [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) | Global Seed (Inspire) |
401
+ | easy preSamplingDynamicCFG | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
402
+ | dynamicThresholdingFull | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
403
+ | easy imageInsetCrop | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | ImageInsetCrop |
404
+ | easy poseEditor | [ComfyUI_Custom_Nodes_AlekPet](https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet) | poseNode |
405
+ | easy if | [ComfyUI-Logic](https://github.com/theUpsider/ComfyUI-Logic) | IfExecute |
406
+ | easy preSamplingLayerDiffusion | [ComfyUI-layerdiffusion](https://github.com/huchenlei/ComfyUI-layerdiffusion) | LayeredDiffusionApply等 |
407
+ | easy dynamiCrafterLoader | [ComfyUI-layerdiffusion](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter) | Apply Dynamicrafter |
408
+ | easy imageChooser | [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) | Preview Chooser |
409
+ | easy styleAlignedBatchAlign | [style_aligned_comfy](https://github.com/chrisgoringe/cg-image-picker) | styleAlignedBatchAlign |
410
+ | easy icLightApply | [ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light) | ICLightApply等 |
411
+ | easy kolorsLoader | [ComfyUI-Kolors-MZ](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) | kolorsLoader |
412
+
413
+ ## Credits
414
+
415
+ [ComfyUI](https://github.com/comfyanonymous/ComfyUI) - 功能强大且模块化的Stable Diffusion GUI
416
+
417
+ [ComfyUI-ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) - ComfyUI管理器
418
+
419
+ [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) - 管道节点(节点束)让用户减少了不必要的连接
420
+
421
+ [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) - diffus3的获取与设置点让用户可以分离工作流构成
422
+
423
+ [ComfyUI-Impact-Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack) - 常规整合包1
424
+
425
+ [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) - 常规整合包2
426
+
427
+ [ComfyUI-Logic](https://github.com/theUpsider/ComfyUI-Logic) - ComfyUI逻辑运算
428
+
429
+ [ComfyUI-ResAdapter](https://github.com/jiaxiangc/ComfyUI-ResAdapter) - 让模型生成不受训练分辨率限制
430
+
431
+ [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) - 风格迁移
432
+
433
+ [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) - 人脸迁移
434
+
435
+ [ComfyUI_PuLID](https://github.com/cubiq/PuLID_ComfyUI) - 人脸迁移
436
+
437
+ [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) - pyssss 小蛇🐍脚本
438
+
439
+ [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) - 图片选择器
440
+
441
+ [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet) - BrushNet 内补节点
442
+
443
+ [ComfyUI_ExtraModels](https://github.com/city96/ComfyUI_ExtraModels) - DiT架构相关节点(Pixart、混元DiT等)
444
+
445
+ ## ☕️ Donation
446
+
447
+ **Comfyui-Easy-Use** 是一个 GPL 许可的开源项目。为了项目取得更好、可持续的发展,我希望能够获得更多的支持。 如果我的自定义节点为您的一天增添了价值,请考虑喝杯咖啡来进一步补充能量! 💖感谢您的支持,每一杯咖啡都是我创作的动力!
448
+
449
+ - [BiliBili充电](https://space.bilibili.com/1840885116)
450
+ - [爱发电](https://afdian.com/a/yolain)
451
+ - [Wechat/Alipay](https://github.com/user-attachments/assets/803469bd-ed6a-4fab-932d-50e5088a2d03)
452
+
453
+ 感谢您的捐助,我将用这些费用来租用 GPU 或购买其他 GPT 服务,以便更好地调试和完善 ComfyUI-Easy-Use 功能
454
+
455
+ ## 🌟Stargazers
456
+
457
+ My gratitude extends to the generous souls who bestow a star. Your support is much appreciated!
458
+
459
+ [![Stargazers repo roster for @yolain/ComfyUI-Easy-Use](https://reporoster.com/stars/yolain/ComfyUI-Easy-Use)](https://github.com/yolain/ComfyUI-Easy-Use/stargazers)
ComfyUI-Easy-Use/README.en.md ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="right">
2
+ <a href="./README.md">中文</a> | <strong>English</strong>
3
+ </p>
4
+
5
+ <div align="center">
6
+
7
+ # ComfyUI Easy Use
8
+ </div>
9
+
10
+ **ComfyUI-Easy-Use** is a simplified node integration package, which is extended on the basis of [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes), and has been integrated and optimized for many mainstream node packages to achieve the purpose of faster and more convenient use of ComfyUI. While ensuring the degree of freedom, it restores the ultimate smooth image production experience that belongs to Stable Diffusion.
11
+
12
+ [![ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Easy-Use/assets/73304135/9a3f54bc-a677-4bf1-a196-8845dd57c942)](https://github.com/yolain/ComfyUI-Yolain-Workflows)
13
+
14
+ ## 👨🏻‍🎨 Introduce
15
+
16
+ - Inspire by [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes), which greatly reduces the time cost of tossing workflows。
17
+ - UI interface beautification, the first time you install the user, if you need to use the UI theme, please switch the theme in Settings -> Color Palette and refresh page.
18
+ - Added a node for pre-sampling parameter configuration, which can be separated from the sampling node for easier previewing
19
+ - Wildcards and lora's are supported, for Lora Block Weight usage, ensure that the custom node package has the [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack)
20
+ - Multi-selectable styled cue word selector, default is Fooocus style json, custom json can be placed under styles, samples folder can be placed in the preview image (name and name consistent, image file name such as spaces need to be converted to underscores '_')
21
+ - The loader enables the A1111 prompt mode, which reproduces nearly identical images to those generated by webui, and needs to be installed [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) first.
22
+ - Noise injection into the latent space can be achieved using the `easy latentNoisy` or `easy preSamplingNoiseIn` node
23
+ - Simplified processes for SD1.x, SD2.x, SDXL, SVD, Zero123, etc. [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableDiffusion)
24
+ - Simplified Stable Cascade [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableCascade)
25
+ - Simplified Layer Diffuse [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#LayerDiffusion),The first time you use it you may need to run `pip install -r requirements.txt` to install the required dependencies.
26
+ - Simplified InstantID [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#InstantID), You need to make sure that the custom node package has the [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
27
+ - Extending the usability of XYplot
28
+ - Fooocus Inpaint integration
29
+ - Integration of common logical calculations, conversion of types, display of all types, etc.
30
+ - Background removal nodes for the RMBG-1.4 model supporting BriaAI, [BriaAI Guide](https://huggingface.co/briaai/RMBG-1.4)
31
+ - Forcibly cleared the memory usage of the comfy UI model are supported
32
+ - Stable Diffusion 3 multi-account API nodes are supported
33
+ - Support Stable Diffusion 3 model
34
+ - Support Kolors model
35
+
36
+ ## 👨🏻‍🔧 Installation
37
+ Clone the repo into the **custom_nodes** directory and install the requirements:
38
+ ```shell
39
+ #1. Clone the repo
40
+ git clone https://github.com/yolain/ComfyUI-Easy-Use
41
+ #2. Install the requirements
42
+ Double-click install.bat to install the required dependencies
43
+ ```
44
+
45
+ ## ☕️ Plan
46
+
47
+ - [ ] Updated new front-end code for easier maintenance
48
+ - [x] Maintain css styles using sass
49
+ - [ ] Optimize existing extensions
50
+ - [ ] Add new components
51
+ - [ ] Add light theme
52
+ - [ ] Upload new workflows to [ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Yolain-Workflows) and translate readme to english version.
53
+ - [ ] Write gitbook with more detailed function introdution
54
+
55
+ ## 📜 Changelog
56
+
57
+ **v1.2.1**
58
+
59
+ - Added **inspyrenet** to `easy imageRemBg`
60
+ - Added `easy controlnetLoader++`
61
+ - Added **PLUS (kolors genernal)** preset to `easy ipadapterApply` and `easy ipadapterApplyADV` (Supported kolors ipadapter)
62
+ - Added `easy kolorsLoader` - Code based on [MinusZoneAI](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ)'s and [kijai](https://github.com/kijai/ComfyUI-KwaiKolorsWrapper)'s repo, thanks for their contribution.
63
+
64
+ **v1.2.0**
65
+
66
+ - Added `easy pulIDApply` and `easy pulIDApplyADV`
67
+ - Added `easy huanyuanDiTLoader` and `easy pixArtLoader`
68
+ - Added **easy sliderControl** - Slider control node, which can currently be used to control the parameters of ipadapterMS (double-click the slider to reset to default)
69
+ - Added **layer_weights** in `easy ipadapterApplyADV`
70
+
71
+ **v1.1.9**
72
+
73
+ - Added **gitsScheduler**
74
+ - Added `easy imageBatchToImageList` and `easy imageListToImageBatch`
75
+ - Recursive subcategories nested for models
76
+ - Support for Stable Diffusion 3 model
77
+ - Added `easy applyInpaint` - All inpainting mode in this node
78
+
79
+ **v1.1.8**
80
+
81
+ - Added `easy controlnetStack`
82
+ - Added `easy applyBrushNet` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4brushnet_1.1.8.json)
83
+ - Added `easy applyPowerPaint` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4powerpaint_outpaint_1.1.8.json)
84
+
85
+ **v1.1.7**
86
+
87
+ - Added `easy prompt` - Subject and light presets, maybe adjusted later
88
+ - Added `easy icLightApply` - Light and shadow migration, Code based on [ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light)
89
+ - Added `easy imageSplitGrid`
90
+ - `easy kSamplerInpainting` added options such as different diffusion and brushnet in **additional** widget
91
+ - Support for brushnet model loading - [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)
92
+ - Added `easy applyFooocusInpaint` - Replace FooocusInpaintLoader
93
+ - Removed `easy fooocusInpaintLoader`
94
+
95
+ **v1.1.6**
96
+
97
+ - Added **alignYourSteps** to **schedulder** widget in all `easy preSampling` and `easy fullkSampler`
98
+ - Added **Preview&Choose** to **image_output** widget in `easy kSampler` & `easy fullkSampler`
99
+ - Added `easy styleAlignedBatchAlign` - Credit of [style_aligned_comfy](https://github.com/brianfitzgerald/style_aligned_comfy)
100
+ - Added `easy ckptNames`
101
+ - Added `easy controlnetNames`
102
+ - Added `easy imagesSplitimage` - Batch images split into single images
103
+ - Added `easy imageCount` - Get Image Count
104
+ - Added `easy textSwitch` - Text Switch
105
+
106
+ **v1.1.5**
107
+
108
+ - Rewrite `easy cleanGPUUsed` - the memory usage of the comfyUI can to be cleared
109
+ - Added `easy humanSegmentation` - Human Part Segmentation
110
+ - Added `easy imageColorMatch`
111
+ - Added `easy ipadapterApplyRegional`
112
+ - Added `easy ipadapterApplyFromParams`
113
+ - Added `easy imageInterrogator` - Image To Prompt
114
+ - Added `easy stableDiffusion3API` - Easy Stable Diffusion 3 Multiple accounts API Node
115
+
116
+ **v1.1.4**
117
+
118
+ - Added `easy preSamplingCustom` - Custom-PreSampling, can be supported cosXL-edit
119
+ - Added `easy ipadapterStyleComposition`
120
+ - Added the right-click menu to view checkpoints and lora information in all Loaders
121
+ - Fixed `easy preSamplingNoiseIn`、`easy latentNoisy`、`east Unsampler` compatible with ComfyUI Revision>=2098 [0542088e] or later
122
+
123
+
124
+ **v1.1.3**
125
+
126
+ - `easy ipadapterApply` Added **COMPOSITION** preset
127
+ - Supported [ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) when load ResAdapter lora
128
+ - Added `easy promptLine`
129
+ - Added `easy promptReplace`
130
+ - Added `easy promptConcat`
131
+ - `easy wildcards` Added **multiline_mode**
132
+
133
+ **v1.1.2**
134
+
135
+ - Optimized some of the recommended nodes for slots related to EasyUse
136
+ - Added **Enable ContextMenu Auto Nest Subdirectories** The setting item is enabled by default, and it can be classified into subdirectories, checkpoints and loras previews
137
+ - Added `easy sv3dLoader`
138
+ - Added `easy dynamiCrafterLoader`
139
+ - Added `easy ipadapterApply`
140
+ - Added `easy ipadapterApplyADV`
141
+ - Added `easy ipadapterApplyEncoder`
142
+ - Added `easy ipadapterApplyEmbeds`
143
+ - Added `easy preMaskDetailerFix`
144
+ - Fixed `easy stylesSelector` is change the prompt when not select the style
145
+ - Fixed `easy pipeEdit` error when add lora to prompt
146
+ - Fixed layerDiffuse xyplot bug
147
+ - `easy kSamplerInpainting` add *additional* widget,you can choose 'Differential Diffusion' or 'Only InpaintModelConditioning'
148
+
149
+ **v1.1.1**
150
+
151
+ - The issue that the seed is 0 when a node with a seed control is added and **control before generate** is fixed for the first time run queue prompt.
152
+ - `easy preSamplingAdvanced` Added **return_with_leftover_noise**
153
+ - Fixed `easy stylesSelector` error when choose the custom file
154
+ - `easy preSamplingLayerDiffusion` Added optional input parameter for mask
155
+ - Renamed all nodes widget name named seed_num to seed
156
+ - Remove forced **control_before_generate** settings。 If you want to use control_before_generate, change widget_value_control_mode to before in system settings
157
+ - Added `easy imageRemBg` - The default is BriaAI's RMBG-1.4 model, which removes the background effect more and faster
158
+
159
+ <details>
160
+ <summary><b>v1.1.0</b></summary>
161
+
162
+ - Added `easy imageSplitList` - to split every N images
163
+ - Added `easy preSamplingDiffusionADDTL` - It can modify foreground、background or blended additional prompt
164
+ - Added `easy preSamplingNoiseIn` It can replace the `easy latentNoisy` node that needs to be fronted to achieve better noise injection
165
+ - `easy pipeEdit` Added conditioning splicing mode selection, you can choose to replace, concat, combine, average, and set timestep range
166
+ - Added `easy pipeEdit` - nodes that can edit pipes (including re-enterable prompts)
167
+ - Added `easy preSamplingLayerDiffusion` and `easy kSamplerLayerDiffusion`
168
+ - Added a convenient menu to right-click on nodes such as Loader, Presampler, Sampler, Controlnet, etc. to quickly replace nodes of the same type
169
+ - Added `easy instantIDApplyADV` can link positive and negative
170
+ - Fixed layerDiffusion error when batch size greater than 1
171
+ - Fixed `easy wildcards` When LoRa is not filled in completely, LoRa is not automatically retrieved, resulting in failure to load LoRa
172
+ - Fixed the issue that 'BREAK' non-initiation when didn't use a1111 prompt style
173
+ - Fixed `easy instantIDApply` mask not input right
174
+ </details>
175
+
176
+ <details>
177
+ <summary><b>v1.0.9</b></summary>
178
+
179
+ - Fixed the error when ComfyUI-Impack-Pack and ComfyUI_InstantID were not installed
180
+ - Fixed `easy pipeIn`
181
+ - Added `easy instantIDApply` - you need installed [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) fisrt, Workflow[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#InstantID)
182
+ - Fixed `easy detailerFix` not added to the list of nodes available for saving images formatting extensions
183
+ - Fixed `easy XYInputs: PromptSR` errors are reported when replacing negative prompts
184
+ </details>
185
+
186
+ <details>
187
+ <summary><b>v1.0.8</b></summary>
188
+
189
+ - `easy cascadeLoader` stage_c and stage_b support the checkpoint model (Download [checkpoints](https://huggingface.co/stabilityai/stable-cascade/tree/main/comfyui_checkpoints) models)
190
+ - `easy styleSelector` The search box is modified to be case-insensitive
191
+ - `easy fullLoader` **positive**、**negative**、**latent** added to the output items
192
+ - Fixed the issue that 'easy preSampling' and other similar node, latent could not be generated based on the batch index after passing in
193
+ - Fixed `easy svdLoader` error when the positive or negative is empty
194
+ - Fixed the error of SDXLClipModel in ComfyUI revision 2016[c2cb8e88] and above (the revision number was judged to be compatible with the old revision)
195
+ - Fixed `easy detailerFix` generation error when batch size is greater than 1
196
+ - Optimize the code, reduce a lot of redundant code and improve the running speed
197
+ </details>
198
+
199
+ <details>
200
+ <summary><b>v1.0.7</b></summary>
201
+
202
+ - Added `easy cascadeLoader` - stable cascade Loader
203
+ - Added `easy preSamplingCascade` - stable cascade preSampling Settings
204
+ - Added `easy fullCascadeKSampler` - stable cascade stage-c ksampler full
205
+ - Added `easy cascadeKSampler` - stable cascade stage-c ksampler simple
206
+ -
207
+ - Optimize the image to image[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#image-to-image)
208
+ </details>
209
+
210
+ <details>
211
+ <summary><b>v1.0.6</b></summary>
212
+
213
+ - Added `easy XYInputs: Checkpoint`
214
+ - Added `easy XYInputs: Lora`
215
+ - `easy seed` can manually switch the random seed when increasing the fixed seed value
216
+ - Fixed `easy fullLoader` and all loaders to automatically adjust the node size when switching LoRa
217
+ - Removed the original ttn image saving logic and adapted to the default image saving format extension of ComfyUI
218
+ </details>
219
+
220
+ <details>
221
+ <summary><b>v1.0.5</b></summary>
222
+
223
+ - Added `easy isSDXL`
224
+ - Added prompt word control on `easy svdLoader`, which can be used with open_clip model
225
+ - Added **populated_text** on `easy wildcards`, wildcard populated text can be output
226
+ </details>
227
+
228
+ <details>
229
+ <summary><b>v1.0.4</b></summary>
230
+
231
+ - `easy showAnything` added support for converting other types (e.g., tensor conditions, images, etc.)
232
+ - Added `easy showLoaderSettingsNames` can display the model and VAE name in the output loader assembly
233
+ - Added `easy promptList`
234
+ - Added `easy fooocusInpaintLoader` (only the process of SDXLModel is supported)
235
+ - Added **Logic** nodes
236
+ - Added `easy imageSave` - Image saving node with date conversion and aspect and height formatting
237
+ - Added `easy joinImageBatch`
238
+ - `easy kSamplerInpainting` Added the **patch** input value to be used with the FooocusInpaintLoader node
239
+
240
+ - Fixed xyplot error when with Pillow>9.5
241
+ - Fixed `easy wildcards` An error is reported when running with the PS extension
242
+ - Fixed `easy XYInputs: ControlNet` Error
243
+ - Fixed `easy loraStack` error when **toggle** is disabled
244
+
245
+
246
+ - Changing the first-time install node package no longer automatically replaces the theme, you need to manually adjust and refresh the page
247
+ - `easy imageSave` added **only_preivew**
248
+ - Adjust the `easy latentCompositeMaskedWithCond` node
249
+ </details>
250
+
251
+ <details>
252
+ <summary><b>v1.0.3</b></summary>
253
+
254
+ - Added `easy stylesSelector`
255
+ - Added **scale_soft_weights** in `easy controlnetLoader` and `easy controlnetLoaderADV`
256
+ - Added the queue progress bar setting item, which is not enabled by default
257
+
258
+
259
+ - Fixed `easy XYInputs: Sampler/Scheduler` Error
260
+ - Fixed the right menu has a problem when clicking the button
261
+ - Fixed `easy comfyLoader` error
262
+ - Fixed xyPlot error when connecting to zero123
263
+ - Fixed the error message in the loader when the prompt word was component
264
+ - Fixed `easy getNode` and `easy setNode` the title does not change when loading
265
+ - Fixed all samplers using subdirectories to store images
266
+
267
+
268
+ - Adjust the UI theme, divided into two sets of styles: the official default background and the dark black background, which can be switched in the color palette in the settings
269
+ - Modify the styles path to be compatible with other environments
270
+ </details>
271
+
272
+ <details>
273
+ <summary><b>v1.0.2</b></summary>
274
+
275
+ - Added `easy XYPlotAdvanced` and some nodes about `easy XYInputs`
276
+ - Added **Alt+1-Alt+9** Shortcut keys to quickly paste node presets for Node templates (corresponding to 1~9 sequences)
277
+ - Added a `📜Groups Map(EasyUse)` to the context menu.
278
+ - An `autocomplete` folder has been added, If you have [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) installed, the txt files in that folder will be merged and overwritten to the autocomplete .txt file of the pyssss package at startup.
279
+
280
+
281
+ - Fixed XYPlot is not working when `a1111_prompt_style` is True
282
+ - Fixed UI loading failure in the new version of ComfyUI
283
+ - `easy XYInputs ModelMergeBlocks` Values can be imported from CSV files
284
+ - Fixed `easy pipeToBasicPipe` Bug
285
+
286
+
287
+ - Removed `easy imageRemBg`
288
+ - Remove the introductory diagram and workflow files from the package to reduce the package size
289
+ - Replaced the font file used in the generation of XY diagrams
290
+ </details>
291
+
292
+ <details>
293
+ <summary><b>v1.0.1</b></summary>
294
+
295
+ - Fixed `easy comfyLoader` error
296
+ - Fixed All nodes that contain the value of the image size
297
+ - Added `easy kSamplerInpainting`
298
+ - Added `easy pipeToBasicPipe`
299
+ - Fixed `width` and `height` can not customize in `easy svdLoader`
300
+ - Fixed all preview image path (Previously, it was not possible to preview the image on the Mac system)
301
+ - Fixed `vae_name` is not working in `easy fullLoader` and `easy a1111Loader` and `easy comfyLoader`
302
+ - Fixed `easy fullkSampler` outputs error
303
+ - Fixed `model_override` is not working in `easy fullLoader`
304
+ - Fixed `easy hiresFix` error
305
+ - Fixed `easy xyplot` font file path error
306
+ - Fixed seed that cannot be fixed when you convert `seed_num` to `easy seed`
307
+ - Fixed `easy pipeIn` inputs bug
308
+ - `easy preDetailerFix` have added a new parameter `optional_image`
309
+ - Fixed `easy zero123Loader` and `easy svdLoader` model into cache.
310
+ - Added `easy seed`
311
+ - Fixed `image_output` default value is "Preview"
312
+ - `easy fullLoader` and `easy a1111Loader` have added a new parameter `a1111_prompt_style`,that can reproduce the same image generated from stable-diffusion-webui on comfyui, but you need to install [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) to use this feature in the current version
313
+ </details>
314
+
315
+ <details>
316
+ <summary><b>v1.0.0</b></summary>
317
+
318
+ - Added `easy positive` - simple positive prompt text
319
+ - Added `easy negative` - simple negative prompt text
320
+ - Added `easy wildcards` - support for wildcards and hint text selected by Lora
321
+ - Added `easy portraitMaster` - PortraitMaster v2.2
322
+ - Added `easy loraStack` - Lora stack
323
+ - Added `easy fullLoader` - full version of the loader
324
+ - Added `easy zero123Loader` - simple zero123 loader
325
+ - Added `easy svdLoader` - easy svd loader
326
+ - Added `easy fullkSampler` - full version of the sampler (no separation)
327
+ - Added `easy hiresFix` - support for HD repair of Pipe
328
+ - Added `easy predetailerFix` and `easy DetailerFix` - support for Pipe detail fixing
329
+ - Added `easy ultralyticsDetectorPipe` and `easy samLoaderPipe` - Detect loader (detail fixed input)
330
+ - Added `easy pipein` `easy pipeout` - Pipe input and output
331
+ - Added `easy xyPlot` - simple xyplot (more controllable parameters will be updated in the future)
332
+ - Added `easy imageRemoveBG` - image to remove background
333
+ - Added `easy imagePixelPerfect` - image pixel perfect
334
+ - Added `easy poseEditor` - Pose editor
335
+ - New UI Theme (Obsidian) - Auto-load UI by default, which can also be changed in the settings
336
+
337
+ - Fixed `easy globalSeed` is not working
338
+ - Fixed an issue where all `seed_num` values were out of order due to [cg-use-everywhere](https://github.com/chrisgoringe/cg-use-everywhere) updating the chart in real time
339
+ - Fixed `easy imageSize`, `easy imageSizeBySide`, `easy imageSizeByLongerSide` as end nodes
340
+ - Fixed the bug that `seed_num` (random seed value) could not be read consistently in history
341
+ </details>
342
+
343
+ <details>
344
+ <summary><b>Updated at 12/14/2023</b></summary>
345
+
346
+ - `easy a1111Loader` and `easy comfyLoader` added `batch_size` of required input parameters
347
+ - Added the `easy controlnetLoaderADV` node
348
+ - `easy controlnetLoaderADV` and `easy controlnetLoader` added `control_net ` of optional input parameters
349
+ - `easy preSampling` and `easy preSamplingAdvanced` added `image_to_latent` optional input parameters
350
+ - Added the `easy imageSizeBySide` node, which can be output as a long side or a short side
351
+ </details>
352
+
353
+ <details>
354
+ <summary><b>Updated at 12/13/2023</b></summary>
355
+
356
+ - Added the `easy LLLiteLoader` node, if you have pre-installed the kohya-ss/ControlNet-LLLite-ComfyUI package, please move the model files in the models to `ComfyUI\models\controlnet\` (i.e. in the default controlnet path of comfy, please do not change the file name of the model, otherwise it will not be read).
357
+ - Modify `easy controlnetLoader` to the bottom of the loader category.
358
+ - Added size display for `easy imageSize` and `easy imageSizeByLongerSize` outputs.
359
+ </details>
360
+
361
+ <details>
362
+ <summary><b>Updated at 12/11/2023</b></summary>
363
+ - Added the `showSpentTime` node to display the time spent on image diffusion and the time spent on VAE decoding images
364
+ </details>
365
+
366
+ ## The relevant node package involved
367
+
368
+ Disclaimer: Opened source was not easy. I have a lot of respect for the contributions of these original authors. I just did some integration and optimization.
369
+
370
+ | Nodes Name(Search Name) | Related libraries | Library-related node |
371
+ |:-------------------------------|:----------------------------------------------------------------------------|:-------------------------|
372
+ | easy setNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.SetNode |
373
+ | easy getNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.GetNode |
374
+ | easy bookmark | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | Bookmark 🔖 |
375
+ | easy portraitMarker | [comfyui-portrait-master](https://github.com/florestefano1975/comfyui-portrait-master) | Portrait Master |
376
+ | easy LLLiteLoader | [ControlNet-LLLite-ComfyUI](https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI) | LLLiteLoader |
377
+ | easy globalSeed | [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) | Global Seed (Inspire) |
378
+ | easy preSamplingDynamicCFG | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
379
+ | dynamicThresholdingFull | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
380
+ | easy imageInsetCrop | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | ImageInsetCrop |
381
+ | easy poseEditor | [ComfyUI_Custom_Nodes_AlekPet](https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet) | poseNode |
382
+ | easy preSamplingLayerDiffusion | [ComfyUI-layerdiffusion](https://github.com/huchenlei/ComfyUI-layerdiffusion) | LayeredDiffusionApply... |
383
+ | easy dynamiCrafterLoader | [ComfyUI-layerdiffusion](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter) | Apply Dynamicrafter |
384
+ | easy imageChooser | [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) | Preview Chooser |
385
+ | easy styleAlignedBatchAlign | [style_aligned_comfy](https://github.com/chrisgoringe/cg-image-picker) | styleAlignedBatchAlign |
386
+ | easy kolorsLoader | [ComfyUI-Kolors-MZ](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) | kolorsLoader |
387
+
388
+
389
+ ## Credits
390
+
391
+ [ComfyUI](https://github.com/comfyanonymous/ComfyUI) - Powerful and modular Stable Diffusion GUI
392
+
393
+ [ComfyUI-ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) - ComfyUI Manager
394
+
395
+ [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) - Pipe nodes (node bundles) allow users to reduce unnecessary connections
396
+
397
+ [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) - Diffus3 gets and sets points that allow the user to detach the composition of the workflow
398
+
399
+ [ComfyUI-Impact-Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack) - General modpack 1
400
+
401
+ [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) - General Modpack 2
402
+
403
+ [ComfyUI-ResAdapter](https://github.com/jiaxiangc/ComfyUI-ResAdapter) - Make model generation independent of training resolution
404
+
405
+ [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) - Style migration
406
+
407
+ [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) - Face migration
408
+
409
+ [ComfyUI_PuLID](https://github.com/cubiq/PuLID_ComfyUI) - Face migration
410
+
411
+ [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) - pyssss🐍
412
+
413
+ [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) - Image Preview Chooser
414
+
415
+ [ComfyUI_ExtraModels](https://github.com/city96/ComfyUI_ExtraModels) - DiT custom nodes
416
+
417
+
418
+ ## 🌟Stargazers
419
+
420
+ My gratitude extends to the generous souls who bestow a star. Your support is much appreciated!
421
+
422
+ [![Stargazers repo roster for @yolain/ComfyUI-Easy-Use](https://reporoster.com/stars/yolain/ComfyUI-Easy-Use)](https://github.com/yolain/ComfyUI-Easy-Use/stargazers)
ComfyUI-Easy-Use/README.md ADDED
@@ -0,0 +1,448 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![comfyui-easy-use](https://github.com/user-attachments/assets/9b7a5e44-f5e2-4c27-aed2-d0e6b50c46bb)
2
+
3
+ <div align="center">
4
+ <a href="https://space.bilibili.com/1840885116">Video Tutorial</a> |
5
+ Docs (Cooming Soon) |
6
+ <a href="https://github.com/yolain/ComfyUI-Yolain-Workflows">Workflow Collection</a> |
7
+ <a href="#%EF%B8%8F-donation">Donation</a>
8
+ <br><br>
9
+ <a href="./README.md"><img src="https://img.shields.io/badge/🇬🇧English-0b8cf5"></a>
10
+ <a href="./README.ZH_CN.md"><img src="https://img.shields.io/badge/🇨🇳中文简体-e9e9e9"></a>
11
+ </div>
12
+
13
+ **ComfyUI-Easy-Use** is an efficiency custom nodes integration package, which is extended on the basis of [TinyTerraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes). It has been integrated and optimized for many popular awesome custom nodes to achieve the purpose of faster and more convenient use of ComfyUI. While ensuring the degree of freedom, it restores the ultimate smooth image production experience that belongs to Stable Diffusion.
14
+
15
+ ## 👨🏻‍🎨 Introduce
16
+
17
+ - Inspire by [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes), which greatly reduces the time cost of tossing workflows。
18
+ - UI interface beautification, the first time you install the user, if you need to use the UI theme, please switch the theme in Settings -> Color Palette and refresh page.
19
+ - Added a node for pre-sampling parameter configuration, which can be separated from the sampling node for easier previewing
20
+ - Wildcards and lora's are supported, for Lora Block Weight usage, ensure that the custom node package has the [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack)
21
+ - Multi-selectable styled cue word selector, default is Fooocus style json, custom json can be placed under styles, samples folder can be placed in the preview image (name and name consistent, image file name such as spaces need to be converted to underscores '_')
22
+ - The loader enables the A1111 prompt mode, which reproduces nearly identical images to those generated by webui, and needs to be installed [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) first.
23
+ - Noise injection into the latent space can be achieved using the `easy latentNoisy` or `easy preSamplingNoiseIn` node
24
+ - Simplified processes for SD1.x, SD2.x, SDXL, SVD, Zero123, etc. [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableDiffusion)
25
+ - Simplified Stable Cascade [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#StableCascade)
26
+ - Simplified Layer Diffuse [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#LayerDiffusion),The first time you use it you may need to run `pip install -r requirements.txt` to install the required dependencies.
27
+ - Simplified InstantID [Example](https://github.com/yolain/ComfyUI-Easy-Use?tab=readme-ov-file#InstantID), You need to make sure that the custom node package has the [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID)
28
+ - Extending the usability of XYplot
29
+ - Fooocus Inpaint integration
30
+ - Integration of common logical calculations, conversion of types, display of all types, etc.
31
+ - Background removal nodes for the RMBG-1.4 model supporting BriaAI, [BriaAI Guide](https://huggingface.co/briaai/RMBG-1.4)
32
+ - Forcibly cleared the memory usage of the comfy UI model are supported
33
+ - Stable Diffusion 3 multi-account API nodes are supported
34
+ - Support SD3's model
35
+ - Support Kolors‘s model
36
+ - Support Flux's model
37
+
38
+ ## 👨🏻‍🔧 Installation
39
+ Clone the repo into the **custom_nodes** directory and install the requirements:
40
+ ```shell
41
+ #1. Clone the repo
42
+ git clone https://github.com/yolain/ComfyUI-Easy-Use
43
+ #2. Install the requirements
44
+ Double-click install.bat to install the required dependencies
45
+ ```
46
+
47
+ ## 👨🏻‍🚀 Plan
48
+
49
+ - [x] Updated new front-end code for easier maintenance
50
+ - [x] Maintain css styles using sass
51
+ - [x] Optimize existing extensions
52
+ - [x] Add new components
53
+ - [ ] Upload new workflows to [ComfyUI-Yolain-Workflows](https://github.com/yolain/ComfyUI-Yolain-Workflows) and translate readme to english version.
54
+ - [ ] Write gitbook with more detailed function introdution
55
+
56
+ ## 📜 Changelog
57
+
58
+ **v1.2.2**
59
+
60
+ - Added v2 web frond-end code
61
+ - Added `easy fluxLoader`
62
+ - Added support for `controlnetApply` Related nodes with SD3 and hunyuanDiT
63
+
64
+ **v1.2.1**
65
+
66
+ - Added `easy ipadapterApplyFaceIDKolors`
67
+ - Added **inspyrenet** to `easy imageRemBg`
68
+ - Added `easy controlnetLoader++`
69
+ - Added **PLUS (kolors genernal)** and **FACEID PLUS KOLORS** preset to `easy ipadapterApply` and `easy ipadapterApplyADV` (Supported kolors ipadapter)
70
+ - Added `easy kolorsLoader` - Code based on [MinusZoneAI](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ)'s and [kijai](https://github.com/kijai/ComfyUI-KwaiKolorsWrapper)'s repo, thanks for their contribution.
71
+
72
+ **v1.2.0**
73
+
74
+ - Added `easy pulIDApply` and `easy pulIDApplyADV`
75
+ - Added `easy huanyuanDiTLoader` and `easy pixArtLoader`
76
+ - Added **easy sliderControl** - Slider control node, which can currently be used to control the parameters of ipadapterMS (double-click the slider to reset to default)
77
+ - Added **layer_weights** in `easy ipadapterApplyADV`
78
+
79
+ **v1.1.9**
80
+
81
+ - Added **gitsScheduler**
82
+ - Added `easy imageBatchToImageList` and `easy imageListToImageBatch`
83
+ - Recursive subcategories nested for models
84
+ - Support for Stable Diffusion 3 model
85
+ - Added `easy applyInpaint` - All inpainting mode in this node
86
+
87
+ **v1.1.8**
88
+
89
+ - Added `easy controlnetStack`
90
+ - Added `easy applyBrushNet` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4brushnet_1.1.8.json)
91
+ - Added `easy applyPowerPaint` - [Workflow Example](https://github.com/yolain/ComfyUI-Yolain-Workflows/blob/main/workflows/2_advanced/2-4inpainting/2-4powerpaint_outpaint_1.1.8.json)
92
+
93
+ **v1.1.7**
94
+
95
+ - Added `easy prompt` - Subject and light presets, maybe adjusted later
96
+ - Added `easy icLightApply` - Light and shadow migration, Code based on [ComfyUI-IC-Light](https://github.com/huchenlei/ComfyUI-IC-Light)
97
+ - Added `easy imageSplitGrid`
98
+ - `easy kSamplerInpainting` added options such as different diffusion and brushnet in **additional** widget
99
+ - Support for brushnet model loading - [ComfyUI-BrushNet](https://github.com/nullquant/ComfyUI-BrushNet)
100
+ - Added `easy applyFooocusInpaint` - Replace FooocusInpaintLoader
101
+ - Removed `easy fooocusInpaintLoader`
102
+
103
+ **v1.1.6**
104
+
105
+ - Added **alignYourSteps** to **schedulder** widget in all `easy preSampling` and `easy fullkSampler`
106
+ - Added **Preview&Choose** to **image_output** widget in `easy kSampler` & `easy fullkSampler`
107
+ - Added `easy styleAlignedBatchAlign` - Credit of [style_aligned_comfy](https://github.com/brianfitzgerald/style_aligned_comfy)
108
+ - Added `easy ckptNames`
109
+ - Added `easy controlnetNames`
110
+ - Added `easy imagesSplitimage` - Batch images split into single images
111
+ - Added `easy imageCount` - Get Image Count
112
+ - Added `easy textSwitch` - Text Switch
113
+
114
+ **v1.1.5**
115
+
116
+ - Rewrite `easy cleanGPUUsed` - the memory usage of the comfyUI can to be cleared
117
+ - Added `easy humanSegmentation` - Human Part Segmentation
118
+ - Added `easy imageColorMatch`
119
+ - Added `easy ipadapterApplyRegional`
120
+ - Added `easy ipadapterApplyFromParams`
121
+ - Added `easy imageInterrogator` - Image To Prompt
122
+ - Added `easy stableDiffusion3API` - Easy Stable Diffusion 3 Multiple accounts API Node
123
+
124
+ **v1.1.4**
125
+
126
+ - Added `easy preSamplingCustom` - Custom-PreSampling, can be supported cosXL-edit
127
+ - Added `easy ipadapterStyleComposition`
128
+ - Added the right-click menu to view checkpoints and lora information in all Loaders
129
+ - Fixed `easy preSamplingNoiseIn`、`easy latentNoisy`、`east Unsampler` compatible with ComfyUI Revision>=2098 [0542088e] or later
130
+
131
+
132
+ **v1.1.3**
133
+
134
+ - `easy ipadapterApply` Added **COMPOSITION** preset
135
+ - Supported [ResAdapter](https://huggingface.co/jiaxiangc/res-adapter) when load ResAdapter lora
136
+ - Added `easy promptLine`
137
+ - Added `easy promptReplace`
138
+ - Added `easy promptConcat`
139
+ - `easy wildcards` Added **multiline_mode**
140
+
141
+ <details>
142
+ <summary><b>v1.1.2</b></summary>
143
+
144
+ - Optimized some of the recommended nodes for slots related to EasyUse
145
+ - Added **Enable ContextMenu Auto Nest Subdirectories** The setting item is enabled by default, and it can be classified into subdirectories, checkpoints and loras previews
146
+ - Added `easy sv3dLoader`
147
+ - Added `easy dynamiCrafterLoader`
148
+ - Added `easy ipadapterApply`
149
+ - Added `easy ipadapterApplyADV`
150
+ - Added `easy ipadapterApplyEncoder`
151
+ - Added `easy ipadapterApplyEmbeds`
152
+ - Added `easy preMaskDetailerFix`
153
+ - Fixed `easy stylesSelector` is change the prompt when not select the style
154
+ - Fixed `easy pipeEdit` error when add lora to prompt
155
+ - Fixed layerDiffuse xyplot bug
156
+ - `easy kSamplerInpainting` add *additional* widget,you can choose 'Differential Diffusion' or 'Only InpaintModelConditioning'
157
+ </details>
158
+
159
+ <details>
160
+ <summary><b>v1.1.1</b></summary>
161
+
162
+ - The issue that the seed is 0 when a node with a seed control is added and **control before generate** is fixed for the first time run queue prompt.
163
+ - `easy preSamplingAdvanced` Added **return_with_leftover_noise**
164
+ - Fixed `easy stylesSelector` error when choose the custom file
165
+ - `easy preSamplingLayerDiffusion` Added optional input parameter for mask
166
+ - Renamed all nodes widget name named seed_num to seed
167
+ - Remove forced **control_before_generate** settings。 If you want to use control_before_generate, change widget_value_control_mode to before in system settings
168
+ - Added `easy imageRemBg` - The default is BriaAI's RMBG-1.4 model, which removes the background effect more and faster
169
+ </details>
170
+
171
+ <details>
172
+ <summary><b>v1.1.0</b></summary>
173
+
174
+ - Added `easy imageSplitList` - to split every N images
175
+ - Added `easy preSamplingDiffusionADDTL` - It can modify foreground、background or blended additional prompt
176
+ - Added `easy preSamplingNoiseIn` It can replace the `easy latentNoisy` node that needs to be fronted to achieve better noise injection
177
+ - `easy pipeEdit` Added conditioning splicing mode selection, you can choose to replace, concat, combine, average, and set timestep range
178
+ - Added `easy pipeEdit` - nodes that can edit pipes (including re-enterable prompts)
179
+ - Added `easy preSamplingLayerDiffusion` and `easy kSamplerLayerDiffusion`
180
+ - Added a convenient menu to right-click on nodes such as Loader, Presampler, Sampler, Controlnet, etc. to quickly replace nodes of the same type
181
+ - Added `easy instantIDApplyADV` can link positive and negative
182
+ - Fixed layerDiffusion error when batch size greater than 1
183
+ - Fixed `easy wildcards` When LoRa is not filled in completely, LoRa is not automatically retrieved, resulting in failure to load LoRa
184
+ - Fixed the issue that 'BREAK' non-initiation when didn't use a1111 prompt style
185
+ - Fixed `easy instantIDApply` mask not input right
186
+ </details>
187
+
188
+ <details>
189
+ <summary><b>v1.0.9</b></summary>
190
+
191
+ - Fixed the error when ComfyUI-Impack-Pack and ComfyUI_InstantID were not installed
192
+ - Fixed `easy pipeIn`
193
+ - Added `easy instantIDApply` - you need installed [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) fisrt, Workflow[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#InstantID)
194
+ - Fixed `easy detailerFix` not added to the list of nodes available for saving images formatting extensions
195
+ - Fixed `easy XYInputs: PromptSR` errors are reported when replacing negative prompts
196
+ </details>
197
+
198
+ <details>
199
+ <summary><b>v1.0.8</b></summary>
200
+
201
+ - `easy cascadeLoader` stage_c and stage_b support the checkpoint model (Download [checkpoints](https://huggingface.co/stabilityai/stable-cascade/tree/main/comfyui_checkpoints) models)
202
+ - `easy styleSelector` The search box is modified to be case-insensitive
203
+ - `easy fullLoader` **positive**、**negative**、**latent** added to the output items
204
+ - Fixed the issue that 'easy preSampling' and other similar node, latent could not be generated based on the batch index after passing in
205
+ - Fixed `easy svdLoader` error when the positive or negative is empty
206
+ - Fixed the error of SDXLClipModel in ComfyUI revision 2016[c2cb8e88] and above (the revision number was judged to be compatible with the old revision)
207
+ - Fixed `easy detailerFix` generation error when batch size is greater than 1
208
+ - Optimize the code, reduce a lot of redundant code and improve the running speed
209
+ </details>
210
+
211
+ <details>
212
+ <summary><b>v1.0.7</b></summary>
213
+
214
+ - Added `easy cascadeLoader` - stable cascade Loader
215
+ - Added `easy preSamplingCascade` - stable cascade preSampling Settings
216
+ - Added `easy fullCascadeKSampler` - stable cascade stage-c ksampler full
217
+ - Added `easy cascadeKSampler` - stable cascade stage-c ksampler simple
218
+ -
219
+ - Optimize the image to image[Example](https://github.com/yolain/ComfyUI-Easy-Use/blob/main/README.en.md#image-to-image)
220
+ </details>
221
+
222
+ <details>
223
+ <summary><b>v1.0.6</b></summary>
224
+
225
+ - Added `easy XYInputs: Checkpoint`
226
+ - Added `easy XYInputs: Lora`
227
+ - `easy seed` can manually switch the random seed when increasing the fixed seed value
228
+ - Fixed `easy fullLoader` and all loaders to automatically adjust the node size when switching LoRa
229
+ - Removed the original ttn image saving logic and adapted to the default image saving format extension of ComfyUI
230
+ </details>
231
+
232
+ <details>
233
+ <summary><b>v1.0.5</b></summary>
234
+
235
+ - Added `easy isSDXL`
236
+ - Added prompt word control on `easy svdLoader`, which can be used with open_clip model
237
+ - Added **populated_text** on `easy wildcards`, wildcard populated text can be output
238
+ </details>
239
+
240
+ <details>
241
+ <summary><b>v1.0.4</b></summary>
242
+
243
+ - `easy showAnything` added support for converting other types (e.g., tensor conditions, images, etc.)
244
+ - Added `easy showLoaderSettingsNames` can display the model and VAE name in the output loader assembly
245
+ - Added `easy promptList`
246
+ - Added `easy fooocusInpaintLoader` (only the process of SDXLModel is supported)
247
+ - Added **Logic** nodes
248
+ - Added `easy imageSave` - Image saving node with date conversion and aspect and height formatting
249
+ - Added `easy joinImageBatch`
250
+ - `easy kSamplerInpainting` Added the **patch** input value to be used with the FooocusInpaintLoader node
251
+
252
+ - Fixed xyplot error when with Pillow>9.5
253
+ - Fixed `easy wildcards` An error is reported when running with the PS extension
254
+ - Fixed `easy XYInputs: ControlNet` Error
255
+ - Fixed `easy loraStack` error when **toggle** is disabled
256
+
257
+
258
+ - Changing the first-time install node package no longer automatically replaces the theme, you need to manually adjust and refresh the page
259
+ - `easy imageSave` added **only_preivew**
260
+ - Adjust the `easy latentCompositeMaskedWithCond` node
261
+ </details>
262
+
263
+ <details>
264
+ <summary><b>v1.0.3</b></summary>
265
+
266
+ - Added `easy stylesSelector`
267
+ - Added **scale_soft_weights** in `easy controlnetLoader` and `easy controlnetLoaderADV`
268
+ - Added the queue progress bar setting item, which is not enabled by default
269
+
270
+
271
+ - Fixed `easy XYInputs: Sampler/Scheduler` Error
272
+ - Fixed the right menu has a problem when clicking the button
273
+ - Fixed `easy comfyLoader` error
274
+ - Fixed xyPlot error when connecting to zero123
275
+ - Fixed the error message in the loader when the prompt word was component
276
+ - Fixed `easy getNode` and `easy setNode` the title does not change when loading
277
+ - Fixed all samplers using subdirectories to store images
278
+
279
+
280
+ - Adjust the UI theme, divided into two sets of styles: the official default background and the dark black background, which can be switched in the color palette in the settings
281
+ - Modify the styles path to be compatible with other environments
282
+ </details>
283
+
284
+ <details>
285
+ <summary><b>v1.0.2</b></summary>
286
+
287
+ - Added `easy XYPlotAdvanced` and some nodes about `easy XYInputs`
288
+ - Added **Alt+1-Alt+9** Shortcut keys to quickly paste node presets for Node templates (corresponding to 1~9 sequences)
289
+ - Added a `📜Groups Map(EasyUse)` to the context menu.
290
+ - An `autocomplete` folder has been added, If you have [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) installed, the txt files in that folder will be merged and overwritten to the autocomplete .txt file of the pyssss package at startup.
291
+
292
+
293
+ - Fixed XYPlot is not working when `a1111_prompt_style` is True
294
+ - Fixed UI loading failure in the new version of ComfyUI
295
+ - `easy XYInputs ModelMergeBlocks` Values can be imported from CSV files
296
+ - Fixed `easy pipeToBasicPipe` Bug
297
+
298
+
299
+ - Removed `easy imageRemBg`
300
+ - Remove the introductory diagram and workflow files from the package to reduce the package size
301
+ - Replaced the font file used in the generation of XY diagrams
302
+ </details>
303
+
304
+ <details>
305
+ <summary><b>v1.0.1</b></summary>
306
+
307
+ - Fixed `easy comfyLoader` error
308
+ - Fixed All nodes that contain the value of the image size
309
+ - Added `easy kSamplerInpainting`
310
+ - Added `easy pipeToBasicPipe`
311
+ - Fixed `width` and `height` can not customize in `easy svdLoader`
312
+ - Fixed all preview image path (Previously, it was not possible to preview the image on the Mac system)
313
+ - Fixed `vae_name` is not working in `easy fullLoader` and `easy a1111Loader` and `easy comfyLoader`
314
+ - Fixed `easy fullkSampler` outputs error
315
+ - Fixed `model_override` is not working in `easy fullLoader`
316
+ - Fixed `easy hiresFix` error
317
+ - Fixed `easy xyplot` font file path error
318
+ - Fixed seed that cannot be fixed when you convert `seed_num` to `easy seed`
319
+ - Fixed `easy pipeIn` inputs bug
320
+ - `easy preDetailerFix` have added a new parameter `optional_image`
321
+ - Fixed `easy zero123Loader` and `easy svdLoader` model into cache.
322
+ - Added `easy seed`
323
+ - Fixed `image_output` default value is "Preview"
324
+ - `easy fullLoader` and `easy a1111Loader` have added a new parameter `a1111_prompt_style`,that can reproduce the same image generated from stable-diffusion-webui on comfyui, but you need to install [ComfyUI_smZNodes](https://github.com/shiimizu/ComfyUI_smZNodes) to use this feature in the current version
325
+ </details>
326
+
327
+ <details>
328
+ <summary><b>v1.0.0</b></summary>
329
+
330
+ - Added `easy positive` - simple positive prompt text
331
+ - Added `easy negative` - simple negative prompt text
332
+ - Added `easy wildcards` - support for wildcards and hint text selected by Lora
333
+ - Added `easy portraitMaster` - PortraitMaster v2.2
334
+ - Added `easy loraStack` - Lora stack
335
+ - Added `easy fullLoader` - full version of the loader
336
+ - Added `easy zero123Loader` - simple zero123 loader
337
+ - Added `easy svdLoader` - easy svd loader
338
+ - Added `easy fullkSampler` - full version of the sampler (no separation)
339
+ - Added `easy hiresFix` - support for HD repair of Pipe
340
+ - Added `easy predetailerFix` and `easy DetailerFix` - support for Pipe detail fixing
341
+ - Added `easy ultralyticsDetectorPipe` and `easy samLoaderPipe` - Detect loader (detail fixed input)
342
+ - Added `easy pipein` `easy pipeout` - Pipe input and output
343
+ - Added `easy xyPlot` - simple xyplot (more controllable parameters will be updated in the future)
344
+ - Added `easy imageRemoveBG` - image to remove background
345
+ - Added `easy imagePixelPerfect` - image pixel perfect
346
+ - Added `easy poseEditor` - Pose editor
347
+ - New UI Theme (Obsidian) - Auto-load UI by default, which can also be changed in the settings
348
+
349
+ - Fixed `easy globalSeed` is not working
350
+ - Fixed an issue where all `seed_num` values were out of order due to [cg-use-everywhere](https://github.com/chrisgoringe/cg-use-everywhere) updating the chart in real time
351
+ - Fixed `easy imageSize`, `easy imageSizeBySide`, `easy imageSizeByLongerSide` as end nodes
352
+ - Fixed the bug that `seed_num` (random seed value) could not be read consistently in history
353
+ </details>
354
+
355
+ <details>
356
+ <summary><b>Updated at 12/14/2023</b></summary>
357
+
358
+ - `easy a1111Loader` and `easy comfyLoader` added `batch_size` of required input parameters
359
+ - Added the `easy controlnetLoaderADV` node
360
+ - `easy controlnetLoaderADV` and `easy controlnetLoader` added `control_net ` of optional input parameters
361
+ - `easy preSampling` and `easy preSamplingAdvanced` added `image_to_latent` optional input parameters
362
+ - Added the `easy imageSizeBySide` node, which can be output as a long side or a short side
363
+ </details>
364
+
365
+ <details>
366
+ <summary><b>Updated at 12/13/2023</b></summary>
367
+
368
+ - Added the `easy LLLiteLoader` node, if you have pre-installed the kohya-ss/ControlNet-LLLite-ComfyUI package, please move the model files in the models to `ComfyUI\models\controlnet\` (i.e. in the default controlnet path of comfy, please do not change the file name of the model, otherwise it will not be read).
369
+ - Modify `easy controlnetLoader` to the bottom of the loader category.
370
+ - Added size display for `easy imageSize` and `easy imageSizeByLongerSize` outputs.
371
+ </details>
372
+
373
+ <details>
374
+ <summary><b>Updated at 12/11/2023</b></summary>
375
+ - Added the `showSpentTime` node to display the time spent on image diffusion and the time spent on VAE decoding images
376
+ </details>
377
+
378
+ ## The relevant node package involved
379
+
380
+ Disclaimer: Opened source was not easy. I have a lot of respect for the contributions of these original authors. I just did some integration and optimization.
381
+
382
+ | Nodes Name(Search Name) | Related libraries | Library-related node |
383
+ |:-------------------------------|:----------------------------------------------------------------------------|:-------------------------|
384
+ | easy setNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.SetNode |
385
+ | easy getNode | [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) | diffus3.GetNode |
386
+ | easy bookmark | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | Bookmark 🔖 |
387
+ | easy portraitMarker | [comfyui-portrait-master](https://github.com/florestefano1975/comfyui-portrait-master) | Portrait Master |
388
+ | easy LLLiteLoader | [ControlNet-LLLite-ComfyUI](https://github.com/kohya-ss/ControlNet-LLLite-ComfyUI) | LLLiteLoader |
389
+ | easy globalSeed | [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) | Global Seed (Inspire) |
390
+ | easy preSamplingDynamicCFG | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
391
+ | dynamicThresholdingFull | [sd-dynamic-thresholding](https://github.com/mcmonkeyprojects/sd-dynamic-thresholding) | DynamicThresholdingFull |
392
+ | easy imageInsetCrop | [rgthree-comfy](https://github.com/rgthree/rgthree-comfy) | ImageInsetCrop |
393
+ | easy poseEditor | [ComfyUI_Custom_Nodes_AlekPet](https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet) | poseNode |
394
+ | easy preSamplingLayerDiffusion | [ComfyUI-layerdiffusion](https://github.com/huchenlei/ComfyUI-layerdiffusion) | LayeredDiffusionApply... |
395
+ | easy dynamiCrafterLoader | [ComfyUI-layerdiffusion](https://github.com/ExponentialML/ComfyUI_Native_DynamiCrafter) | Apply Dynamicrafter |
396
+ | easy imageChooser | [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) | Preview Chooser |
397
+ | easy styleAlignedBatchAlign | [style_aligned_comfy](https://github.com/chrisgoringe/cg-image-picker) | styleAlignedBatchAlign |
398
+ | easy kolorsLoader | [ComfyUI-Kolors-MZ](https://github.com/MinusZoneAI/ComfyUI-Kolors-MZ) | kolorsLoader |
399
+
400
+
401
+ ## Credits
402
+
403
+ [ComfyUI](https://github.com/comfyanonymous/ComfyUI) - Powerful and modular Stable Diffusion GUI
404
+
405
+ [ComfyUI-ComfyUI-Manager](https://github.com/ltdrdata/ComfyUI-Manager) - ComfyUI Manager
406
+
407
+ [tinyterraNodes](https://github.com/TinyTerra/ComfyUI_tinyterraNodes) - Pipe nodes (node bundles) allow users to reduce unnecessary connections
408
+
409
+ [ComfyUI-extensions](https://github.com/diffus3/ComfyUI-extensions) - Diffus3 gets and sets points that allow the user to detach the composition of the workflow
410
+
411
+ [ComfyUI-Impact-Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack) - General modpack 1
412
+
413
+ [ComfyUI-Inspire-Pack](https://github.com/ltdrdata/ComfyUI-Inspire-Pack) - General Modpack 2
414
+
415
+ [ComfyUI-ResAdapter](https://github.com/jiaxiangc/ComfyUI-ResAdapter) - Make model generation independent of training resolution
416
+
417
+ [ComfyUI_IPAdapter_plus](https://github.com/cubiq/ComfyUI_IPAdapter_plus) - Style migration
418
+
419
+ [ComfyUI_InstantID](https://github.com/cubiq/ComfyUI_InstantID) - Face migration
420
+
421
+ [ComfyUI_PuLID](https://github.com/cubiq/PuLID_ComfyUI) - Face migration
422
+
423
+ [ComfyUI-Custom-Scripts](https://github.com/pythongosssss/ComfyUI-Custom-Scripts) - pyssss🐍
424
+
425
+ [cg-image-picker](https://github.com/chrisgoringe/cg-image-picker) - Image Preview Chooser
426
+
427
+ [ComfyUI_ExtraModels](https://github.com/city96/ComfyUI_ExtraModels) - DiT custom nodes
428
+
429
+ ## ☕️ Donation
430
+
431
+ **Comfyui-Easy-Use** is an GPL-licensed open source project. In order to achieve better and sustainable development of the project, i expect to gain more backers. <br>
432
+ If my custom nodes has added value to your day, consider indulging in a coffee to fuel it further! <br>
433
+ 💖You can support me in any of the following ways:
434
+
435
+ - [BiliBili](https://space.bilibili.com/1840885116)
436
+ - [Afdian](https://afdian.com/a/yolain)
437
+ - [Wechat / Alipay](https://github.com/user-attachments/assets/803469bd-ed6a-4fab-932d-50e5088a2d03)
438
+ - 🪙 Wallet Address:
439
+ - ETH: 0x01f7CEd3245CaB3891A0ec8f528178db352EaC74
440
+ - USDT(tron): TP3AnJXkAzfebL2GKmFAvQvXgsxzivweV6
441
+
442
+ (This is a newly created wallet, and if it receives sponsorship, I'll use it to rent GPUs or other GPT services for better debugging and refinement of ComfyUI-Easy-Use features.)
443
+
444
+ ## 🌟Stargazers
445
+
446
+ My gratitude extends to the generous souls who bestow a star. Your support is much appreciated!
447
+
448
+ [![Stargazers repo roster for @yolain/ComfyUI-Easy-Use](https://reporoster.com/stars/yolain/ComfyUI-Easy-Use)](https://github.com/yolain/ComfyUI-Easy-Use/stargazers)
ComfyUI-Easy-Use/__init__.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.2.2"
2
+
3
+ import yaml
4
+ import os
5
+ import folder_paths
6
+ import importlib
7
+ from pathlib import Path
8
+
9
+ node_list = [
10
+ "server",
11
+ "api",
12
+ "easyNodes",
13
+ "image",
14
+ "logic"
15
+ ]
16
+
17
+ NODE_CLASS_MAPPINGS = {}
18
+ NODE_DISPLAY_NAME_MAPPINGS = {}
19
+
20
+ for module_name in node_list:
21
+ imported_module = importlib.import_module(".py.{}".format(module_name), __name__)
22
+ NODE_CLASS_MAPPINGS = {**NODE_CLASS_MAPPINGS, **imported_module.NODE_CLASS_MAPPINGS}
23
+ NODE_DISPLAY_NAME_MAPPINGS = {**NODE_DISPLAY_NAME_MAPPINGS, **imported_module.NODE_DISPLAY_NAME_MAPPINGS}
24
+
25
+ cwd_path = os.path.dirname(os.path.realpath(__file__))
26
+ comfy_path = folder_paths.base_path
27
+
28
+ #Wildcards
29
+ from .py.libs.wildcards import read_wildcard_dict
30
+ wildcards_path = os.path.join(os.path.dirname(__file__), "wildcards")
31
+ if os.path.exists(wildcards_path):
32
+ read_wildcard_dict(wildcards_path)
33
+ else:
34
+ os.mkdir(wildcards_path)
35
+
36
+ #Styles
37
+ styles_path = os.path.join(os.path.dirname(__file__), "styles")
38
+ samples_path = os.path.join(os.path.dirname(__file__), "styles", "samples")
39
+ if os.path.exists(styles_path):
40
+ if not os.path.exists(samples_path):
41
+ os.mkdir(samples_path)
42
+ else:
43
+ os.mkdir(styles_path)
44
+ os.mkdir(samples_path)
45
+
46
+ # Model thumbnails
47
+ from .py.libs.add_resources import add_static_resource
48
+ from .py.libs.model import easyModelManager
49
+ model_config = easyModelManager().models_config
50
+ for model in model_config:
51
+ paths = folder_paths.get_folder_paths(model)
52
+ for path in paths:
53
+ if not Path(path).exists():
54
+ continue
55
+ add_static_resource(path, path, limit=True)
56
+
57
+ # get comfyui revision
58
+ from .py.libs.utils import compare_revision
59
+
60
+ new_frontend_revision = 2546
61
+ web_default_version = 'v2' if compare_revision(new_frontend_revision) else 'v1'
62
+ # web directory
63
+ config_path = os.path.join(cwd_path, "config.yaml")
64
+ if os.path.isfile(config_path):
65
+ with open(config_path, 'r') as f:
66
+ data = yaml.load(f, Loader=yaml.FullLoader)
67
+ if data and "WEB_VERSION" in data:
68
+ directory = f"web_version/{data['WEB_VERSION']}"
69
+ with open(config_path, 'w') as f:
70
+ yaml.dump(data, f)
71
+ elif web_default_version != 'v1':
72
+ if not data:
73
+ data = {'WEB_VERSION': web_default_version}
74
+ elif 'WEB_VERSION' not in data:
75
+ data = {**data, 'WEB_VERSION': web_default_version}
76
+ with open(config_path, 'w') as f:
77
+ yaml.dump(data, f)
78
+ directory = f"web_version/{web_default_version}"
79
+ else:
80
+ directory = f"web_version/v1"
81
+ if not os.path.exists(os.path.join(cwd_path, directory)):
82
+ print(f"web root {data['WEB_VERSION']} not found, using default")
83
+ directory = f"web_version/{web_default_version}"
84
+ WEB_DIRECTORY = directory
85
+ else:
86
+ directory = f"web_version/{web_default_version}"
87
+ WEB_DIRECTORY = directory
88
+
89
+ __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS', "WEB_DIRECTORY"]
90
+
91
+ print(f'\033[34m[ComfyUI-Easy-Use] server: \033[0mv{__version__} \033[92mLoaded\033[0m')
92
+ print(f'\033[34m[ComfyUI-Easy-Use] web root: \033[0m{os.path.join(cwd_path, directory)} \033[92mLoaded\033[0m')
ComfyUI-Easy-Use/config.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ STABILITY_API_DEFAULT: 0
2
+ STABILITY_API_KEY:
3
+ - key: ''
4
+ name: Default
5
+ WEB_VERSION: v2
ComfyUI-Easy-Use/install.bat ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+
3
+ set "requirements_txt=%~dp0\requirements.txt"
4
+ set "requirements_repair_txt=%~dp0\repair_dependency_list.txt"
5
+ set "python_exec=..\..\..\python_embedded\python.exe"
6
+ set "aki_python_exec=..\..\python\python.exe"
7
+
8
+ echo Installing EasyUse Requirements...
9
+
10
+ if exist "%python_exec%" (
11
+ echo Installing with ComfyUI Portable
12
+ "%python_exec%" -s -m pip install -r "%requirements_txt%"
13
+ )^
14
+ else (
15
+ echo Installing with Python
16
+ pip install -r "%requirements_txt%"
17
+ )
18
+ pause
ComfyUI-Easy-Use/prestartup_script.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import folder_paths
2
+ import os
3
+ def add_folder_path_and_extensions(folder_name, full_folder_paths, extensions):
4
+ for full_folder_path in full_folder_paths:
5
+ folder_paths.add_model_folder_path(folder_name, full_folder_path)
6
+ if folder_name in folder_paths.folder_names_and_paths:
7
+ current_paths, current_extensions = folder_paths.folder_names_and_paths[folder_name]
8
+ updated_extensions = current_extensions | extensions
9
+ folder_paths.folder_names_and_paths[folder_name] = (current_paths, updated_extensions)
10
+ else:
11
+ folder_paths.folder_names_and_paths[folder_name] = (full_folder_paths, extensions)
12
+
13
+ image_suffixs = set([".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".svg", ".ico", ".apng", ".tif", ".hdr", ".exr"])
14
+
15
+ model_path = folder_paths.models_dir
16
+ add_folder_path_and_extensions("ultralytics_bbox", [os.path.join(model_path, "ultralytics", "bbox")], folder_paths.supported_pt_extensions)
17
+ add_folder_path_and_extensions("ultralytics_segm", [os.path.join(model_path, "ultralytics", "segm")], folder_paths.supported_pt_extensions)
18
+ add_folder_path_and_extensions("ultralytics", [os.path.join(model_path, "ultralytics")], folder_paths.supported_pt_extensions)
19
+ add_folder_path_and_extensions("mmdets_bbox", [os.path.join(model_path, "mmdets", "bbox")], folder_paths.supported_pt_extensions)
20
+ add_folder_path_and_extensions("mmdets_segm", [os.path.join(model_path, "mmdets", "segm")], folder_paths.supported_pt_extensions)
21
+ add_folder_path_and_extensions("mmdets", [os.path.join(model_path, "mmdets")], folder_paths.supported_pt_extensions)
22
+ add_folder_path_and_extensions("sams", [os.path.join(model_path, "sams")], folder_paths.supported_pt_extensions)
23
+ add_folder_path_and_extensions("onnx", [os.path.join(model_path, "onnx")], {'.onnx'})
24
+ add_folder_path_and_extensions("instantid", [os.path.join(model_path, "instantid")], folder_paths.supported_pt_extensions)
25
+ add_folder_path_and_extensions("pulid", [os.path.join(model_path, "pulid")], folder_paths.supported_pt_extensions)
26
+ add_folder_path_and_extensions("layer_model", [os.path.join(model_path, "layer_model")], folder_paths.supported_pt_extensions)
27
+ add_folder_path_and_extensions("rembg", [os.path.join(model_path, "rembg")], folder_paths.supported_pt_extensions)
28
+ add_folder_path_and_extensions("ipadapter", [os.path.join(model_path, "ipadapter")], folder_paths.supported_pt_extensions)
29
+ add_folder_path_and_extensions("dynamicrafter_models", [os.path.join(model_path, "dynamicrafter_models")], folder_paths.supported_pt_extensions)
30
+ add_folder_path_and_extensions("mediapipe", [os.path.join(model_path, "mediapipe")], set(['.tflite','.pth']))
31
+ add_folder_path_and_extensions("inpaint", [os.path.join(model_path, "inpaint")], folder_paths.supported_pt_extensions)
32
+ add_folder_path_and_extensions("prompt_generator", [os.path.join(model_path, "prompt_generator")], folder_paths.supported_pt_extensions)
33
+ add_folder_path_and_extensions("t5", [os.path.join(model_path, "t5")], folder_paths.supported_pt_extensions)
34
+ add_folder_path_and_extensions("llm", [os.path.join(model_path, "LLM")], folder_paths.supported_pt_extensions)
35
+
36
+ add_folder_path_and_extensions("checkpoints_thumb", [os.path.join(model_path, "checkpoints")], image_suffixs)
37
+ add_folder_path_and_extensions("loras_thumb", [os.path.join(model_path, "loras")], image_suffixs)
ComfyUI-Easy-Use/py/__init__.py ADDED
File without changes
ComfyUI-Easy-Use/py/api.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ import sys
4
+ import json
5
+ import shutil
6
+ import folder_paths
7
+ from folder_paths import get_directory_by_type
8
+ from server import PromptServer
9
+ from .config import RESOURCES_DIR, FOOOCUS_STYLES_DIR, FOOOCUS_STYLES_SAMPLES
10
+ from .libs.model import easyModelManager
11
+ from .libs.utils import getMetadata, cleanGPUUsedForce, get_local_filepath
12
+ from .libs.cache import remove_cache
13
+ from .libs.translate import has_chinese, zh_to_en
14
+
15
+ try:
16
+ import aiohttp
17
+ from aiohttp import web
18
+ except ImportError:
19
+ print("Module 'aiohttp' not installed. Please install it via:")
20
+ print("pip install aiohttp")
21
+ sys.exit()
22
+
23
+ @PromptServer.instance.routes.post("/easyuse/cleangpu")
24
+ def cleanGPU(request):
25
+ try:
26
+ cleanGPUUsedForce()
27
+ remove_cache('*')
28
+ return web.Response(status=200)
29
+ except Exception as e:
30
+ return web.Response(status=500)
31
+ pass
32
+
33
+ @PromptServer.instance.routes.post("/easyuse/translate")
34
+ async def translate(request):
35
+ post = await request.post()
36
+ text = post.get("text")
37
+ if has_chinese(text):
38
+ return web.json_response({"text": zh_to_en([text])[0]})
39
+ else:
40
+ return web.json_response({"text": text})
41
+
42
+ @PromptServer.instance.routes.get("/easyuse/reboot")
43
+ def reboot(request):
44
+ try:
45
+ sys.stdout.close_log()
46
+ except Exception as e:
47
+ pass
48
+
49
+ return os.execv(sys.executable, [sys.executable] + sys.argv)
50
+
51
+ # parse csv
52
+ @PromptServer.instance.routes.post("/easyuse/upload/csv")
53
+ async def parse_csv(request):
54
+ post = await request.post()
55
+ csv = post.get("csv")
56
+ if csv and csv.file:
57
+ file = csv.file
58
+ text = ''
59
+ for line in file.readlines():
60
+ line = str(line.strip())
61
+ line = line.replace("'", "").replace("b",'')
62
+ text += line + '; \n'
63
+ return web.json_response(text)
64
+
65
+ #get style list
66
+ @PromptServer.instance.routes.get("/easyuse/prompt/styles")
67
+ async def getStylesList(request):
68
+ if "name" in request.rel_url.query:
69
+ name = request.rel_url.query["name"]
70
+ if name == 'fooocus_styles':
71
+ file = os.path.join(RESOURCES_DIR, name+'.json')
72
+ cn_file = os.path.join(RESOURCES_DIR, name + '_cn.json')
73
+ else:
74
+ file = os.path.join(FOOOCUS_STYLES_DIR, name+'.json')
75
+ cn_file = os.path.join(FOOOCUS_STYLES_DIR, name + '_cn.json')
76
+ cn_data = None
77
+ if os.path.isfile(cn_file):
78
+ f = open(cn_file, 'r', encoding='utf-8')
79
+ cn_data = json.load(f)
80
+ f.close()
81
+ if os.path.isfile(file):
82
+ f = open(file, 'r', encoding='utf-8')
83
+ data = json.load(f)
84
+ f.close()
85
+ if data:
86
+ ndata = []
87
+ for d in data:
88
+ nd = {}
89
+ name = d['name'].replace('-', ' ')
90
+ words = name.split(' ')
91
+ key = ' '.join(
92
+ word.upper() if word.lower() in ['mre', 'sai', '3d'] else word.capitalize() for word in
93
+ words)
94
+ img_name = '_'.join(words).lower()
95
+ if "name_cn" in d:
96
+ nd['name_cn'] = d['name_cn']
97
+ elif cn_data:
98
+ nd['name_cn'] = cn_data[key] if key in cn_data else key
99
+ nd["name"] = d['name']
100
+ nd['imgName'] = img_name
101
+ if "prompt" in d:
102
+ nd['prompt'] = d['prompt']
103
+ if "negative_prompt" in d:
104
+ nd['negative_prompt'] = d['negative_prompt']
105
+ ndata.append(nd)
106
+ return web.json_response(ndata)
107
+ return web.Response(status=400)
108
+
109
+ # get style preview image
110
+ @PromptServer.instance.routes.get("/easyuse/prompt/styles/image")
111
+ async def getStylesImage(request):
112
+ styles_name = request.rel_url.query["styles_name"] if "styles_name" in request.rel_url.query else None
113
+ if "name" in request.rel_url.query:
114
+ name = request.rel_url.query["name"]
115
+ if os.path.exists(os.path.join(FOOOCUS_STYLES_DIR, 'samples')):
116
+ file = os.path.join(FOOOCUS_STYLES_DIR, 'samples', name + '.jpg')
117
+ if os.path.isfile(file):
118
+ return web.FileResponse(file)
119
+ elif styles_name == 'fooocus_styles':
120
+ return web.Response(text=FOOOCUS_STYLES_SAMPLES + name + '.jpg')
121
+ elif styles_name == 'fooocus_styles':
122
+ return web.Response(text=FOOOCUS_STYLES_SAMPLES + name + '.jpg')
123
+ return web.Response(status=400)
124
+
125
+ # get models lists
126
+ @PromptServer.instance.routes.get("/easyuse/models/list")
127
+ async def getModelsList(request):
128
+ if "type" in request.rel_url.query:
129
+ type = request.rel_url.query["type"]
130
+ if type not in ['checkpoints', 'loras']:
131
+ return web.Response(status=400)
132
+ manager = easyModelManager()
133
+ return web.json_response(manager.get_model_lists(type))
134
+ else:
135
+ return web.Response(status=400)
136
+
137
+ # get models thumbnails
138
+ @PromptServer.instance.routes.get("/easyuse/models/thumbnail")
139
+ async def getModelsThumbnail(request):
140
+ limit = 500
141
+ if "limit" in request.rel_url.query:
142
+ limit = request.rel_url.query.get("limit")
143
+ limit = int(limit)
144
+ checkpoints = folder_paths.get_filename_list("checkpoints_thumb")
145
+ loras = folder_paths.get_filename_list("loras_thumb")
146
+ checkpoints_full = []
147
+ loras_full = []
148
+ if len(checkpoints) + len(loras) >= limit:
149
+ return web.Response(status=400)
150
+ for index, i in enumerate(checkpoints):
151
+ full_path = folder_paths.get_full_path('checkpoints_thumb', str(i))
152
+ if full_path:
153
+ checkpoints_full.append(full_path)
154
+ for index, i in enumerate(loras):
155
+ full_path = folder_paths.get_full_path('loras_thumb', str(i))
156
+ if full_path:
157
+ loras_full.append(full_path)
158
+ return web.json_response(checkpoints_full + loras_full)
159
+
160
+ @PromptServer.instance.routes.post("/easyuse/metadata/notes/{name}")
161
+ async def save_notes(request):
162
+ name = request.match_info["name"]
163
+ pos = name.index("/")
164
+ type = name[0:pos]
165
+ name = name[pos+1:]
166
+
167
+ file_path = None
168
+ if type == "embeddings" or type == "loras":
169
+ name = name.lower()
170
+ files = folder_paths.get_filename_list(type)
171
+ for f in files:
172
+ lower_f = f.lower()
173
+ if lower_f == name:
174
+ file_path = folder_paths.get_full_path(type, f)
175
+ else:
176
+ n = os.path.splitext(f)[0].lower()
177
+ if n == name:
178
+ file_path = folder_paths.get_full_path(type, f)
179
+
180
+ if file_path is not None:
181
+ break
182
+ else:
183
+ file_path = folder_paths.get_full_path(
184
+ type, name)
185
+ if not file_path:
186
+ return web.Response(status=404)
187
+
188
+ file_no_ext = os.path.splitext(file_path)[0]
189
+ info_file = file_no_ext + ".txt"
190
+ with open(info_file, "w") as f:
191
+ f.write(await request.text())
192
+
193
+ return web.Response(status=200)
194
+
195
+ @PromptServer.instance.routes.get("/easyuse/metadata/{name}")
196
+ async def load_metadata(request):
197
+ name = request.match_info["name"]
198
+ pos = name.index("/")
199
+ type = name[0:pos]
200
+ name = name[pos+1:]
201
+
202
+ file_path = None
203
+ if type == "embeddings":
204
+ name = name.lower()
205
+ files = folder_paths.get_filename_list(type)
206
+ for f in files:
207
+ lower_f = f.lower()
208
+ if lower_f == name:
209
+ file_path = folder_paths.get_full_path(type, f)
210
+ else:
211
+ n = os.path.splitext(f)[0].lower()
212
+ if n == name:
213
+ file_path = folder_paths.get_full_path(type, f)
214
+
215
+ if file_path is not None:
216
+ break
217
+ else:
218
+ file_path = folder_paths.get_full_path(type, name)
219
+ if not file_path:
220
+ return web.Response(status=404)
221
+
222
+ try:
223
+ header = getMetadata(file_path)
224
+ header_json = json.loads(header)
225
+ meta = header_json["__metadata__"] if "__metadata__" in header_json else None
226
+ except:
227
+ meta = None
228
+
229
+ if meta is None:
230
+ meta = {}
231
+
232
+ file_no_ext = os.path.splitext(file_path)[0]
233
+
234
+ info_file = file_no_ext + ".txt"
235
+ if os.path.isfile(info_file):
236
+ with open(info_file, "r") as f:
237
+ meta["easyuse.notes"] = f.read()
238
+
239
+ hash_file = file_no_ext + ".sha256"
240
+ if os.path.isfile(hash_file):
241
+ with open(hash_file, "rt") as f:
242
+ meta["easyuse.sha256"] = f.read()
243
+ else:
244
+ with open(file_path, "rb") as f:
245
+ meta["easyuse.sha256"] = hashlib.sha256(f.read()).hexdigest()
246
+ with open(hash_file, "wt") as f:
247
+ f.write(meta["easyuse.sha256"])
248
+
249
+ return web.json_response(meta)
250
+
251
+ @PromptServer.instance.routes.post("/easyuse/save/{name}")
252
+ async def save_preview(request):
253
+ name = request.match_info["name"]
254
+ pos = name.index("/")
255
+ type = name[0:pos]
256
+ name = name[pos+1:]
257
+
258
+ body = await request.json()
259
+
260
+ dir = get_directory_by_type(body.get("type", "output"))
261
+ subfolder = body.get("subfolder", "")
262
+ full_output_folder = os.path.join(dir, os.path.normpath(subfolder))
263
+
264
+ if os.path.commonpath((dir, os.path.abspath(full_output_folder))) != dir:
265
+ return web.Response(status=400)
266
+
267
+ filepath = os.path.join(full_output_folder, body.get("filename", ""))
268
+ image_path = folder_paths.get_full_path(type, name)
269
+ image_path = os.path.splitext(
270
+ image_path)[0] + os.path.splitext(filepath)[1]
271
+
272
+ shutil.copyfile(filepath, image_path)
273
+
274
+ return web.json_response({
275
+ "image": type + "/" + os.path.basename(image_path)
276
+ })
277
+
278
+ @PromptServer.instance.routes.post("/easyuse/model/download")
279
+ async def download_model(request):
280
+ post = await request.post()
281
+ url = post.get("url")
282
+ local_dir = post.get("local_dir")
283
+ if local_dir not in ['checkpoints', 'loras', 'controlnet', 'onnx', 'instantid', 'ipadapter', 'dynamicrafter_models', 'mediapipe', 'rembg', 'layer_model']:
284
+ return web.Response(status=400)
285
+ local_path = os.path.join(folder_paths.models_dir, local_dir)
286
+ try:
287
+ get_local_filepath(url, local_path)
288
+ return web.Response(status=200)
289
+ except:
290
+ return web.Response(status=500)
291
+
292
+ NODE_CLASS_MAPPINGS = {}
293
+ NODE_DISPLAY_NAME_MAPPINGS = {}
ComfyUI-Easy-Use/py/bitsandbytes_NF4/__init__.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #credit to comfyanonymous for this module
2
+ #from https://github.com/comfyanonymous/ComfyUI_bitsandbytes_NF4
3
+ import comfy.ops
4
+ import torch
5
+ import folder_paths
6
+ from ..libs.utils import install_package
7
+
8
+ try:
9
+ from bitsandbytes.nn.modules import Params4bit, QuantState
10
+ except ImportError:
11
+ Params4bit = torch.nn.Parameter
12
+ raise ImportError("Please install bitsandbytes>=0.43.3")
13
+
14
+ def functional_linear_4bits(x, weight, bias):
15
+ try:
16
+ install_package("bitsandbytes", "0.43.3", True, "0.43.3")
17
+ import bitsandbytes as bnb
18
+ except ImportError:
19
+ raise ImportError("Please install bitsandbytes>=0.43.3")
20
+
21
+ out = bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=weight.quant_state)
22
+ out = out.to(x)
23
+ return out
24
+
25
+
26
+ def copy_quant_state(state, device: torch.device = None):
27
+ if state is None:
28
+ return None
29
+
30
+ device = device or state.absmax.device
31
+
32
+ state2 = (
33
+ QuantState(
34
+ absmax=state.state2.absmax.to(device),
35
+ shape=state.state2.shape,
36
+ code=state.state2.code.to(device),
37
+ blocksize=state.state2.blocksize,
38
+ quant_type=state.state2.quant_type,
39
+ dtype=state.state2.dtype,
40
+ )
41
+ if state.nested
42
+ else None
43
+ )
44
+
45
+ return QuantState(
46
+ absmax=state.absmax.to(device),
47
+ shape=state.shape,
48
+ code=state.code.to(device),
49
+ blocksize=state.blocksize,
50
+ quant_type=state.quant_type,
51
+ dtype=state.dtype,
52
+ offset=state.offset.to(device) if state.nested else None,
53
+ state2=state2,
54
+ )
55
+
56
+
57
+ class ForgeParams4bit(Params4bit):
58
+
59
+ def to(self, *args, **kwargs):
60
+ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
61
+ if device is not None and device.type == "cuda" and not self.bnb_quantized:
62
+ return self._quantize(device)
63
+ else:
64
+ n = ForgeParams4bit(
65
+ torch.nn.Parameter.to(self, device=device, dtype=dtype, non_blocking=non_blocking),
66
+ requires_grad=self.requires_grad,
67
+ quant_state=copy_quant_state(self.quant_state, device),
68
+ blocksize=self.blocksize,
69
+ compress_statistics=self.compress_statistics,
70
+ quant_type=self.quant_type,
71
+ quant_storage=self.quant_storage,
72
+ bnb_quantized=self.bnb_quantized,
73
+ module=self.module
74
+ )
75
+ self.module.quant_state = n.quant_state
76
+ self.data = n.data
77
+ self.quant_state = n.quant_state
78
+ return n
79
+
80
+ class ForgeLoader4Bit(torch.nn.Module):
81
+ def __init__(self, *, device, dtype, quant_type, **kwargs):
82
+ super().__init__()
83
+ self.dummy = torch.nn.Parameter(torch.empty(1, device=device, dtype=dtype))
84
+ self.weight = None
85
+ self.quant_state = None
86
+ self.bias = None
87
+ self.quant_type = quant_type
88
+
89
+ def _save_to_state_dict(self, destination, prefix, keep_vars):
90
+ super()._save_to_state_dict(destination, prefix, keep_vars)
91
+ quant_state = getattr(self.weight, "quant_state", None)
92
+ if quant_state is not None:
93
+ for k, v in quant_state.as_dict(packed=True).items():
94
+ destination[prefix + "weight." + k] = v if keep_vars else v.detach()
95
+ return
96
+
97
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
98
+ quant_state_keys = {k[len(prefix + "weight."):] for k in state_dict.keys() if k.startswith(prefix + "weight.")}
99
+
100
+ if any('bitsandbytes' in k for k in quant_state_keys):
101
+ quant_state_dict = {k: state_dict[prefix + "weight." + k] for k in quant_state_keys}
102
+
103
+ self.weight = ForgeParams4bit().from_prequantized(
104
+ data=state_dict[prefix + 'weight'],
105
+ quantized_stats=quant_state_dict,
106
+ requires_grad=False,
107
+ device=self.dummy.device,
108
+ module=self
109
+ )
110
+ self.quant_state = self.weight.quant_state
111
+
112
+ if prefix + 'bias' in state_dict:
113
+ self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
114
+
115
+ del self.dummy
116
+ elif hasattr(self, 'dummy'):
117
+ if prefix + 'weight' in state_dict:
118
+ self.weight = ForgeParams4bit(
119
+ state_dict[prefix + 'weight'].to(self.dummy),
120
+ requires_grad=False,
121
+ compress_statistics=True,
122
+ quant_type=self.quant_type,
123
+ quant_storage=torch.uint8,
124
+ module=self,
125
+ )
126
+ self.quant_state = self.weight.quant_state
127
+
128
+ if prefix + 'bias' in state_dict:
129
+ self.bias = torch.nn.Parameter(state_dict[prefix + 'bias'].to(self.dummy))
130
+
131
+ del self.dummy
132
+ else:
133
+ super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
134
+
135
+ current_device = None
136
+ current_dtype = None
137
+ current_manual_cast_enabled = False
138
+ current_bnb_dtype = None
139
+
140
+ class OPS(comfy.ops.manual_cast):
141
+ class Linear(ForgeLoader4Bit):
142
+ def __init__(self, *args, device=None, dtype=None, **kwargs):
143
+ super().__init__(device=device, dtype=dtype, quant_type=current_bnb_dtype)
144
+ self.parameters_manual_cast = current_manual_cast_enabled
145
+
146
+ def forward(self, x):
147
+ self.weight.quant_state = self.quant_state
148
+
149
+ if self.bias is not None and self.bias.dtype != x.dtype:
150
+ # Maybe this can also be set to all non-bnb ops since the cost is very low.
151
+ # And it only invokes one time, and most linear does not have bias
152
+ self.bias.data = self.bias.data.to(x.dtype)
153
+
154
+ if not self.parameters_manual_cast:
155
+ return functional_linear_4bits(x, self.weight, self.bias)
156
+ elif not self.weight.bnb_quantized:
157
+ assert x.device.type == 'cuda', 'BNB Must Use CUDA as Computation Device!'
158
+ layer_original_device = self.weight.device
159
+ self.weight = self.weight._quantize(x.device)
160
+ bias = self.bias.to(x.device) if self.bias is not None else None
161
+ out = functional_linear_4bits(x, self.weight, bias)
162
+ self.weight = self.weight.to(layer_original_device)
163
+ return out
164
+ else:
165
+ weight, bias, signal = weights_manual_cast(self, x, skip_weight_dtype=True, skip_bias_dtype=True)
166
+ with main_stream_worker(weight, bias, signal):
167
+ return functional_linear_4bits(x, weight, bias)