HaoranChu commited on
Commit
e9ad0c9
·
verified ·
1 Parent(s): f7e973d

Upload 47 files

Browse files
Files changed (47) hide show
  1. GPTQ-for-Qwen_hf/.gitignore +1 -0
  2. GPTQ-for-Qwen_hf/.style.yapf +3 -0
  3. GPTQ-for-Qwen_hf/LICENSE.txt +201 -0
  4. GPTQ-for-Qwen_hf/README.md +188 -0
  5. GPTQ-for-Qwen_hf/__pycache__/categories.cpython-310.pyc +0 -0
  6. GPTQ-for-Qwen_hf/__pycache__/datautils.cpython-310.pyc +0 -0
  7. GPTQ-for-Qwen_hf/__pycache__/evaluate_.cpython-310.pyc +0 -0
  8. GPTQ-for-Qwen_hf/__pycache__/gptq.cpython-310.pyc +0 -0
  9. GPTQ-for-Qwen_hf/__pycache__/gptq.cpython-38.pyc +0 -0
  10. GPTQ-for-Qwen_hf/pip.log +1 -0
  11. GPTQ-for-Qwen_hf/quant/__init__.py +5 -0
  12. GPTQ-for-Qwen_hf/quant/__pycache__/__init__.cpython-310.pyc +0 -0
  13. GPTQ-for-Qwen_hf/quant/__pycache__/__init__.cpython-38.pyc +0 -0
  14. GPTQ-for-Qwen_hf/quant/__pycache__/custom_autotune.cpython-310.pyc +0 -0
  15. GPTQ-for-Qwen_hf/quant/__pycache__/custom_autotune.cpython-38.pyc +0 -0
  16. GPTQ-for-Qwen_hf/quant/__pycache__/fused_attn.cpython-310.pyc +0 -0
  17. GPTQ-for-Qwen_hf/quant/__pycache__/fused_attn.cpython-38.pyc +0 -0
  18. GPTQ-for-Qwen_hf/quant/__pycache__/fused_mlp.cpython-310.pyc +0 -0
  19. GPTQ-for-Qwen_hf/quant/__pycache__/fused_mlp.cpython-38.pyc +0 -0
  20. GPTQ-for-Qwen_hf/quant/__pycache__/quant_linear.cpython-310.pyc +0 -0
  21. GPTQ-for-Qwen_hf/quant/__pycache__/quant_linear.cpython-38.pyc +0 -0
  22. GPTQ-for-Qwen_hf/quant/__pycache__/quantizer.cpython-310.pyc +0 -0
  23. GPTQ-for-Qwen_hf/quant/__pycache__/quantizer.cpython-38.pyc +0 -0
  24. GPTQ-for-Qwen_hf/quant/__pycache__/triton_norm.cpython-310.pyc +0 -0
  25. GPTQ-for-Qwen_hf/quant/__pycache__/triton_norm.cpython-38.pyc +0 -0
  26. GPTQ-for-Qwen_hf/quant/custom_autotune.py +194 -0
  27. GPTQ-for-Qwen_hf/quant/fused_attn.py +204 -0
  28. GPTQ-for-Qwen_hf/quant/fused_mlp.py +288 -0
  29. GPTQ-for-Qwen_hf/quant/quant_linear.py +423 -0
  30. GPTQ-for-Qwen_hf/quant/quantizer.py +127 -0
  31. GPTQ-for-Qwen_hf/quant/triton_norm.py +92 -0
  32. GPTQ-for-Qwen_hf/qwen.py +292 -0
  33. GPTQ-for-Qwen_hf/qwen_gptq_0.6B_loadtest.log +37 -0
  34. GPTQ-for-Qwen_hf/requirements.txt +11 -0
  35. GPTQ-for-Qwen_hf/test.sh +1 -0
  36. GPTQ-for-Qwen_hf/utils/__init__.py +3 -0
  37. GPTQ-for-Qwen_hf/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  38. GPTQ-for-Qwen_hf/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  39. GPTQ-for-Qwen_hf/utils/__pycache__/datautils.cpython-310.pyc +0 -0
  40. GPTQ-for-Qwen_hf/utils/__pycache__/datautils.cpython-38.pyc +0 -0
  41. GPTQ-for-Qwen_hf/utils/__pycache__/export.cpython-310.pyc +0 -0
  42. GPTQ-for-Qwen_hf/utils/__pycache__/export.cpython-38.pyc +0 -0
  43. GPTQ-for-Qwen_hf/utils/__pycache__/modelutils.cpython-310.pyc +0 -0
  44. GPTQ-for-Qwen_hf/utils/__pycache__/modelutils.cpython-38.pyc +0 -0
  45. GPTQ-for-Qwen_hf/utils/datautils.py +234 -0
  46. GPTQ-for-Qwen_hf/utils/export.py +37 -0
  47. GPTQ-for-Qwen_hf/utils/modelutils.py +83 -0
GPTQ-for-Qwen_hf/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
GPTQ-for-Qwen_hf/.style.yapf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ column_limit = 200
GPTQ-for-Qwen_hf/LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
GPTQ-for-Qwen_hf/README.md ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GPTQ-for-LLaMA
2
+
3
+ **I am currently focusing on [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) and recommend using [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ) instead of GPTQ for Llama.**
4
+
5
+ <img src = https://user-images.githubusercontent.com/64115820/235287009-2d07bba8-9b85-4973-9e06-2a3c28777f06.png width="50%" height="50%">
6
+
7
+ 4 bits quantization of [LLaMA](https://arxiv.org/abs/2302.13971) using [GPTQ](https://arxiv.org/abs/2210.17323)
8
+
9
+ GPTQ is SOTA one-shot weight quantization method
10
+
11
+ **It can be used universally, but it is not the [fastest](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/old-cuda) and only supports linux.**
12
+
13
+ **Triton only supports Linux, so if you are a Windows user, please use [WSL2](https://learn.microsoft.com/en-us/windows/wsl/install).**
14
+
15
+ ## News or Update
16
+ **AutoGPTQ-triton, a packaged version of GPTQ with triton, has been integrated into [AutoGPTQ](https://github.com/PanQiWei/AutoGPTQ).**
17
+ ## Result
18
+ <details>
19
+ <summary>LLaMA-7B(click me)</summary>
20
+
21
+ | [LLaMA-7B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
22
+ | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
23
+ | FP16 | 16 | - | 13940 | 5.68 | 12.5 |
24
+ | RTN | 4 | - | - | 6.29 | - |
25
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 4740 | 6.09 | 3.5 |
26
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 4891 | 5.85 | 3.6 |
27
+ | RTN | 3 | - | - | 25.54 | - |
28
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 3852 | 8.07 | 2.7 |
29
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 4116 | 6.61 | 3.0 |
30
+
31
+ </details>
32
+
33
+ <details>
34
+ <summary>LLaMA-13B</summary>
35
+
36
+ | [LLaMA-13B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
37
+ | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
38
+ | FP16 | 16 | - | OOM | 5.09 | 24.2 |
39
+ | RTN | 4 | - | - | 5.53 | - |
40
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 8410 | 5.36 | 6.5 |
41
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 8747 | 5.20 | 6.7 |
42
+ | RTN | 3 | - | - | 11.40 | - |
43
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 6870 | 6.63 | 5.1 |
44
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 7277 | 5.62 | 5.4 |
45
+
46
+ </details>
47
+
48
+ <details>
49
+ <summary>LLaMA-33B</summary>
50
+
51
+ | [LLaMA-33B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
52
+ | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
53
+ | FP16 | 16 | - | OOM | 4.10 | 60.5 |
54
+ | RTN | 4 | - | - | 4.54 | - |
55
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | 19493 | 4.45 | 15.7 |
56
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | 20570 | 4.23 | 16.3 |
57
+ | RTN | 3 | - | - | 14.89 | - |
58
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | 15493 | 5.69 | 12.0 |
59
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | 16566 | 4.80 | 13.0 |
60
+
61
+ </details>
62
+
63
+ <details>
64
+ <summary>LLaMA-65B</summary>
65
+
66
+ | [LLaMA-65B](https://arxiv.org/abs/2302.13971) | Bits | group-size | memory(MiB) | Wikitext2 | checkpoint size(GB) |
67
+ | -------------------------------------------------- | ---- | ---------- | ----------- | --------- | ------------------- |
68
+ | FP16 | 16 | - | OOM | 3.53 | 121.0 |
69
+ | RTN | 4 | - | - | 3.92 | - |
70
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | - | OOM | 3.84 | 31.1 |
71
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 4 | 128 | OOM | 3.65 | 32.3 |
72
+ | RTN | 3 | - | - | 10.59 | - |
73
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | - | OOM | 5.04 | 23.6 |
74
+ | [GPTQ](https://arxiv.org/abs/2210.17323) | 3 | 128 | OOM | 4.17 | 25.6 |
75
+ </details>
76
+
77
+ Quantization requires a large amount of CPU memory. However, the memory required can be reduced by using swap memory.
78
+
79
+ Depending on the GPUs/drivers, there may be a difference in performance, which decreases as the model size increases.(https://github.com/IST-DASLab/gptq/issues/1)
80
+
81
+ According to [GPTQ paper](https://arxiv.org/abs/2210.17323), As the size of the model increases, the difference in performance between FP16 and GPTQ decreases.
82
+
83
+ ## GPTQ vs [bitsandbytes](https://github.com/TimDettmers/bitsandbytes)
84
+
85
+ <details>
86
+ <summary>LLaMA-7B(click me)</summary>
87
+
88
+ | [LLaMA-7B(seqlen=2048)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) |
89
+ | --------------------------------------------------------------- | ------------------- | ----------- | --------- |
90
+ | FP16 | 16 | 13948 | 5.22 |
91
+ | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 4781 | 5.30 |
92
+ | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 4804 | 5.30 |
93
+ | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 5102 | 5.30 |
94
+ | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 5102 | 5.33 |
95
+
96
+ </details>
97
+
98
+ <details>
99
+ <summary>LLaMA-13B</summary>
100
+
101
+ | [LLaMA-13B(seqlen=2048)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) |
102
+ | ---------------------------------------------------------------- | ------------------- | ----------- | --------- |
103
+ | FP16 | 16 | OOM | - |
104
+ | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 8589 | 5.02 |
105
+ | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 8581 | 5.04 |
106
+ | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 9170 | 5.04 |
107
+ | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 9170 | 5.11 |
108
+ </details>
109
+
110
+ <details>
111
+ <summary>LLaMA-33B</summary>
112
+
113
+ | [LLaMA-33B(seqlen=1024)](https://arxiv.org/abs/2302.13971) | Bits Per Weight(BPW)| memory(MiB) | c4(ppl) |
114
+ | ---------------------------------------------------------------- | ------------------- | ----------- | --------- |
115
+ | FP16 | 16 | OOM | - |
116
+ | [GPTQ-128g](https://arxiv.org/abs/2210.17323) | 4.15 | 18441 | 3.71 |
117
+ | [nf4-double_quant](https://arxiv.org/abs/2305.14314) | 4.127 | 18313 | 3.76 |
118
+ | [nf4](https://arxiv.org/abs/2305.14314) | 4.5 | 19729 | 3.75 |
119
+ | [fp4](https://arxiv.org/abs/2212.09720) | 4.5 | 19729 | 3.75 |
120
+
121
+ </details>
122
+
123
+ ## Installation
124
+ If you don't have [conda](https://docs.conda.io/en/latest/miniconda.html), install it first.
125
+ ```
126
+ conda create --name gptq python=3.9 -y
127
+ conda activate gptq
128
+ conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia
129
+ # Or, if you're having trouble with conda, use pip with python3.9:
130
+ # pip3 install torch torchvision torchaudio
131
+
132
+ git clone https://github.com/qwopqwop200/GPTQ-for-LLaMa
133
+ cd GPTQ-for-LLaMa
134
+ pip install -r requirements.txt
135
+ ```
136
+ ## Dependencies
137
+
138
+ * `torch`: tested on v2.0.0+cu117
139
+ * `transformers`: tested on v4.28.0.dev0
140
+ * `datasets`: tested on v2.10.1
141
+ * `safetensors`: tested on v0.3.0
142
+
143
+ All experiments were run on a single NVIDIA RTX3090.
144
+
145
+ # Language Generation
146
+ ## LLaMA
147
+
148
+ ```
149
+ #convert LLaMA to hf
150
+ python convert_llama_weights_to_hf.py --input_dir /path/to/downloaded/llama/weights --model_size 7B --output_dir ./llama-hf
151
+
152
+ # Benchmark language generation with 4-bit LLaMA-7B:
153
+
154
+ # Save compressed model
155
+ CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save llama7b-4bit-128g.pt
156
+
157
+ # Or save compressed `.safetensors` model
158
+ CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --save_safetensors llama7b-4bit-128g.safetensors
159
+
160
+ # Benchmark generating a 2048 token sequence with the saved model
161
+ CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --benchmark 2048 --check
162
+
163
+ # Benchmark FP16 baseline, note that the model will be split across all listed GPUs
164
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4 python llama.py ${MODEL_DIR} c4 --benchmark 2048 --check
165
+
166
+ # model inference with the saved model
167
+ CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama"
168
+
169
+ # model inference with the saved model using safetensors loaded direct to gpu
170
+ CUDA_VISIBLE_DEVICES=0 python llama_inference.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.safetensors --text "this is llama" --device=0
171
+
172
+ # model inference with the saved model with offload(This is very slow).
173
+ CUDA_VISIBLE_DEVICES=0 python llama_inference_offload.py ${MODEL_DIR} --wbits 4 --groupsize 128 --load llama7b-4bit-128g.pt --text "this is llama" --pre_layer 16
174
+ It takes about 180 seconds to generate 45 tokens(5->50 tokens) on single RTX3090 based on LLaMa-65B. pre_layer is set to 50.
175
+ ```
176
+ Basically, 4-bit quantization and 128 groupsize are recommended.
177
+
178
+ You can also export quantization parameters with toml+numpy format.
179
+ ```
180
+ CUDA_VISIBLE_DEVICES=0 python llama.py ${MODEL_DIR} c4 --wbits 4 --true-sequential --act-order --groupsize 128 --quant-directory ${TOML_DIR}
181
+ ```
182
+
183
+ # Acknowledgements
184
+ This code is based on [GPTQ](https://github.com/IST-DASLab/gptq)
185
+
186
+ Thanks to Meta AI for releasing [LLaMA](https://arxiv.org/abs/2302.13971), a powerful LLM.
187
+
188
+ Triton GPTQ kernel code is based on [GPTQ-triton](https://github.com/fpgaminer/GPTQ-triton)
GPTQ-for-Qwen_hf/__pycache__/categories.cpython-310.pyc ADDED
Binary file (2.32 kB). View file
 
GPTQ-for-Qwen_hf/__pycache__/datautils.cpython-310.pyc ADDED
Binary file (15.7 kB). View file
 
GPTQ-for-Qwen_hf/__pycache__/evaluate_.cpython-310.pyc ADDED
Binary file (4.83 kB). View file
 
GPTQ-for-Qwen_hf/__pycache__/gptq.cpython-310.pyc ADDED
Binary file (6.31 kB). View file
 
GPTQ-for-Qwen_hf/__pycache__/gptq.cpython-38.pyc ADDED
Binary file (6.3 kB). View file
 
GPTQ-for-Qwen_hf/pip.log ADDED
@@ -0,0 +1 @@
 
 
1
+ pip install lm_eval==0.4.7
GPTQ-for-Qwen_hf/quant/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .quantizer import Quantizer
2
+ from .fused_attn import QuantLlamaAttention, make_quant_attn
3
+ from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused
4
+ from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear
5
+ from .triton_norm import TritonLlamaRMSNorm, make_quant_norm
GPTQ-for-Qwen_hf/quant/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (575 Bytes). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (578 Bytes). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/custom_autotune.cpython-310.pyc ADDED
Binary file (7.94 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/custom_autotune.cpython-38.pyc ADDED
Binary file (7.95 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/fused_attn.cpython-310.pyc ADDED
Binary file (5.15 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/fused_attn.cpython-38.pyc ADDED
Binary file (5.11 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/fused_mlp.cpython-310.pyc ADDED
Binary file (7.2 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/fused_mlp.cpython-38.pyc ADDED
Binary file (7.29 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/quant_linear.cpython-310.pyc ADDED
Binary file (11 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/quant_linear.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/quantizer.cpython-310.pyc ADDED
Binary file (3.48 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/quantizer.cpython-38.pyc ADDED
Binary file (3.47 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/triton_norm.cpython-310.pyc ADDED
Binary file (2.67 kB). View file
 
GPTQ-for-Qwen_hf/quant/__pycache__/triton_norm.cpython-38.pyc ADDED
Binary file (2.64 kB). View file
 
GPTQ-for-Qwen_hf/quant/custom_autotune.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/fpgaminer/GPTQ-triton
2
+ """
3
+ Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
4
+ """
5
+
6
+ import builtins
7
+ import math
8
+ import time
9
+ from typing import Dict
10
+
11
+ import triton
12
+
13
+
14
+ class Autotuner(triton.KernelInterface):
15
+
16
+ def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False):
17
+ '''
18
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
19
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
20
+ 'top_k': number of configs to bench
21
+ 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
22
+ 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results
23
+ '''
24
+ if not configs:
25
+ self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
26
+ else:
27
+ self.configs = configs
28
+ self.key_idx = [arg_names.index(k) for k in key]
29
+ self.nearest_power_of_two = nearest_power_of_two
30
+ self.cache = {}
31
+ # hook to reset all required tensor to zeros before relaunching a kernel
32
+ self.hook = lambda args: 0
33
+ if reset_to_zero is not None:
34
+ self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
35
+
36
+ def _hook(args):
37
+ for i in self.reset_idx:
38
+ args[i].zero_()
39
+
40
+ self.hook = _hook
41
+ self.arg_names = arg_names
42
+ # prune configs
43
+ if prune_configs_by:
44
+ perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k']
45
+ if 'early_config_prune' in prune_configs_by:
46
+ early_config_prune = prune_configs_by['early_config_prune']
47
+ else:
48
+ perf_model, top_k, early_config_prune = None, None, None
49
+ self.perf_model, self.configs_top_k = perf_model, top_k
50
+ self.early_config_prune = early_config_prune
51
+ self.fn = fn
52
+
53
+ def _bench(self, *args, config, **meta):
54
+ # check for conflicts, i.e. meta-parameters both provided
55
+ # as kwargs and by the autotuner
56
+ conflicts = meta.keys() & config.kwargs.keys()
57
+ if conflicts:
58
+ raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}."
59
+ " Make sure that you don't re-define auto-tuned symbols.")
60
+ # augment meta-parameters with tunable ones
61
+ current = dict(meta, **config.kwargs)
62
+
63
+ def kernel_call():
64
+ if config.pre_hook:
65
+ config.pre_hook(self.nargs)
66
+ self.hook(args)
67
+ self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
68
+
69
+ try:
70
+ # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
71
+ # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
72
+ # return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40)
73
+ return tuple(triton.testing.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8),rep=40))
74
+ except triton.OutOfResources:
75
+ return (float('inf'), float('inf'), float('inf'))
76
+
77
+ def run(self, *args, **kwargs):
78
+ self.nargs = dict(zip(self.arg_names, args))
79
+ if len(self.configs) > 1:
80
+ key = tuple(args[i] for i in self.key_idx)
81
+
82
+ # This reduces the amount of autotuning by rounding the keys to the nearest power of two
83
+ # In my testing this gives decent results, and greatly reduces the amount of tuning required
84
+ if self.nearest_power_of_two:
85
+ key = tuple([2**int(math.log2(x) + 0.5) for x in key])
86
+
87
+ if key not in self.cache:
88
+ # prune configs
89
+ pruned_configs = self.prune_configs(kwargs)
90
+ bench_start = time.time()
91
+ timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
92
+ bench_end = time.time()
93
+ self.bench_time = bench_end - bench_start
94
+ self.cache[key] = builtins.min(timings, key=timings.get)
95
+ self.hook(args)
96
+ self.configs_timings = timings
97
+ config = self.cache[key]
98
+ else:
99
+ config = self.configs[0]
100
+ self.best_config = config
101
+ if config.pre_hook is not None:
102
+ config.pre_hook(self.nargs)
103
+ return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs)
104
+
105
+ def prune_configs(self, kwargs):
106
+ pruned_configs = self.configs
107
+ if self.early_config_prune:
108
+ pruned_configs = self.early_config_prune(self.configs, self.nargs)
109
+ if self.perf_model:
110
+ top_k = self.configs_top_k
111
+ if isinstance(top_k, float) and top_k <= 1.0:
112
+ top_k = int(len(self.configs) * top_k)
113
+ if len(pruned_configs) > top_k:
114
+ est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs}
115
+ pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
116
+ return pruned_configs
117
+
118
+ def warmup(self, *args, **kwargs):
119
+ self.nargs = dict(zip(self.arg_names, args))
120
+ for config in self.prune_configs(kwargs):
121
+ self.fn.warmup(
122
+ *args,
123
+ num_warps=config.num_warps,
124
+ num_stages=config.num_stages,
125
+ **kwargs,
126
+ **config.kwargs,
127
+ )
128
+ self.nargs = None
129
+
130
+
131
+ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
132
+ """
133
+ Decorator for auto-tuning a :code:`triton.jit`'d function.
134
+ .. highlight:: python
135
+ .. code-block:: python
136
+ @triton.autotune(configs=[
137
+ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
138
+ triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
139
+ ],
140
+ key=['x_size'] # the two above configs will be evaluated anytime
141
+ # the value of x_size changes
142
+ )
143
+ @triton.jit
144
+ def kernel(x_ptr, x_size, **META):
145
+ BLOCK_SIZE = META['BLOCK_SIZE']
146
+ :note: When all the configurations are evaluated, the kernel will run multiple time.
147
+ This means that whatever value the kernel updates will be updated multiple times.
148
+ To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
149
+ reset the value of the provided tensor to `zero` before running any configuration.
150
+ :param configs: a list of :code:`triton.Config` objects
151
+ :type configs: list[triton.Config]
152
+ :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
153
+ :type key: list[str]
154
+ :param prune_configs_by: a dict of functions that are used to prune configs, fields:
155
+ 'perf_model': performance model used to predicate running time with different configs, returns running time
156
+ 'top_k': number of configs to bench
157
+ 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
158
+ :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
159
+ :type reset_to_zero: list[str]
160
+ """
161
+
162
+ def decorator(fn):
163
+ return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two)
164
+
165
+ return decorator
166
+
167
+
168
+ def matmul248_kernel_config_pruner(configs, nargs):
169
+ """
170
+ The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
171
+ """
172
+ m = max(2**int(math.ceil(math.log2(nargs['M']))), 16)
173
+ n = max(2**int(math.ceil(math.log2(nargs['N']))), 16)
174
+ k = max(2**int(math.ceil(math.log2(nargs['K']))), 16)
175
+
176
+ used = set()
177
+ for config in configs:
178
+ block_size_m = min(m, config.kwargs['BLOCK_SIZE_M'])
179
+ block_size_n = min(n, config.kwargs['BLOCK_SIZE_N'])
180
+ block_size_k = min(k, config.kwargs['BLOCK_SIZE_K'])
181
+ group_size_m = config.kwargs['GROUP_SIZE_M']
182
+
183
+ if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used:
184
+ continue
185
+
186
+ used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps))
187
+ yield triton.Config({
188
+ 'BLOCK_SIZE_M': block_size_m,
189
+ 'BLOCK_SIZE_N': block_size_n,
190
+ 'BLOCK_SIZE_K': block_size_k,
191
+ 'GROUP_SIZE_M': group_size_m
192
+ },
193
+ num_stages=config.num_stages,
194
+ num_warps=config.num_warps)
GPTQ-for-Qwen_hf/quant/fused_attn.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.nn import functional as F
2
+ from transformers.models.llama.modeling_llama import LlamaAttention
3
+ from .quant_linear import *
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def rotate_half_kernel(
10
+ qk_seq_ptr,
11
+ position_ids_ptr,
12
+ qk_seq_stride,
13
+ position_ids_batch_stride,
14
+ seq_len,
15
+ HEAD_DIM: tl.constexpr,
16
+ BLOCK_HEIGHT: tl.constexpr,
17
+ BLOCK_WIDTH: tl.constexpr,
18
+ INV_BASE: tl.constexpr
19
+ ):
20
+ # qk_seq_ptr: (bsz, seq_len, 2, num_heads, head_dim) -- OK to be discontinuous in 2nd dimension.
21
+ # position ids: (bsz, seq_len) -- must be contiguous in the last dimension.
22
+
23
+ HALF_HEAD: tl.constexpr = HEAD_DIM // 2
24
+ STEPS_PER_ROW: tl.constexpr = HALF_HEAD // BLOCK_WIDTH
25
+
26
+ batch_seq = tl.program_id(axis=0)
27
+ row_blk_x_col_blk = tl.program_id(axis=1)
28
+
29
+ row_blk = row_blk_x_col_blk // STEPS_PER_ROW
30
+ row = row_blk * BLOCK_HEIGHT
31
+ if BLOCK_WIDTH < HALF_HEAD:
32
+ col_blk = row_blk_x_col_blk % STEPS_PER_ROW
33
+ col = col_blk * BLOCK_WIDTH
34
+ else:
35
+ col: tl.constexpr = 0
36
+
37
+ # A block will never cross a sequence boundary, which simplifies things a lot.
38
+ batch = batch_seq // seq_len
39
+ seq = batch_seq % seq_len
40
+ position_id = tl.load(position_ids_ptr + batch * position_ids_batch_stride + seq)
41
+ # As sometimes happens, just calculating this on the fly is faster than loading it from memory.
42
+ # Use `tl.libdevice.exp` rather than `tl.exp` -- the latter is less accurate.
43
+ freq = tl.libdevice.exp((col + tl.arange(0, BLOCK_WIDTH)).to(tl.float32) * INV_BASE) * position_id
44
+ cos = tl.cos(freq).to(tl.float32)
45
+ sin = tl.sin(freq).to(tl.float32)
46
+
47
+ col_offsets: tl.constexpr = tl.arange(0, BLOCK_WIDTH)
48
+ embed_offsets = (row * HEAD_DIM + col) + col_offsets
49
+ x_ptrs = (qk_seq_ptr + batch_seq * qk_seq_stride) + embed_offsets
50
+
51
+ for k in range(0, BLOCK_HEIGHT):
52
+ x = tl.load(x_ptrs).to(tl.float32)
53
+ y = tl.load(x_ptrs + HALF_HEAD).to(tl.float32)
54
+ out_x = x * cos - y * sin
55
+ tl.store(x_ptrs, out_x)
56
+ out_y = x * sin + y * cos
57
+ tl.store(x_ptrs + HALF_HEAD, out_y)
58
+ x_ptrs += HEAD_DIM
59
+
60
+
61
+ def triton_rotate_half_(qk, position_ids, config=None):
62
+ with torch.cuda.device(qk.device):
63
+ batch_size, seq_len, qandk, num_heads, head_dim = qk.shape
64
+
65
+ # This default is the fastest for most job sizes, at least on my RTX 4090, and when it's not it's within spitting distance of the best option. There are some odd cases where having a block height of 2 or 4 helps but the difference is within 5%. It makes sense that this configuration is fast from a memory bandwidth and caching perspective.
66
+ config = config or {'BLOCK_HEIGHT': 1, 'BLOCK_WIDTH': min(128, head_dim // 2), 'num_warps': 1}
67
+ config['BLOCK_HEIGHT'] = min(config['BLOCK_HEIGHT'], 2 * num_heads)
68
+
69
+ assert qk.stride(3) == head_dim
70
+ assert qk.stride(4) == 1
71
+ assert position_ids.shape == (batch_size, seq_len)
72
+ assert position_ids.stride(1) == 1, 'position_ids must be contiguous in the last dimension'
73
+ assert (2 * num_heads) % config['BLOCK_HEIGHT'] == 0, f'number of rows not evenly divisible by {config["BLOCK_HEIGHT"]}'
74
+ assert (head_dim // 2) % config['BLOCK_WIDTH'] == 0, f'number of columns ({head_dim // 2}) not evenly divisible by {config["BLOCK_WIDTH"]}'
75
+
76
+ qk_by_seq = qk.view(batch_size * seq_len, 2 * num_heads * head_dim)
77
+ grid = (qk_by_seq.shape[0], (2 * num_heads // config['BLOCK_HEIGHT']) * (head_dim // 2 // config['BLOCK_WIDTH']))
78
+
79
+ # Must be the same as the theta of the frequencies used to train the model.
80
+ BASE = 10000.0
81
+
82
+ rotate_half_kernel[grid](
83
+ qk_by_seq,
84
+ position_ids,
85
+ qk_by_seq.stride(0),
86
+ position_ids.stride(0),
87
+ seq_len,
88
+ HEAD_DIM=head_dim,
89
+ BLOCK_HEIGHT=config['BLOCK_HEIGHT'],
90
+ BLOCK_WIDTH=config['BLOCK_WIDTH'],
91
+ INV_BASE=-2.0 * math.log(BASE) / head_dim,
92
+ num_warps=config['num_warps']
93
+ )
94
+
95
+
96
+ class QuantLlamaAttention(nn.Module):
97
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
98
+
99
+ def __init__(
100
+ self,
101
+ hidden_size,
102
+ num_heads,
103
+ qkv_proj,
104
+ o_proj
105
+ ):
106
+ super().__init__()
107
+ self.hidden_size = hidden_size
108
+ self.num_heads = num_heads
109
+ self.head_dim = hidden_size // num_heads
110
+
111
+ if (self.head_dim * num_heads) != self.hidden_size:
112
+ raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
113
+ f" and `num_heads`: {num_heads}).")
114
+ self.qkv_proj = qkv_proj
115
+ self.o_proj = o_proj
116
+
117
+ def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
118
+ """Input shape: Batch x Time x Channel"""
119
+
120
+ bsz, q_len, _ = hidden_states.size()
121
+
122
+ qkv_states = self.qkv_proj(hidden_states)
123
+ qkv_states = qkv_states.view(bsz, q_len, 3, self.num_heads, self.head_dim)
124
+
125
+ # This updates the query and key states in-place, saving VRAM.
126
+ triton_rotate_half_(qkv_states[:, :, :2], position_ids)
127
+
128
+ query_states, key_states, value_states = torch.split(qkv_states, 1, dim=2)
129
+ del qkv_states
130
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
131
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
132
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
133
+
134
+ is_causal = past_key_value is None
135
+
136
+ kv_seq_len = q_len
137
+ if past_key_value is not None:
138
+ kv_seq_len += past_key_value[0].shape[-2]
139
+
140
+ if past_key_value is not None:
141
+ # reuse k, v, self_attention
142
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
143
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
144
+
145
+ if use_cache:
146
+ # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
147
+ # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
148
+ key_states = key_states.contiguous()
149
+ value_states = value_states.contiguous()
150
+ query_states = query_states.contiguous()
151
+
152
+ past_key_value = (key_states, value_states) if use_cache else None
153
+
154
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
155
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal)
156
+ del query_states, key_states, value_states
157
+
158
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
159
+ attn_output = self.o_proj(attn_output)
160
+
161
+ return attn_output, None, past_key_value
162
+
163
+
164
+ def make_quant_attn(model):
165
+ """
166
+ Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
167
+ """
168
+
169
+ for name, m in model.named_modules():
170
+ if not isinstance(m, LlamaAttention):
171
+ continue
172
+
173
+ q_proj = m.q_proj
174
+ k_proj = m.k_proj
175
+ v_proj = m.v_proj
176
+
177
+ qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
178
+ qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
179
+ scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
180
+ g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
181
+ bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
182
+
183
+ qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False)
184
+ qkv_layer.qweight = qweights
185
+ qkv_layer.qzeros = qzeros
186
+ qkv_layer.scales = scales
187
+ qkv_layer.g_idx = g_idx
188
+ qkv_layer.bias = bias
189
+ # We're dropping the rotary embedding layer m.rotary_emb here. We don't need it in the triton branch.
190
+
191
+ attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj)
192
+
193
+ if '.' in name:
194
+ parent_name = name.rsplit('.', 1)[0]
195
+ child_name = name[len(parent_name) + 1:]
196
+ parent = model.get_submodule(parent_name)
197
+ else:
198
+ parent_name = ''
199
+ parent = model
200
+ child_name = name
201
+
202
+ #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
203
+
204
+ setattr(parent, child_name, attn)
GPTQ-for-Qwen_hf/quant/fused_mlp.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.cuda.amp import custom_bwd, custom_fwd
5
+ from transformers.models.llama.modeling_llama import LlamaMLP
6
+
7
+ try:
8
+ import triton
9
+ import triton.language as tl
10
+ from . import custom_autotune
11
+
12
+ # code based https://github.com/fpgaminer/GPTQ-triton
13
+ @custom_autotune.autotune(
14
+ configs=[
15
+ triton.Config({
16
+ 'BLOCK_SIZE_M': 256,
17
+ 'BLOCK_SIZE_N': 64,
18
+ 'BLOCK_SIZE_K': 32,
19
+ 'GROUP_SIZE_M': 8
20
+ }, num_stages=4, num_warps=4),
21
+ triton.Config({
22
+ 'BLOCK_SIZE_M': 64,
23
+ 'BLOCK_SIZE_N': 256,
24
+ 'BLOCK_SIZE_K': 32,
25
+ 'GROUP_SIZE_M': 8
26
+ }, num_stages=4, num_warps=4),
27
+ triton.Config({
28
+ 'BLOCK_SIZE_M': 128,
29
+ 'BLOCK_SIZE_N': 128,
30
+ 'BLOCK_SIZE_K': 32,
31
+ 'GROUP_SIZE_M': 8
32
+ }, num_stages=4, num_warps=4),
33
+ triton.Config({
34
+ 'BLOCK_SIZE_M': 128,
35
+ 'BLOCK_SIZE_N': 64,
36
+ 'BLOCK_SIZE_K': 32,
37
+ 'GROUP_SIZE_M': 8
38
+ }, num_stages=4, num_warps=4),
39
+ triton.Config({
40
+ 'BLOCK_SIZE_M': 64,
41
+ 'BLOCK_SIZE_N': 128,
42
+ 'BLOCK_SIZE_K': 32,
43
+ 'GROUP_SIZE_M': 8
44
+ }, num_stages=4, num_warps=4),
45
+ triton.Config({
46
+ 'BLOCK_SIZE_M': 128,
47
+ 'BLOCK_SIZE_N': 32,
48
+ 'BLOCK_SIZE_K': 32,
49
+ 'GROUP_SIZE_M': 8
50
+ }, num_stages=4, num_warps=4), # 3090
51
+ triton.Config({
52
+ 'BLOCK_SIZE_M': 128,
53
+ 'BLOCK_SIZE_N': 16,
54
+ 'BLOCK_SIZE_K': 32,
55
+ 'GROUP_SIZE_M': 8
56
+ }, num_stages=4, num_warps=4), # 3090
57
+ triton.Config({
58
+ 'BLOCK_SIZE_M': 32,
59
+ 'BLOCK_SIZE_N': 32,
60
+ 'BLOCK_SIZE_K': 128,
61
+ 'GROUP_SIZE_M': 8
62
+ }, num_stages=2, num_warps=4), # 3090
63
+ triton.Config({
64
+ 'BLOCK_SIZE_M': 64,
65
+ 'BLOCK_SIZE_N': 16,
66
+ 'BLOCK_SIZE_K': 64,
67
+ 'GROUP_SIZE_M': 8
68
+ }, num_stages=4, num_warps=4), # 3090
69
+ triton.Config({
70
+ 'BLOCK_SIZE_M': 64,
71
+ 'BLOCK_SIZE_N': 32,
72
+ 'BLOCK_SIZE_K': 64,
73
+ 'GROUP_SIZE_M': 8
74
+ }, num_stages=4, num_warps=4), # 3090
75
+ ],
76
+ key=['M', 'N', 'K'],
77
+ nearest_power_of_two=True,
78
+ prune_configs_by={
79
+ 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
80
+ 'perf_model': None,
81
+ 'top_k': None,
82
+ },
83
+ )
84
+ @triton.jit
85
+ def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn,
86
+ stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
87
+ """
88
+ Computes: C = silu(A * B1) * (A * B2)
89
+ A is of shape (M, K) float16
90
+ B is of shape (K//8, N) int32
91
+ C is of shape (M, N) float16
92
+ scales is of shape (1, N) float16
93
+ zeros is of shape (1, N//8) int32
94
+ """
95
+ infearure_per_bits = 32 // bits
96
+
97
+ pid = tl.program_id(axis=0)
98
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
99
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
100
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
101
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
102
+ group_id = pid // num_pid_in_group
103
+ first_pid_m = group_id * GROUP_SIZE_M
104
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
105
+ pid_m = first_pid_m + (pid % group_size_m)
106
+ pid_n = (pid % num_pid_in_group) // group_size_m
107
+
108
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
109
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
110
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
111
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
112
+ a_mask = (offs_am[:, None] < M)
113
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
114
+ b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
115
+ b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn)
116
+ g1_ptrs = g1_ptr + offs_k
117
+ g2_ptrs = g2_ptr + offs_k
118
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
119
+ scales1_ptrs = scales1_ptr + offs_bn[None, :]
120
+ scales2_ptrs = scales2_ptr + offs_bn[None, :]
121
+ zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits)
122
+ zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits)
123
+
124
+ shifter = (offs_k % infearure_per_bits) * bits
125
+ zeros_shifter = (offs_bn % infearure_per_bits) * bits
126
+ accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
127
+ accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
128
+ for k in range(0, num_pid_k):
129
+ g1_idx = tl.load(g1_ptrs)
130
+ g2_idx = tl.load(g2_ptrs)
131
+
132
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
133
+ scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
134
+ scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales)
135
+
136
+ zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
137
+ zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq
138
+ zeros1 = (zeros1 + 1)
139
+
140
+ zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
141
+ zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq
142
+ zeros2 = (zeros2 + 1)
143
+
144
+ a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
145
+ b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
146
+ b2 = tl.load(b2_ptrs)
147
+
148
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
149
+ b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values
150
+ b1 = (b1 - zeros1) * scales1 # Scale and shift
151
+ accumulator1 += tl.dot(a, b1)
152
+
153
+ b2 = (b2 >> shifter[:, None]) & maxq
154
+ b2 = (b2 - zeros2) * scales2
155
+ accumulator2 += tl.dot(a, b2)
156
+
157
+ a_ptrs += BLOCK_SIZE_K
158
+ b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
159
+ b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
160
+ g1_ptrs += BLOCK_SIZE_K
161
+ g2_ptrs += BLOCK_SIZE_K
162
+
163
+ accumulator1 = silu(accumulator1)
164
+ c = accumulator1 * accumulator2
165
+ c = c.to(tl.float16)
166
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
167
+ c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
168
+ tl.store(c_ptrs, c, mask=c_mask)
169
+
170
+ @triton.jit
171
+ def silu(x):
172
+ return x * tl.sigmoid(x)
173
+ except:
174
+ print('triton not installed.')
175
+
176
+
177
+ class QuantLlamaMLP(nn.Module):
178
+
179
+ def __init__(
180
+ self,
181
+ gate_proj,
182
+ down_proj,
183
+ up_proj,
184
+ ):
185
+ super().__init__()
186
+ self.register_buffer('gate_proj_qweight', gate_proj.qweight)
187
+ self.register_buffer('gate_proj_scales', gate_proj.scales)
188
+ self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
189
+ self.register_buffer('gate_proj_g_idx', gate_proj.g_idx)
190
+ self.register_buffer('up_proj_qweight', up_proj.qweight)
191
+ self.register_buffer('up_proj_scales', up_proj.scales)
192
+ self.register_buffer('up_proj_qzeros', up_proj.qzeros)
193
+ self.register_buffer('up_proj_g_idx', up_proj.g_idx)
194
+
195
+ self.infeatures = gate_proj.infeatures
196
+ self.intermediate_size = gate_proj.outfeatures
197
+ self.outfeatures = down_proj.outfeatures
198
+ self.bits = gate_proj.bits
199
+ self.maxq = gate_proj.maxq
200
+
201
+ self.down_proj = down_proj
202
+
203
+ def forward(self, x):
204
+ return self.down_proj(self.triton_llama_mlp(x))
205
+
206
+ def triton_llama_mlp(self, x):
207
+ with torch.cuda.device(x.device):
208
+ out_shape = x.shape[:-1] + (self.intermediate_size, )
209
+ x = x.reshape(-1, x.shape[-1])
210
+ M, K = x.shape
211
+ N = self.intermediate_size
212
+ c = torch.empty((M, N), device=x.device, dtype=torch.float16)
213
+ grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
214
+ fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales,
215
+ self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0),
216
+ self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0))
217
+ c = c.reshape(out_shape)
218
+ return c
219
+
220
+ def fused2cuda(self):
221
+ self.gate_proj_qweight = self.gate_proj_qweight.cuda()
222
+ self.gate_proj_scales = self.gate_proj_scales.cuda()
223
+ self.gate_proj_qzeros = self.gate_proj_qzeros.cuda()
224
+ self.gate_proj_g_idx = self.gate_proj_g_idx.cuda()
225
+ self.up_proj_qweight = self.up_proj_qweight.cuda()
226
+ self.up_proj_scales = self.up_proj_scales.cuda()
227
+ self.up_proj_qzeros = self.up_proj_qzeros.cuda()
228
+ self.up_proj_g_idx = self.up_proj_g_idx.cuda()
229
+
230
+ def fused2cpu(self):
231
+ self.gate_proj_qweight = self.gate_proj_qweight.cpu()
232
+ self.gate_proj_scales = self.gate_proj_scales.cpu()
233
+ self.gate_proj_qzeros = self.gate_proj_qzeros.cpu()
234
+ self.gate_proj_g_idx = self.gate_proj_g_idx.cpu()
235
+ self.up_proj_qweight = self.up_proj_qweight.cpu()
236
+ self.up_proj_scales = self.up_proj_scales.cpu()
237
+ self.up_proj_qzeros = self.up_proj_qzeros.cpu()
238
+ self.up_proj_g_idx = self.up_proj_g_idx.cpu()
239
+
240
+
241
+ def make_fused_mlp(m, parent_name=''):
242
+ """
243
+ Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
244
+ """
245
+ if isinstance(m, LlamaMLP):
246
+ return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
247
+
248
+ for name, child in m.named_children():
249
+ child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
250
+
251
+ if isinstance(child, QuantLlamaMLP):
252
+ setattr(m, name, child)
253
+ return m
254
+
255
+
256
+ def autotune_warmup_fused(model):
257
+ """
258
+ Pre-tunes the quantized kernel
259
+ """
260
+ from tqdm import tqdm
261
+
262
+ kn_values = {}
263
+
264
+ for _, m in model.named_modules():
265
+ if not isinstance(m, QuantLlamaMLP):
266
+ continue
267
+
268
+ k = m.infeatures
269
+ n = m.intermediate_size
270
+
271
+ m.fused2cuda()
272
+ if (k, n) not in kn_values:
273
+ kn_values[(k, n)] = m
274
+
275
+ print(f'Found {len(kn_values)} unique fused mlp KN values.')
276
+
277
+ print('Warming up autotune cache ...')
278
+ with torch.no_grad():
279
+ for m in tqdm(range(0, 12)):
280
+ m = 2**m # [1, 2048]
281
+ for (k, n), (modules) in kn_values.items():
282
+ a = torch.randn(m, k, dtype=torch.float16, device='cuda')
283
+ modules.triton_llama_mlp(a)
284
+
285
+ for (k, n), (modules) in kn_values.items():
286
+ a = torch.randn(m, k, dtype=torch.float16, device='cuda')
287
+ modules.fused2cpu()
288
+ del kn_values
GPTQ-for-Qwen_hf/quant/quant_linear.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.cuda.amp import custom_bwd, custom_fwd
6
+
7
+ try:
8
+ import triton
9
+ import triton.language as tl
10
+ from . import custom_autotune
11
+
12
+ # code based https://github.com/fpgaminer/GPTQ-triton
13
+ @custom_autotune.autotune(
14
+ configs=[
15
+ triton.Config({
16
+ 'BLOCK_SIZE_M': 64,
17
+ 'BLOCK_SIZE_N': 256,
18
+ 'BLOCK_SIZE_K': 32,
19
+ 'GROUP_SIZE_M': 8
20
+ }, num_stages=4, num_warps=4),
21
+ triton.Config({
22
+ 'BLOCK_SIZE_M': 128,
23
+ 'BLOCK_SIZE_N': 128,
24
+ 'BLOCK_SIZE_K': 32,
25
+ 'GROUP_SIZE_M': 8
26
+ }, num_stages=4, num_warps=4),
27
+ triton.Config({
28
+ 'BLOCK_SIZE_M': 64,
29
+ 'BLOCK_SIZE_N': 128,
30
+ 'BLOCK_SIZE_K': 32,
31
+ 'GROUP_SIZE_M': 8
32
+ }, num_stages=4, num_warps=4),
33
+ triton.Config({
34
+ 'BLOCK_SIZE_M': 128,
35
+ 'BLOCK_SIZE_N': 32,
36
+ 'BLOCK_SIZE_K': 32,
37
+ 'GROUP_SIZE_M': 8
38
+ }, num_stages=4, num_warps=4),
39
+ triton.Config({
40
+ 'BLOCK_SIZE_M': 64,
41
+ 'BLOCK_SIZE_N': 64,
42
+ 'BLOCK_SIZE_K': 32,
43
+ 'GROUP_SIZE_M': 8
44
+ }, num_stages=4, num_warps=4),
45
+ triton.Config({
46
+ 'BLOCK_SIZE_M': 64,
47
+ 'BLOCK_SIZE_N': 128,
48
+ 'BLOCK_SIZE_K': 32,
49
+ 'GROUP_SIZE_M': 8
50
+ }, num_stages=2, num_warps=8),
51
+ triton.Config({
52
+ 'BLOCK_SIZE_M': 64,
53
+ 'BLOCK_SIZE_N': 64,
54
+ 'BLOCK_SIZE_K': 64,
55
+ 'GROUP_SIZE_M': 8
56
+ }, num_stages=3, num_warps=8),
57
+ triton.Config({
58
+ 'BLOCK_SIZE_M': 32,
59
+ 'BLOCK_SIZE_N': 32,
60
+ 'BLOCK_SIZE_K': 128,
61
+ 'GROUP_SIZE_M': 8
62
+ }, num_stages=2, num_warps=4),
63
+ ],
64
+ key=['M', 'N', 'K'],
65
+ nearest_power_of_two=True,
66
+ prune_configs_by={
67
+ 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner,
68
+ 'perf_model': None,
69
+ 'top_k': None,
70
+ },
71
+ )
72
+ @triton.jit
73
+ def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros,
74
+ BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
75
+ """
76
+ Compute the matrix multiplication C = A x B.
77
+ A is of shape (M, K) float16
78
+ B is of shape (K//8, N) int32
79
+ C is of shape (M, N) float16
80
+ scales is of shape (G, N) float16
81
+ zeros is of shape (G, N) float16
82
+ g_ptr is of shape (K) int32
83
+ """
84
+ infearure_per_bits = 32 // bits
85
+
86
+ pid = tl.program_id(axis=0)
87
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
88
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
89
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
90
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
91
+ group_id = pid // num_pid_in_group
92
+ first_pid_m = group_id * GROUP_SIZE_M
93
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
94
+ pid_m = first_pid_m + (pid % group_size_m)
95
+ pid_n = (pid % num_pid_in_group) // group_size_m
96
+
97
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
98
+ offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
99
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
100
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
101
+ a_mask = (offs_am[:, None] < M)
102
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
103
+ b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
104
+ g_ptrs = g_ptr + offs_k
105
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
106
+ scales_ptrs = scales_ptr + offs_bn[None, :]
107
+ zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
108
+
109
+ shifter = (offs_k % infearure_per_bits) * bits
110
+ zeros_shifter = (offs_bn % infearure_per_bits) * bits
111
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
112
+
113
+ for k in range(0, num_pid_k):
114
+ g_idx = tl.load(g_ptrs)
115
+
116
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
117
+ scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
118
+ zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
119
+
120
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
121
+ zeros = (zeros + 1)
122
+
123
+ a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
124
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
125
+
126
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
127
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
128
+ b = (b - zeros) * scales # Scale and shift
129
+
130
+ accumulator += tl.dot(a, b)
131
+ a_ptrs += BLOCK_SIZE_K
132
+ b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
133
+ g_ptrs += BLOCK_SIZE_K
134
+
135
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
136
+ c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
137
+ tl.store(c_ptrs, accumulator, mask=c_mask)
138
+
139
+ @custom_autotune.autotune(configs=[
140
+ triton.Config({
141
+ 'BLOCK_SIZE_M': 64,
142
+ 'BLOCK_SIZE_N': 32,
143
+ 'BLOCK_SIZE_K': 256,
144
+ 'GROUP_SIZE_M': 8
145
+ }, num_stages=4, num_warps=4),
146
+ triton.Config({
147
+ 'BLOCK_SIZE_M': 128,
148
+ 'BLOCK_SIZE_N': 32,
149
+ 'BLOCK_SIZE_K': 128,
150
+ 'GROUP_SIZE_M': 8
151
+ }, num_stages=4, num_warps=4),
152
+ triton.Config({
153
+ 'BLOCK_SIZE_M': 64,
154
+ 'BLOCK_SIZE_N': 32,
155
+ 'BLOCK_SIZE_K': 128,
156
+ 'GROUP_SIZE_M': 8
157
+ }, num_stages=4, num_warps=4),
158
+ triton.Config({
159
+ 'BLOCK_SIZE_M': 128,
160
+ 'BLOCK_SIZE_N': 32,
161
+ 'BLOCK_SIZE_K': 32,
162
+ 'GROUP_SIZE_M': 8
163
+ }, num_stages=4, num_warps=4),
164
+ triton.Config({
165
+ 'BLOCK_SIZE_M': 64,
166
+ 'BLOCK_SIZE_N': 32,
167
+ 'BLOCK_SIZE_K': 64,
168
+ 'GROUP_SIZE_M': 8
169
+ }, num_stages=4, num_warps=4),
170
+ triton.Config({
171
+ 'BLOCK_SIZE_M': 64,
172
+ 'BLOCK_SIZE_N': 32,
173
+ 'BLOCK_SIZE_K': 128,
174
+ 'GROUP_SIZE_M': 8
175
+ }, num_stages=2, num_warps=8),
176
+ triton.Config({
177
+ 'BLOCK_SIZE_M': 64,
178
+ 'BLOCK_SIZE_N': 64,
179
+ 'BLOCK_SIZE_K': 64,
180
+ 'GROUP_SIZE_M': 8
181
+ }, num_stages=3, num_warps=8),
182
+ triton.Config({
183
+ 'BLOCK_SIZE_M': 32,
184
+ 'BLOCK_SIZE_N': 128,
185
+ 'BLOCK_SIZE_K': 32,
186
+ 'GROUP_SIZE_M': 8
187
+ }, num_stages=2, num_warps=4),
188
+ ],
189
+ key=['M', 'N', 'K'],
190
+ nearest_power_of_two=True)
191
+ @triton.jit
192
+ def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales,
193
+ stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
194
+ """
195
+ Compute the matrix multiplication C = A x B.
196
+ A is of shape (M, N) float16
197
+ B is of shape (K//8, N) int32
198
+ C is of shape (M, K) float16
199
+ scales is of shape (G, N) float16
200
+ zeros is of shape (G, N) float16
201
+ g_ptr is of shape (K) int32
202
+ """
203
+ infearure_per_bits = 32 // bits
204
+
205
+ pid = tl.program_id(axis=0)
206
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
207
+ num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
208
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
209
+ num_pid_in_group = GROUP_SIZE_M * num_pid_k
210
+ group_id = pid // num_pid_in_group
211
+ first_pid_m = group_id * GROUP_SIZE_M
212
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
213
+ pid_m = first_pid_m + (pid % group_size_m)
214
+ pid_k = (pid % num_pid_in_group) // group_size_m
215
+
216
+ offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
217
+ offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
218
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
219
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
220
+ a_mask = (offs_am[:, None] < M)
221
+ # b_ptrs is set up such that it repeats elements along the K axis 8 times
222
+ b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
223
+ g_ptrs = g_ptr + offs_bk
224
+ g_idx = tl.load(g_ptrs)
225
+
226
+ # shifter is used to extract the N bits of each element in the 32-bit word from B
227
+ scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
228
+ zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
229
+
230
+ shifter = (offs_bk % infearure_per_bits) * bits
231
+ zeros_shifter = (offs_n % infearure_per_bits) * bits
232
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
233
+
234
+ for n in range(0, num_pid_n):
235
+ # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
236
+ scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
237
+ zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
238
+
239
+ zeros = (zeros >> zeros_shifter[None, :]) & maxq
240
+ zeros = (zeros + 1)
241
+
242
+ a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
243
+ b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
244
+
245
+ # Now we need to unpack b (which is N-bit values) into 32-bit values
246
+ b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
247
+ b = (b - zeros) * scales # Scale and shift
248
+ b = tl.trans(b)
249
+
250
+ accumulator += tl.dot(a, b)
251
+ a_ptrs += BLOCK_SIZE_N
252
+ b_ptrs += BLOCK_SIZE_N
253
+ scales_ptrs += BLOCK_SIZE_N
254
+ zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits)
255
+
256
+ c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
257
+ c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
258
+ tl.store(c_ptrs, accumulator, mask=c_mask)
259
+ except:
260
+ print('triton not installed.')
261
+
262
+
263
+ def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
264
+ with torch.cuda.device(input.device):
265
+ output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
266
+ grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), )
267
+ matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
268
+ qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
269
+ return output
270
+
271
+
272
+ def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq):
273
+ with torch.cuda.device(input.device):
274
+ output_dim = (qweight.shape[0] * 32) // bits
275
+ output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16)
276
+ grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), )
277
+ transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0),
278
+ qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0))
279
+ return output
280
+
281
+
282
+ class QuantLinearFunction(torch.autograd.Function):
283
+
284
+ @staticmethod
285
+ @custom_fwd(cast_inputs=torch.float16)
286
+ def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
287
+ output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq)
288
+ ctx.save_for_backward(qweight, scales, qzeros, g_idx)
289
+ ctx.bits, ctx.maxq = bits, maxq
290
+ return output
291
+
292
+ @staticmethod
293
+ @custom_bwd
294
+ def backward(ctx, grad_output):
295
+ qweight, scales, qzeros, g_idx = ctx.saved_tensors
296
+ bits, maxq = ctx.bits, ctx.maxq
297
+ grad_input = None
298
+
299
+ if ctx.needs_input_grad[0]:
300
+ grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
301
+ return grad_input, None, None, None, None, None, None
302
+
303
+
304
+ class QuantLinear(nn.Module):
305
+
306
+ def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
307
+ super().__init__()
308
+ if bits not in [2, 4, 8]:
309
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
310
+ self.infeatures = infeatures
311
+ self.outfeatures = outfeatures
312
+ self.bits = bits
313
+ self.maxq = 2**self.bits - 1
314
+ self.groupsize = groupsize if groupsize != -1 else infeatures
315
+
316
+ self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32))
317
+ self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32))
318
+ self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
319
+ self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
320
+ if bias:
321
+ self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
322
+ else:
323
+ self.bias = None
324
+
325
+ def pack(self, linear, scales, zeros, g_idx=None):
326
+ self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
327
+
328
+ scales = scales.t().contiguous()
329
+ zeros = zeros.t().contiguous()
330
+ scale_zeros = zeros * scales
331
+ self.scales = scales.clone().half()
332
+ if linear.bias is not None:
333
+ self.bias = linear.bias.clone().half()
334
+
335
+ intweight = []
336
+ for idx in range(self.infeatures):
337
+ intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
338
+ intweight = torch.cat(intweight, dim=1)
339
+ intweight = intweight.t().contiguous()
340
+ intweight = intweight.numpy().astype(np.uint32)
341
+ qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
342
+ i = 0
343
+ row = 0
344
+ while row < qweight.shape[0]:
345
+ if self.bits in [2, 4, 8]:
346
+ for j in range(i, i + (32 // self.bits)):
347
+ qweight[row] |= intweight[j] << (self.bits * (j - i))
348
+ i += 32 // self.bits
349
+ row += 1
350
+ else:
351
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
352
+
353
+ qweight = qweight.astype(np.int32)
354
+ self.qweight = torch.from_numpy(qweight)
355
+
356
+ zeros -= 1
357
+ zeros = zeros.numpy().astype(np.uint32)
358
+ qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
359
+ i = 0
360
+ col = 0
361
+ while col < qzeros.shape[1]:
362
+ if self.bits in [2, 4, 8]:
363
+ for j in range(i, i + (32 // self.bits)):
364
+ qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
365
+ i += 32 // self.bits
366
+ col += 1
367
+ else:
368
+ raise NotImplementedError("Only 2,4,8 bits are supported.")
369
+
370
+ qzeros = qzeros.astype(np.int32)
371
+ self.qzeros = torch.from_numpy(qzeros)
372
+
373
+ def forward(self, x):
374
+ out_shape = x.shape[:-1] + (self.outfeatures, )
375
+ out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq)
376
+ out = out + self.bias if self.bias is not None else out
377
+ return out.reshape(out_shape)
378
+
379
+
380
+ def make_quant_linear(module, names, bits, groupsize, name=''):
381
+ if isinstance(module, QuantLinear):
382
+ return
383
+ for attr in dir(module):
384
+ tmp = getattr(module, attr)
385
+ name1 = name + '.' + attr if name != '' else attr
386
+ if name1 in names:
387
+ delattr(module, attr)
388
+ setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
389
+ for name1, child in module.named_children():
390
+ make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
391
+
392
+
393
+ def autotune_warmup_linear(model, transpose=False):
394
+ """
395
+ Pre-tunes the quantized kernel
396
+ """
397
+ from tqdm import tqdm
398
+
399
+ kn_values = {}
400
+
401
+ for _, m in model.named_modules():
402
+ if not isinstance(m, QuantLinear):
403
+ continue
404
+
405
+ k = m.infeatures
406
+ n = m.outfeatures
407
+
408
+ if (k, n) not in kn_values:
409
+ kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq)
410
+
411
+ print(f'Found {len(kn_values)} unique KN Linear values.')
412
+
413
+ print('Warming up autotune cache ...')
414
+ with torch.no_grad():
415
+ for m in tqdm(range(0, 12)):
416
+ m = 2**m # [1, 2048]
417
+ for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items():
418
+ a = torch.randn(m, k, dtype=torch.float16, device='cuda')
419
+ matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
420
+ if transpose:
421
+ a = torch.randn(m, n, dtype=torch.float16, device='cuda')
422
+ transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq)
423
+ del kn_values
GPTQ-for-Qwen_hf/quant/quantizer.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import math
5
+
6
+
7
+ class Quantizer(nn.Module):
8
+
9
+ def __init__(self, shape=1):
10
+ super(Quantizer, self).__init__()
11
+ self.register_buffer('maxq', torch.tensor(0))
12
+ self.register_buffer('scale', torch.zeros(shape))
13
+ self.register_buffer('zero', torch.zeros(shape))
14
+
15
+ def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False):
16
+
17
+ self.maxq = torch.tensor(2**bits - 1)
18
+ self.perchannel = perchannel
19
+ self.sym = sym
20
+ self.mse = mse
21
+ self.norm = norm
22
+ self.grid = grid
23
+ self.maxshrink = maxshrink
24
+ if trits:
25
+ self.maxq = torch.tensor(-1)
26
+ self.scale = torch.zeros_like(self.scale)
27
+
28
+ def _quantize(self, x, scale, zero, maxq):
29
+ if maxq < 0:
30
+ return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
31
+ q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
32
+ return scale * (q - zero)
33
+
34
+ def find_params(self, x, weight=False):
35
+ dev = x.device
36
+ self.maxq = self.maxq.to(dev)
37
+
38
+ shape = x.shape
39
+ if self.perchannel:
40
+ if weight:
41
+ x = x.flatten(1)
42
+ else:
43
+ if len(shape) == 4:
44
+ x = x.permute([1, 0, 2, 3])
45
+ x = x.flatten(1)
46
+ if len(shape) == 3:
47
+ x = x.reshape((-1, shape[-1])).t()
48
+ if len(shape) == 2:
49
+ x = x.t()
50
+ else:
51
+ x = x.flatten().unsqueeze(0)
52
+
53
+ tmp = torch.zeros(x.shape[0], device=dev)
54
+ xmin = torch.minimum(x.min(1)[0], tmp)
55
+ xmax = torch.maximum(x.max(1)[0], tmp)
56
+
57
+ if self.sym:
58
+ xmax = torch.maximum(torch.abs(xmin), xmax)
59
+ tmp = xmin < 0
60
+ if torch.any(tmp):
61
+ xmin[tmp] = -xmax[tmp]
62
+ tmp = (xmin == 0) & (xmax == 0)
63
+ xmin[tmp] = -1
64
+ xmax[tmp] = +1
65
+
66
+ if self.maxq < 0:
67
+ self.scale = xmax
68
+ self.zero = xmin
69
+ else:
70
+ self.scale = (xmax - xmin) / self.maxq
71
+ if self.sym:
72
+ self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
73
+ else:
74
+ self.zero = torch.round(-xmin / self.scale)
75
+
76
+ if self.mse:
77
+ best = torch.full([x.shape[0]], float('inf'), device=dev)
78
+ for i in range(int(self.maxshrink * self.grid)):
79
+ p = 1 - i / self.grid
80
+ xmin1 = p * xmin
81
+ xmax1 = p * xmax
82
+ scale1 = (xmax1 - xmin1) / self.maxq
83
+ zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
84
+ q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
85
+ q -= x
86
+ q.abs_()
87
+ q.pow_(self.norm)
88
+ err = torch.sum(q, 1)
89
+ tmp = err < best
90
+ if torch.any(tmp):
91
+ best[tmp] = err[tmp]
92
+ self.scale[tmp] = scale1[tmp]
93
+ self.zero[tmp] = zero1[tmp]
94
+ if not self.perchannel:
95
+ if weight:
96
+ tmp = shape[0]
97
+ else:
98
+ tmp = shape[1] if len(shape) != 3 else shape[2]
99
+ self.scale = self.scale.repeat(tmp)
100
+ self.zero = self.zero.repeat(tmp)
101
+
102
+ if weight:
103
+ shape = [-1] + [1] * (len(shape) - 1)
104
+ self.scale = self.scale.reshape(shape)
105
+ self.zero = self.zero.reshape(shape)
106
+ return
107
+ if len(shape) == 4:
108
+ self.scale = self.scale.reshape((1, -1, 1, 1))
109
+ self.zero = self.zero.reshape((1, -1, 1, 1))
110
+ if len(shape) == 3:
111
+ self.scale = self.scale.reshape((1, 1, -1))
112
+ self.zero = self.zero.reshape((1, 1, -1))
113
+ if len(shape) == 2:
114
+ self.scale = self.scale.unsqueeze(0)
115
+ self.zero = self.zero.unsqueeze(0)
116
+
117
+ def quantize(self, x):
118
+ if self.ready():
119
+ return self._quantize(x, self.scale, self.zero, self.maxq)
120
+
121
+ return x
122
+
123
+ def enabled(self):
124
+ return self.maxq > 0
125
+
126
+ def ready(self):
127
+ return torch.all(self.scale != 0)
GPTQ-for-Qwen_hf/quant/triton_norm.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import triton
4
+ import triton.language as tl
5
+ from transformers.models.llama.modeling_llama import LlamaRMSNorm
6
+
7
+ @triton.jit
8
+ def rms_norm_fwd_fused(
9
+ X, # pointer to the input
10
+ Y, # pointer to the output
11
+ W, # pointer to the weights
12
+ stride, # how much to increase the pointer when moving by 1 row
13
+ N, # number of columns in X
14
+ eps, # epsilon to avoid division by zero
15
+ BLOCK_SIZE: tl.constexpr,
16
+ ):
17
+ # Map the program id to the row of X and Y it should compute.
18
+ row = tl.program_id(0)
19
+ Y += row * stride
20
+ X += row * stride
21
+ # Compute variance
22
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
23
+ for off in range(0, N, BLOCK_SIZE):
24
+ cols = off + tl.arange(0, BLOCK_SIZE)
25
+ x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
26
+ x = tl.where(cols < N, x, 0.)
27
+ _var += x * x
28
+ var = tl.sum(_var, axis=0) / N
29
+ rstd = 1 / tl.sqrt(var + eps)
30
+ # Normalize and apply linear transformation
31
+ for off in range(0, N, BLOCK_SIZE):
32
+ cols = off + tl.arange(0, BLOCK_SIZE)
33
+ mask = cols < N
34
+ w = tl.load(W + cols, mask=mask)
35
+ x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
36
+ x_hat = x * rstd
37
+ y = x_hat * w
38
+ # Write output
39
+ tl.store(Y + cols, y, mask=mask)
40
+
41
+ class TritonLlamaRMSNorm(nn.Module):
42
+ def __init__(self, weight, eps=1e-6):
43
+ """
44
+ LlamaRMSNorm is equivalent to T5LayerNorm
45
+ """
46
+ super().__init__()
47
+ self.weight = weight
48
+ self.variance_epsilon = eps
49
+
50
+ def forward(self, x):
51
+ with torch.cuda.device(x.device):
52
+ y = torch.empty_like(x)
53
+ # reshape input data into 2D tensor
54
+ x_arg = x.reshape(-1, x.shape[-1])
55
+ M, N = x_arg.shape
56
+ # Less than 64KB per feature: enqueue fused kernel
57
+ MAX_FUSED_SIZE = 65536 // x.element_size()
58
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
59
+ if N > BLOCK_SIZE:
60
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
61
+ # heuristics for number of warps
62
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
63
+ # enqueue kernel
64
+ rms_norm_fwd_fused[(M,)](x_arg, y, self.weight,
65
+ x_arg.stride(0), N, self.variance_epsilon,
66
+ BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)
67
+ return y
68
+
69
+
70
+ def make_quant_norm(model):
71
+ """
72
+ Replace all LlamaRMSNorm modules with TritonLlamaRMSNorm modules
73
+ """
74
+
75
+ for name, m in model.named_modules():
76
+ if not isinstance(m, LlamaRMSNorm):
77
+ continue
78
+
79
+ norm = TritonLlamaRMSNorm(m.weight, m.variance_epsilon)
80
+
81
+ if '.' in name:
82
+ parent_name = name.rsplit('.', 1)[0]
83
+ child_name = name[len(parent_name) + 1:]
84
+ parent = model.get_submodule(parent_name)
85
+ else:
86
+ parent_name = ''
87
+ parent = model
88
+ child_name = name
89
+
90
+ #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}")
91
+
92
+ setattr(parent, child_name, norm)
GPTQ-for-Qwen_hf/qwen.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import quant
7
+ import sys
8
+
9
+ from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders, export_quant_table, gen_conditions
10
+ from texttable import Texttable
11
+
12
+ import tqdm
13
+
14
+ class Evaluator:
15
+ def __init__(self, dataset, tokenizer, device, n_samples=40):
16
+ self.dataset = dataset
17
+ self.tokenizer = tokenizer
18
+ self.device = device
19
+
20
+ self.dataset = tokenizer(
21
+ "\n\n".join(dataset["text"]), return_tensors="pt"
22
+ ).input_ids.to(device)
23
+
24
+ self.n_samples = n_samples
25
+
26
+ @torch.no_grad()
27
+ def evaluate(self, model):
28
+ model.eval()
29
+ nlls = []
30
+ for i in tqdm.tqdm(range(self.n_samples), desc="Evaluating..."):
31
+ batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
32
+ with torch.no_grad():
33
+ lm_logits = model(batch).logits
34
+ shift_logits = lm_logits[:, :-1, :].contiguous().float()
35
+ shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
36
+ loss_fct = nn.CrossEntropyLoss()
37
+ loss = loss_fct(
38
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
39
+ )
40
+ neg_log_likelihood = loss.float() * 2048
41
+ nlls.append(neg_log_likelihood)
42
+
43
+ return torch.exp(torch.stack(nlls).sum() / (self.n_samples * 2048))
44
+
45
+
46
+ def get_qwen(model):
47
+
48
+ def skip(*args, **kwargs):
49
+ pass
50
+
51
+ torch.nn.init.kaiming_uniform_ = skip
52
+ torch.nn.init.uniform_ = skip
53
+ torch.nn.init.normal_ = skip
54
+ from transformers import Qwen3ForCausalLM
55
+ model = Qwen3ForCausalLM.from_pretrained(model, torch_dtype=torch.bfloat16,device_map="auto")
56
+ model.seqlen = 2048
57
+ return model
58
+
59
+
60
+
61
+ def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True):
62
+ from transformers import Qwen3Config, Qwen3ForCausalLM, modeling_utils
63
+ config = Qwen3Config.from_pretrained(model)
64
+ # model = Qwen3Config.from_pretrained(model, torch_dtype=torch.bfloat16, device_map="auto")
65
+
66
+ def noop(*args, **kwargs):
67
+ pass
68
+
69
+ torch.nn.init.kaiming_uniform_ = noop
70
+ torch.nn.init.uniform_ = noop
71
+ torch.nn.init.normal_ = noop
72
+
73
+ torch.set_default_dtype(torch.half)
74
+ modeling_utils._init_weights = False
75
+ torch.set_default_dtype(torch.half)
76
+ model = Qwen3ForCausalLM(config)
77
+
78
+ torch.set_default_dtype(torch.float)
79
+ if eval:
80
+ model = model.eval()
81
+ layers = find_layers(model)
82
+ for name in ['lm_head']:
83
+ if name in layers:
84
+ del layers[name]
85
+ quant.make_quant_linear(model, layers, wbits, groupsize)
86
+
87
+ del layers
88
+
89
+ print('Loading model ...')
90
+ if checkpoint.endswith('.safetensors'):
91
+ from safetensors.torch import load_file as safe_load
92
+ model.load_state_dict(safe_load(checkpoint))
93
+ else:
94
+ model.load_state_dict(torch.load(checkpoint))
95
+
96
+ if eval:
97
+ quant.make_quant_attn(model)
98
+ quant.make_quant_norm(model)
99
+ if fused_mlp:
100
+ quant.make_fused_mlp(model)
101
+
102
+ if warmup_autotune:
103
+ quant.autotune_warmup_linear(model, transpose=not (eval))
104
+ if eval and fused_mlp:
105
+ quant.autotune_warmup_fused(model)
106
+ model.seqlen = 2048
107
+ print('Done.')
108
+
109
+ return model
110
+
111
+
112
+ def Qwen_multigpu(model, gpus, gpu_dist):
113
+ model.model.embed_tokens = model.model.embed_tokens.to(gpus[0])
114
+ if hasattr(model.model, 'norm') and model.model.norm:
115
+ model.model.norm = model.model.norm.to(gpus[0])
116
+ import copy
117
+ model.lm_head = copy.deepcopy(model.lm_head).to(gpus[0])
118
+
119
+ cache = {'mask': None, 'position_ids': None}
120
+
121
+ class MoveModule(nn.Module):
122
+
123
+ def __init__(self, module, invalidate_cache):
124
+ super().__init__()
125
+ self.module = module
126
+ self.dev = next(iter(self.module.parameters())).device
127
+ self.invalidate_cache=invalidate_cache
128
+
129
+ def forward(self, *inp, **kwargs):
130
+ inp = list(inp)
131
+ if inp[0].device != self.dev:
132
+ inp[0] = inp[0].to(self.dev)
133
+
134
+ if cache['mask'] is None or cache['mask'].device != self.dev or self.invalidate_cache:
135
+ cache['mask'] = kwargs['attention_mask'].to(self.dev)
136
+ kwargs['attention_mask'] = cache['mask']
137
+
138
+ if cache['position_ids'] is None or cache['position_ids'].device != self.dev or self.invalidate_cache:
139
+ cache['position_ids'] = kwargs['position_ids'].to(self.dev)
140
+ kwargs['position_ids'] = cache['position_ids']
141
+
142
+ tmp = self.module(*inp, **kwargs)
143
+ return tmp
144
+
145
+ layers = model.model.layers
146
+ from math import ceil
147
+ if not gpu_dist:
148
+ pergpu = ceil(len(layers) / len(gpus))
149
+ for i in range(len(layers)):
150
+ layers[i] = MoveModule(layers[i].to(0 if i == 0 or i == len(layers) -1 else gpus[(i-1) // pergpu]), i==0)
151
+ else:
152
+ assert gpu_dist[0] >= 2, "At least two layers must be on GPU 0."
153
+ assigned_gpus = [0] * (gpu_dist[0]-1)
154
+ for i in range(1, len(gpu_dist)):
155
+ assigned_gpus = assigned_gpus + [i] * gpu_dist[i]
156
+
157
+ remaining_assignments = len(layers)-len(assigned_gpus) - 1
158
+ if remaining_assignments > 0:
159
+ assigned_gpus = assigned_gpus + [-1] * remaining_assignments
160
+
161
+ assigned_gpus = assigned_gpus + [0]
162
+
163
+ for i in range(len(layers)):
164
+ layers[i] = MoveModule(layers[i].to(gpus[assigned_gpus[i]]), i==0)
165
+
166
+ model.gpus = gpus
167
+
168
+
169
+ def benchmark(model, input_ids, check=False):
170
+ input_ids = input_ids.to(model.gpus[0] if hasattr(model, 'gpus') else DEV)
171
+ torch.cuda.synchronize()
172
+
173
+ cache = {'past': None}
174
+
175
+ def clear_past(i):
176
+
177
+ def tmp(layer, inp, out):
178
+ if cache['past']:
179
+ cache['past'][i] = None
180
+
181
+ return tmp
182
+
183
+ for i, layer in enumerate(model.model.layers):
184
+ layer.register_forward_hook(clear_past(i))
185
+
186
+ print('Benchmarking ...')
187
+
188
+ if check:
189
+ loss = nn.CrossEntropyLoss()
190
+ tot = 0.
191
+
192
+ def sync():
193
+ if hasattr(model, 'gpus'):
194
+ for gpu in model.gpus:
195
+ torch.cuda.synchronize(gpu)
196
+ else:
197
+ torch.cuda.synchronize()
198
+
199
+ max_memory = 0
200
+ with torch.no_grad():
201
+ attention_mask = torch.ones((1, input_ids.numel()), device=DEV)
202
+ times = []
203
+ for i in range(input_ids.numel()):
204
+ tick = time.time()
205
+ out = model(input_ids[:, i:i + 1], past_key_values=cache['past'], attention_mask=attention_mask[:, :(i + 1)].reshape((1, -1)))
206
+ sync()
207
+ times.append(time.time() - tick)
208
+ print(i, times[-1])
209
+ if hasattr(model, 'gpus'):
210
+ mem_allocated = sum(torch.cuda.memory_allocated(gpu) for gpu in model.gpus) / 1024 / 1024
211
+ else:
212
+ mem_allocated = torch.cuda.memory_allocated() / 1024 / 1024
213
+ max_memory = max(max_memory, mem_allocated)
214
+ if check and i != input_ids.numel() - 1:
215
+ tot += loss(out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV)).float()
216
+ cache['past'] = list(out.past_key_values)
217
+ del out
218
+ sync()
219
+ print('Median:', np.median(times))
220
+ if check:
221
+ print('PPL:', torch.exp(tot / (input_ids.numel() - 1)).item())
222
+ print('max memory(MiB):', max_memory)
223
+
224
+
225
+ if __name__ == '__main__':
226
+
227
+ parser = argparse.ArgumentParser()
228
+
229
+ parser.add_argument('model', type=str, help='qwen model to load')
230
+ parser.add_argument('--eval', action='store_true', help='evaluate quantized model.')
231
+ parser.add_argument('--test-generation', action='store_true', help='test generation.')
232
+ parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.')
233
+ parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.')
234
+ parser.add_argument('--load', type=str, default='', help='Load quantized model.')
235
+ parser.add_argument('--check', action='store_true', help='Whether to compute perplexity during benchmarking for verification.')
236
+ parser.add_argument('--true-sequential', action='store_true', help='Whether to run in true sequential model.')
237
+ parser.add_argument('--new-eval', action='store_true', help='Whether to use the new PTB and C4 eval')
238
+ parser.add_argument('--layers-dist', type=str, default='', help='Distribution of layers across GPUs. e.g. 2:1:1 for 2 layers on GPU 0, 1 layer on GPU 1, and 1 layer on GPU 2. Any remaining layers will be assigned to your last GPU.')
239
+ parser.add_argument('--observe',
240
+ action='store_true',
241
+ help='Auto upgrade layer precision to higher precision, for example int2 to int4, groupsize 128 to 64. \
242
+ When this feature enabled, `--save` or `--save_safetensors` would be disable.')
243
+ parser.add_argument('--quant-directory', type=str, default=None, help='Specify the directory for export quantization parameters to toml format. `None` means no export by default.')
244
+
245
+
246
+ args = parser.parse_args()
247
+
248
+ if args.layers_dist:
249
+ gpu_dist = [int(x) for x in args.layers_dist.split(':')]
250
+ else:
251
+ gpu_dist = []
252
+
253
+ if type(args.load) is not str:
254
+ args.load = args.load.as_posix()
255
+
256
+ if args.load:
257
+ model = load_quant(args.model, args.load, args.wbits, args.groupsize)
258
+ else:
259
+ model = get_qwen(args.model)
260
+ model.eval()
261
+
262
+ from transformers import AutoTokenizer
263
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
264
+
265
+ layers = model.model.layers
266
+
267
+
268
+ if args.test_generation:
269
+ gpus = [torch.device('cuda:%d' % i) for i in range(torch.cuda.device_count())]
270
+ if len(gpus) > 1:
271
+ Qwen_multigpu(model, gpus, gpu_dist)
272
+ else:
273
+ model = model.to(DEV)
274
+
275
+ from transformers import AutoTokenizer, TextStreamer
276
+ tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
277
+ input_ids = tokenizer(["The capital of New Mexico is"], return_tensors="pt").input_ids.to(gpus[0])
278
+ streamer = TextStreamer(tokenizer)
279
+ with torch.no_grad():
280
+ generated_ids = model.generate(input_ids, streamer=streamer)
281
+
282
+ model = model.to(DEV)
283
+
284
+ if args.eval:
285
+ # sys.path.append("../eval_my")
286
+ # print(sys.path)
287
+ from eval_my.evaluate_ import eval_ours
288
+ eval_ours(model, tokenizer)
289
+
290
+
291
+
292
+
GPTQ-for-Qwen_hf/qwen_gptq_0.6B_loadtest.log ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
0
  0%| | 0/12 [00:00<?, ?it/s]
1
  8%|▊ | 1/12 [00:02<00:23, 2.18s/it]
2
  17%|█▋ | 2/12 [00:03<00:18, 1.86s/it]
3
  25%|██▌ | 3/12 [00:05<00:15, 1.76s/it]
4
  33%|███▎ | 4/12 [00:07<00:13, 1.71s/it]
5
  42%|████▏ | 5/12 [00:08<00:11, 1.68s/it]
6
  50%|█████ | 6/12 [00:10<00:10, 1.67s/it]
7
  58%|█████▊ | 7/12 [00:12<00:08, 1.67s/it]
8
  67%|██████▋ | 8/12 [00:14<00:07, 1.76s/it]
9
  75%|███████▌ | 9/12 [00:16<00:05, 1.84s/it]
10
  83%|████████▎ | 10/12 [00:18<00:03, 1.92s/it]
11
  92%|█████████▏| 11/12 [00:20<00:02, 2.01s/it]
 
 
 
12
  0%| | 0/12 [00:00<?, ?it/s]
 
 
 
 
 
 
13
  0%| | 0/146 [00:00<?, ?it/s]
14
  1%| | 1/146 [00:00<01:01, 2.36it/s]
15
  1%|▏ | 2/146 [00:00<00:37, 3.89it/s]
16
  2%|▏ | 3/146 [00:00<00:28, 4.98it/s]
17
  3%|▎ | 4/146 [00:00<00:24, 5.72it/s]
18
  3%|▎ | 5/146 [00:00<00:22, 6.24it/s]
19
  4%|▍ | 6/146 [00:01<00:21, 6.61it/s]
20
  5%|▍ | 7/146 [00:01<00:20, 6.86it/s]
21
  5%|▌ | 8/146 [00:01<00:19, 7.04it/s]
22
  6%|▌ | 9/146 [00:01<00:19, 7.16it/s]
23
  7%|▋ | 10/146 [00:01<00:18, 7.24it/s]
24
  8%|▊ | 11/146 [00:01<00:18, 7.31it/s]
25
  8%|▊ | 12/146 [00:01<00:18, 7.35it/s]
26
  9%|▉ | 13/146 [00:02<00:18, 7.37it/s]
27
  10%|▉ | 14/146 [00:02<00:17, 7.39it/s]
28
  10%|█ | 15/146 [00:02<00:17, 7.41it/s]
29
  11%|█ | 16/146 [00:02<00:17, 7.34it/s]
30
  12%|█▏ | 17/146 [00:02<00:17, 7.37it/s]
31
  12%|█▏ | 18/146 [00:02<00:17, 7.39it/s]
32
  13%|█▎ | 19/146 [00:02<00:17, 7.40it/s]
33
  14%|█▎ | 20/146 [00:02<00:17, 7.40it/s]
34
  14%|█▍ | 21/146 [00:03<00:16, 7.40it/s]
35
  15%|█▌ | 22/146 [00:03<00:16, 7.42it/s]
36
  16%|█▌ | 23/146 [00:03<00:16, 7.43it/s]
37
  16%|█▋ | 24/146 [00:03<00:16, 7.43it/s]
38
  17%|█▋ | 25/146 [00:03<00:16, 7.43it/s]
39
  18%|█▊ | 26/146 [00:03<00:16, 7.44it/s]
40
  18%|█▊ | 27/146 [00:03<00:15, 7.44it/s]
41
  19%|█▉ | 28/146 [00:04<00:15, 7.45it/s]
42
  20%|█▉ | 29/146 [00:04<00:15, 7.45it/s]
43
  21%|██ | 30/146 [00:04<00:15, 7.45it/s]
44
  21%|██ | 31/146 [00:04<00:15, 7.45it/s]
45
  22%|██▏ | 32/146 [00:04<00:15, 7.45it/s]
46
  23%|██▎ | 33/146 [00:04<00:15, 7.43it/s]
47
  23%|██▎ | 34/146 [00:04<00:15, 7.44it/s]
48
  24%|██▍ | 35/146 [00:05<00:14, 7.44it/s]
49
  25%|██▍ | 36/146 [00:05<00:14, 7.44it/s]
50
  25%|██▌ | 37/146 [00:05<00:14, 7.44it/s]
51
  26%|██▌ | 38/146 [00:05<00:14, 7.44it/s]
52
  27%|██▋ | 39/146 [00:05<00:14, 7.44it/s]
53
  27%|██▋ | 40/146 [00:05<00:14, 7.44it/s]
54
  28%|██▊ | 41/146 [00:05<00:14, 7.45it/s]
55
  29%|██▉ | 42/146 [00:05<00:13, 7.44it/s]
56
  29%|██▉ | 43/146 [00:06<00:13, 7.45it/s]
57
  30%|███ | 44/146 [00:06<00:13, 7.45it/s]
58
  31%|███ | 45/146 [00:06<00:13, 7.45it/s]
59
  32%|███▏ | 46/146 [00:06<00:13, 7.45it/s]
60
  32%|███▏ | 47/146 [00:06<00:13, 7.45it/s]
61
  33%|███▎ | 48/146 [00:06<00:13, 7.44it/s]
62
  34%|███▎ | 49/146 [00:06<00:13, 7.45it/s]
63
  34%|███▍ | 50/146 [00:07<00:12, 7.45it/s]
64
  35%|███▍ | 51/146 [00:07<00:12, 7.45it/s]
65
  36%|███▌ | 52/146 [00:07<00:12, 7.45it/s]
66
  36%|███▋ | 53/146 [00:07<00:12, 7.45it/s]
67
  37%|███▋ | 54/146 [00:07<00:12, 7.45it/s]
68
  38%|███▊ | 55/146 [00:07<00:12, 7.45it/s]
69
  38%|███▊ | 56/146 [00:07<00:12, 7.45it/s]
70
  39%|███▉ | 57/146 [00:07<00:11, 7.45it/s]
71
  40%|███▉ | 58/146 [00:08<00:11, 7.46it/s]
72
  40%|████ | 59/146 [00:08<00:11, 7.45it/s]
73
  41%|████ | 60/146 [00:08<00:11, 7.45it/s]
74
  42%|████▏ | 61/146 [00:08<00:11, 7.46it/s]
75
  42%|████▏ | 62/146 [00:08<00:11, 7.45it/s]
76
  43%|████▎ | 63/146 [00:08<00:11, 7.45it/s]
77
  44%|████▍ | 64/146 [00:08<00:10, 7.46it/s]
78
  45%|████▍ | 65/146 [00:09<00:10, 7.46it/s]
79
  45%|████▌ | 66/146 [00:09<00:10, 7.46it/s]
80
  46%|████▌ | 67/146 [00:09<00:10, 7.45it/s]
81
  47%|████▋ | 68/146 [00:09<00:10, 7.45it/s]
82
  47%|████▋ | 69/146 [00:09<00:10, 7.46it/s]
83
  48%|████▊ | 70/146 [00:09<00:10, 7.46it/s]
84
  49%|████▊ | 71/146 [00:09<00:10, 7.45it/s]
85
  49%|████▉ | 72/146 [00:09<00:09, 7.45it/s]
86
  50%|█████ | 73/146 [00:10<00:09, 7.45it/s]
87
  51%|█████ | 74/146 [00:10<00:09, 7.45it/s]
88
  51%|█████▏ | 75/146 [00:10<00:09, 7.45it/s]
89
  52%|█████▏ | 76/146 [00:10<00:09, 7.45it/s]
90
  53%|█████▎ | 77/146 [00:10<00:09, 7.45it/s]
91
  53%|█████▎ | 78/146 [00:10<00:09, 7.45it/s]
92
  54%|█████▍ | 79/146 [00:10<00:08, 7.46it/s]
93
  55%|█████▍ | 80/146 [00:11<00:08, 7.46it/s]
94
  55%|█████▌ | 81/146 [00:11<00:08, 7.46it/s]
95
  56%|█████▌ | 82/146 [00:11<00:08, 7.46it/s]
96
  57%|█████▋ | 83/146 [00:11<00:08, 7.46it/s]
97
  58%|█████▊ | 84/146 [00:11<00:08, 7.45it/s]
98
  58%|█████▊ | 85/146 [00:11<00:08, 7.45it/s]
99
  59%|█████▉ | 86/146 [00:11<00:08, 7.45it/s]
100
  60%|█████▉ | 87/146 [00:11<00:07, 7.45it/s]
101
  60%|██████ | 88/146 [00:12<00:07, 7.45it/s]
102
  61%|██████ | 89/146 [00:12<00:07, 7.46it/s]
103
  62%|██████▏ | 90/146 [00:12<00:07, 7.46it/s]
104
  62%|██████▏ | 91/146 [00:12<00:07, 7.45it/s]
105
  63%|██████▎ | 92/146 [00:12<00:07, 7.45it/s]
106
  64%|██████▎ | 93/146 [00:12<00:07, 7.45it/s]
107
  64%|██████▍ | 94/146 [00:12<00:06, 7.45it/s]
108
  65%|██████▌ | 95/146 [00:13<00:06, 7.45it/s]
109
  66%|██████▌ | 96/146 [00:13<00:06, 7.46it/s]
110
  66%|██████▋ | 97/146 [00:13<00:06, 7.46it/s]
111
  67%|██████▋ | 98/146 [00:13<00:06, 7.46it/s]
112
  68%|██████▊ | 99/146 [00:13<00:06, 7.45it/s]
113
  68%|██████▊ | 100/146 [00:13<00:06, 7.46it/s]
114
  69%|██████▉ | 101/146 [00:13<00:06, 7.45it/s]
115
  70%|██████▉ | 102/146 [00:13<00:05, 7.45it/s]
116
  71%|███████ | 103/146 [00:14<00:05, 7.45it/s]
117
  71%|███████ | 104/146 [00:14<00:05, 7.45it/s]
118
  72%|███████▏ | 105/146 [00:14<00:05, 7.46it/s]
119
  73%|███████▎ | 106/146 [00:14<00:05, 7.46it/s]
120
  73%|███████▎ | 107/146 [00:14<00:05, 7.46it/s]
121
  74%|███████▍ | 108/146 [00:14<00:05, 7.46it/s]
122
  75%|███████▍ | 109/146 [00:14<00:04, 7.46it/s]
123
  75%|███████▌ | 110/146 [00:15<00:04, 7.46it/s]
124
  76%|███████▌ | 111/146 [00:15<00:04, 7.46it/s]
125
  77%|███████▋ | 112/146 [00:15<00:04, 7.46it/s]
126
  77%|███████▋ | 113/146 [00:15<00:04, 7.45it/s]
127
  78%|███████▊ | 114/146 [00:15<00:04, 7.45it/s]
128
  79%|███████▉ | 115/146 [00:15<00:04, 7.45it/s]
129
  79%|███████▉ | 116/146 [00:15<00:04, 7.46it/s]
130
  80%|████████ | 117/146 [00:16<00:03, 7.46it/s]
131
  81%|████████ | 118/146 [00:16<00:03, 7.46it/s]
132
  82%|████████▏ | 119/146 [00:16<00:03, 7.46it/s]
133
  82%|████████▏ | 120/146 [00:16<00:03, 7.46it/s]
134
  83%|████████▎ | 121/146 [00:16<00:03, 7.46it/s]
135
  84%|████████▎ | 122/146 [00:16<00:03, 7.46it/s]
136
  84%|████████▍ | 123/146 [00:16<00:03, 7.46it/s]
137
  85%|████████▍ | 124/146 [00:16<00:02, 7.46it/s]
138
  86%|████████▌ | 125/146 [00:17<00:02, 7.46it/s]
139
  86%|████████▋ | 126/146 [00:17<00:02, 7.46it/s]
140
  87%|████████▋ | 127/146 [00:17<00:02, 7.46it/s]
141
  88%|████████▊ | 128/146 [00:17<00:02, 7.45it/s]
142
  88%|████████▊ | 129/146 [00:17<00:02, 7.45it/s]
143
  89%|████████▉ | 130/146 [00:17<00:02, 7.46it/s]
144
  90%|████████▉ | 131/146 [00:17<00:02, 7.46it/s]
145
  90%|█████████ | 132/146 [00:18<00:01, 7.46it/s]
146
  91%|█████████ | 133/146 [00:18<00:01, 7.46it/s]
147
  92%|█████████▏| 134/146 [00:18<00:01, 7.46it/s]
148
  92%|█████████▏| 135/146 [00:18<00:01, 7.46it/s]
149
  93%|█████████▎| 136/146 [00:18<00:01, 7.46it/s]
150
  94%|█████████▍| 137/146 [00:18<00:01, 7.46it/s]
151
  95%|█████████▍| 138/146 [00:18<00:01, 7.45it/s]
152
  95%|█████████▌| 139/146 [00:18<00:00, 7.46it/s]
153
  96%|█████████▌| 140/146 [00:19<00:00, 7.45it/s]
154
  97%|█████████▋| 141/146 [00:19<00:00, 7.45it/s]
155
  97%|█████████▋| 142/146 [00:19<00:00, 7.45it/s]
156
  98%|█████████▊| 143/146 [00:19<00:00, 7.45it/s]
157
  99%|█████████▊| 144/146 [00:19<00:00, 7.45it/s]
158
  99%|█████████▉| 145/146 [00:19<00:00, 7.46it/s]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /mnt/data/lin/pretrained_models/GPTQ-for-Qwen_hf/quant/quant_linear.py:285: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
2
+ @custom_fwd(cast_inputs=torch.float16)
3
+ /mnt/data/lin/pretrained_models/GPTQ-for-Qwen_hf/quant/quant_linear.py:294: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
4
+ def backward(ctx, grad_output):
5
+ `loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
6
+ /home/beihang/lin/pretrained_models/GPTQ-for-Qwen_hf/qwen.py:94: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
7
+ model.load_state_dict(torch.load(checkpoint))
8
+ Loading model ...
9
+ Found 5 unique KN Linear values.
10
+ Warming up autotune cache ...
11
+
12
  0%| | 0/12 [00:00<?, ?it/s]
13
  8%|▊ | 1/12 [00:02<00:23, 2.18s/it]
14
  17%|█▋ | 2/12 [00:03<00:18, 1.86s/it]
15
  25%|██▌ | 3/12 [00:05<00:15, 1.76s/it]
16
  33%|███▎ | 4/12 [00:07<00:13, 1.71s/it]
17
  42%|████▏ | 5/12 [00:08<00:11, 1.68s/it]
18
  50%|█████ | 6/12 [00:10<00:10, 1.67s/it]
19
  58%|█████▊ | 7/12 [00:12<00:08, 1.67s/it]
20
  67%|██████▋ | 8/12 [00:14<00:07, 1.76s/it]
21
  75%|███████▌ | 9/12 [00:16<00:05, 1.84s/it]
22
  83%|████████▎ | 10/12 [00:18<00:03, 1.92s/it]
23
  92%|█████████▏| 11/12 [00:20<00:02, 2.01s/it]
24
+ Found 0 unique fused mlp KN values.
25
+ Warming up autotune cache ...
26
+
27
  0%| | 0/12 [00:00<?, ?it/s]
28
+ 2025-05-10:22:35:04,025 WARNING [huggingface.py:98] `pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way.
29
+ 2025-05-10:22:35:04,048 WARNING [huggingface.py:279] Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration
30
+ Token indices sequence length is longer than the specified maximum sequence length for this model (299078 > 131072). Running this sequence through the model will result in indexing errors
31
+ Done.
32
+ bos/eos tokens updated: tokenizer.bos_token_id=1, tokenizer.eos_token_id=2
33
+
34
  0%| | 0/146 [00:00<?, ?it/s]
35
  1%| | 1/146 [00:00<01:01, 2.36it/s]
36
  1%|▏ | 2/146 [00:00<00:37, 3.89it/s]
37
  2%|▏ | 3/146 [00:00<00:28, 4.98it/s]
38
  3%|▎ | 4/146 [00:00<00:24, 5.72it/s]
39
  3%|▎ | 5/146 [00:00<00:22, 6.24it/s]
40
  4%|▍ | 6/146 [00:01<00:21, 6.61it/s]
41
  5%|▍ | 7/146 [00:01<00:20, 6.86it/s]
42
  5%|▌ | 8/146 [00:01<00:19, 7.04it/s]
43
  6%|▌ | 9/146 [00:01<00:19, 7.16it/s]
44
  7%|▋ | 10/146 [00:01<00:18, 7.24it/s]
45
  8%|▊ | 11/146 [00:01<00:18, 7.31it/s]
46
  8%|▊ | 12/146 [00:01<00:18, 7.35it/s]
47
  9%|▉ | 13/146 [00:02<00:18, 7.37it/s]
48
  10%|▉ | 14/146 [00:02<00:17, 7.39it/s]
49
  10%|█ | 15/146 [00:02<00:17, 7.41it/s]
50
  11%|█ | 16/146 [00:02<00:17, 7.34it/s]
51
  12%|█▏ | 17/146 [00:02<00:17, 7.37it/s]
52
  12%|█▏ | 18/146 [00:02<00:17, 7.39it/s]
53
  13%|█▎ | 19/146 [00:02<00:17, 7.40it/s]
54
  14%|█▎ | 20/146 [00:02<00:17, 7.40it/s]
55
  14%|█▍ | 21/146 [00:03<00:16, 7.40it/s]
56
  15%|█▌ | 22/146 [00:03<00:16, 7.42it/s]
57
  16%|█▌ | 23/146 [00:03<00:16, 7.43it/s]
58
  16%|█▋ | 24/146 [00:03<00:16, 7.43it/s]
59
  17%|█▋ | 25/146 [00:03<00:16, 7.43it/s]
60
  18%|█▊ | 26/146 [00:03<00:16, 7.44it/s]
61
  18%|█▊ | 27/146 [00:03<00:15, 7.44it/s]
62
  19%|█▉ | 28/146 [00:04<00:15, 7.45it/s]
63
  20%|█▉ | 29/146 [00:04<00:15, 7.45it/s]
64
  21%|██ | 30/146 [00:04<00:15, 7.45it/s]
65
  21%|██ | 31/146 [00:04<00:15, 7.45it/s]
66
  22%|██▏ | 32/146 [00:04<00:15, 7.45it/s]
67
  23%|██▎ | 33/146 [00:04<00:15, 7.43it/s]
68
  23%|██▎ | 34/146 [00:04<00:15, 7.44it/s]
69
  24%|██▍ | 35/146 [00:05<00:14, 7.44it/s]
70
  25%|██▍ | 36/146 [00:05<00:14, 7.44it/s]
71
  25%|██▌ | 37/146 [00:05<00:14, 7.44it/s]
72
  26%|██▌ | 38/146 [00:05<00:14, 7.44it/s]
73
  27%|██▋ | 39/146 [00:05<00:14, 7.44it/s]
74
  27%|██▋ | 40/146 [00:05<00:14, 7.44it/s]
75
  28%|██▊ | 41/146 [00:05<00:14, 7.45it/s]
76
  29%|██▉ | 42/146 [00:05<00:13, 7.44it/s]
77
  29%|██▉ | 43/146 [00:06<00:13, 7.45it/s]
78
  30%|███ | 44/146 [00:06<00:13, 7.45it/s]
79
  31%|███ | 45/146 [00:06<00:13, 7.45it/s]
80
  32%|███▏ | 46/146 [00:06<00:13, 7.45it/s]
81
  32%|███▏ | 47/146 [00:06<00:13, 7.45it/s]
82
  33%|███▎ | 48/146 [00:06<00:13, 7.44it/s]
83
  34%|███▎ | 49/146 [00:06<00:13, 7.45it/s]
84
  34%|███▍ | 50/146 [00:07<00:12, 7.45it/s]
85
  35%|███▍ | 51/146 [00:07<00:12, 7.45it/s]
86
  36%|███▌ | 52/146 [00:07<00:12, 7.45it/s]
87
  36%|███▋ | 53/146 [00:07<00:12, 7.45it/s]
88
  37%|███▋ | 54/146 [00:07<00:12, 7.45it/s]
89
  38%|███▊ | 55/146 [00:07<00:12, 7.45it/s]
90
  38%|███▊ | 56/146 [00:07<00:12, 7.45it/s]
91
  39%|███▉ | 57/146 [00:07<00:11, 7.45it/s]
92
  40%|███▉ | 58/146 [00:08<00:11, 7.46it/s]
93
  40%|████ | 59/146 [00:08<00:11, 7.45it/s]
94
  41%|████ | 60/146 [00:08<00:11, 7.45it/s]
95
  42%|████▏ | 61/146 [00:08<00:11, 7.46it/s]
96
  42%|████▏ | 62/146 [00:08<00:11, 7.45it/s]
97
  43%|████▎ | 63/146 [00:08<00:11, 7.45it/s]
98
  44%|████▍ | 64/146 [00:08<00:10, 7.46it/s]
99
  45%|████▍ | 65/146 [00:09<00:10, 7.46it/s]
100
  45%|████▌ | 66/146 [00:09<00:10, 7.46it/s]
101
  46%|████▌ | 67/146 [00:09<00:10, 7.45it/s]
102
  47%|████▋ | 68/146 [00:09<00:10, 7.45it/s]
103
  47%|████▋ | 69/146 [00:09<00:10, 7.46it/s]
104
  48%|████▊ | 70/146 [00:09<00:10, 7.46it/s]
105
  49%|████▊ | 71/146 [00:09<00:10, 7.45it/s]
106
  49%|████▉ | 72/146 [00:09<00:09, 7.45it/s]
107
  50%|█████ | 73/146 [00:10<00:09, 7.45it/s]
108
  51%|█████ | 74/146 [00:10<00:09, 7.45it/s]
109
  51%|█████▏ | 75/146 [00:10<00:09, 7.45it/s]
110
  52%|█████▏ | 76/146 [00:10<00:09, 7.45it/s]
111
  53%|█████▎ | 77/146 [00:10<00:09, 7.45it/s]
112
  53%|█████▎ | 78/146 [00:10<00:09, 7.45it/s]
113
  54%|█████▍ | 79/146 [00:10<00:08, 7.46it/s]
114
  55%|█████▍ | 80/146 [00:11<00:08, 7.46it/s]
115
  55%|█████▌ | 81/146 [00:11<00:08, 7.46it/s]
116
  56%|█████▌ | 82/146 [00:11<00:08, 7.46it/s]
117
  57%|█████▋ | 83/146 [00:11<00:08, 7.46it/s]
118
  58%|█████▊ | 84/146 [00:11<00:08, 7.45it/s]
119
  58%|█████▊ | 85/146 [00:11<00:08, 7.45it/s]
120
  59%|█████▉ | 86/146 [00:11<00:08, 7.45it/s]
121
  60%|█████▉ | 87/146 [00:11<00:07, 7.45it/s]
122
  60%|██████ | 88/146 [00:12<00:07, 7.45it/s]
123
  61%|██████ | 89/146 [00:12<00:07, 7.46it/s]
124
  62%|██████▏ | 90/146 [00:12<00:07, 7.46it/s]
125
  62%|██████▏ | 91/146 [00:12<00:07, 7.45it/s]
126
  63%|██████▎ | 92/146 [00:12<00:07, 7.45it/s]
127
  64%|██████▎ | 93/146 [00:12<00:07, 7.45it/s]
128
  64%|██████▍ | 94/146 [00:12<00:06, 7.45it/s]
129
  65%|██████▌ | 95/146 [00:13<00:06, 7.45it/s]
130
  66%|██████▌ | 96/146 [00:13<00:06, 7.46it/s]
131
  66%|██████▋ | 97/146 [00:13<00:06, 7.46it/s]
132
  67%|██████▋ | 98/146 [00:13<00:06, 7.46it/s]
133
  68%|██████▊ | 99/146 [00:13<00:06, 7.45it/s]
134
  68%|██████▊ | 100/146 [00:13<00:06, 7.46it/s]
135
  69%|██████▉ | 101/146 [00:13<00:06, 7.45it/s]
136
  70%|██████▉ | 102/146 [00:13<00:05, 7.45it/s]
137
  71%|███████ | 103/146 [00:14<00:05, 7.45it/s]
138
  71%|███████ | 104/146 [00:14<00:05, 7.45it/s]
139
  72%|███████▏ | 105/146 [00:14<00:05, 7.46it/s]
140
  73%|███████▎ | 106/146 [00:14<00:05, 7.46it/s]
141
  73%|███████▎ | 107/146 [00:14<00:05, 7.46it/s]
142
  74%|███████▍ | 108/146 [00:14<00:05, 7.46it/s]
143
  75%|███████▍ | 109/146 [00:14<00:04, 7.46it/s]
144
  75%|███████▌ | 110/146 [00:15<00:04, 7.46it/s]
145
  76%|███████▌ | 111/146 [00:15<00:04, 7.46it/s]
146
  77%|███████▋ | 112/146 [00:15<00:04, 7.46it/s]
147
  77%|███████▋ | 113/146 [00:15<00:04, 7.45it/s]
148
  78%|███████▊ | 114/146 [00:15<00:04, 7.45it/s]
149
  79%|███████▉ | 115/146 [00:15<00:04, 7.45it/s]
150
  79%|███████▉ | 116/146 [00:15<00:04, 7.46it/s]
151
  80%|████████ | 117/146 [00:16<00:03, 7.46it/s]
152
  81%|████████ | 118/146 [00:16<00:03, 7.46it/s]
153
  82%|████████▏ | 119/146 [00:16<00:03, 7.46it/s]
154
  82%|████████▏ | 120/146 [00:16<00:03, 7.46it/s]
155
  83%|████████▎ | 121/146 [00:16<00:03, 7.46it/s]
156
  84%|████████▎ | 122/146 [00:16<00:03, 7.46it/s]
157
  84%|████████▍ | 123/146 [00:16<00:03, 7.46it/s]
158
  85%|████████▍ | 124/146 [00:16<00:02, 7.46it/s]
159
  86%|████████▌ | 125/146 [00:17<00:02, 7.46it/s]
160
  86%|████████▋ | 126/146 [00:17<00:02, 7.46it/s]
161
  87%|████████▋ | 127/146 [00:17<00:02, 7.46it/s]
162
  88%|████████▊ | 128/146 [00:17<00:02, 7.45it/s]
163
  88%|████████▊ | 129/146 [00:17<00:02, 7.45it/s]
164
  89%|████████▉ | 130/146 [00:17<00:02, 7.46it/s]
165
  90%|████████▉ | 131/146 [00:17<00:02, 7.46it/s]
166
  90%|█████████ | 132/146 [00:18<00:01, 7.46it/s]
167
  91%|█████████ | 133/146 [00:18<00:01, 7.46it/s]
168
  92%|█████████▏| 134/146 [00:18<00:01, 7.46it/s]
169
  92%|█████████▏| 135/146 [00:18<00:01, 7.46it/s]
170
  93%|█████████▎| 136/146 [00:18<00:01, 7.46it/s]
171
  94%|█████████▍| 137/146 [00:18<00:01, 7.46it/s]
172
  95%|█████████▍| 138/146 [00:18<00:01, 7.45it/s]
173
  95%|█████████▌| 139/146 [00:18<00:00, 7.46it/s]
174
  96%|█████████▌| 140/146 [00:19<00:00, 7.45it/s]
175
  97%|█████████▋| 141/146 [00:19<00:00, 7.45it/s]
176
  97%|█████████▋| 142/146 [00:19<00:00, 7.45it/s]
177
  98%|█████████▊| 143/146 [00:19<00:00, 7.45it/s]
178
  99%|█████████▊| 144/146 [00:19<00:00, 7.45it/s]
179
  99%|█████████▉| 145/146 [00:19<00:00, 7.46it/s]
180
+ wikitext2 18.208261489868164
181
+ Traceback (most recent call last):
182
+ File "/home/beihang/lin/pretrained_models/GPTQ-for-Qwen_hf/qwen.py", line 288, in <module>
183
+ eval_ours(model, tokenizer)
184
+ File "/mnt/data/lin/pretrained_models/GPTQ-for-Qwen_hf/eval_my/evaluate_.py", line 163, in eval_ours
185
+ results = evaluate_model(
186
+ File "/home/beihang/anaconda3/envs/qwen3/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
187
+ return func(*args, **kwargs)
188
+ File "/mnt/data/lin/pretrained_models/GPTQ-for-Qwen_hf/eval_my/evaluate_.py", line 108, in evaluate_model
189
+ testloader = get_eval_loaders(dataset, tokenizer)
190
+ File "/mnt/data/lin/pretrained_models/GPTQ-for-Qwen_hf/eval_my/datautils.py", line 67, in get_eval_loaders
191
+ return get_c4(tokenizer)
192
+ File "/mnt/data/lin/pretrained_models/GPTQ-for-Qwen_hf/eval_my/datautils.py", line 690, in get_c4
193
+ valdata = load_from_disk('eval_my/ppl_datasets/allenai/c4/allenai--c4/validation')
194
+ File "/home/beihang/anaconda3/envs/qwen3/lib/python3.10/site-packages/datasets/load.py", line 2694, in load_from_disk
195
+ raise FileNotFoundError(f"Directory {dataset_path} not found")
196
+ FileNotFoundError: Directory eval_my/ppl_datasets/allenai/c4/allenai--c4/validation not found
GPTQ-for-Qwen_hf/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ safetensors==0.3.1
2
+ datasets==2.10.1
3
+ sentencepiece
4
+ git+https://github.com/huggingface/transformers
5
+ accelerate==0.20.3
6
+ triton==2.0.0
7
+ texttable
8
+ toml
9
+ numpy
10
+ protobuf==3.20.2
11
+
GPTQ-for-Qwen_hf/test.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ CUDA_VISIBLE_DEVICES=6 python /home/beihang/lin/pretrained_models/GPTQ-for-Qwen_hf/qwen.py /home/beihang/lin/pretrained_models/Qwen3/Qwen3-0.6B --wbits 4 --groupsize -1 --eval --load /home/beihang/lin/pretrained_models/eval_my/GPTQ-for-LLaMa-triton/gptq_0.6B_w4_perchannel.pth 2>&1 | tee /home/beihang/lin/pretrained_models/GPTQ-for-Qwen_hf/qwen_gptq_0.6B_loadtest.log
GPTQ-for-Qwen_hf/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .modelutils import DEV, find_layers, gen_conditions, torch_snr_error
2
+ from .datautils import set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders
3
+ from .export import export_quant_table
GPTQ-for-Qwen_hf/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (485 Bytes). View file
 
GPTQ-for-Qwen_hf/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (488 Bytes). View file
 
GPTQ-for-Qwen_hf/utils/__pycache__/datautils.cpython-310.pyc ADDED
Binary file (4.31 kB). View file
 
GPTQ-for-Qwen_hf/utils/__pycache__/datautils.cpython-38.pyc ADDED
Binary file (5.15 kB). View file
 
GPTQ-for-Qwen_hf/utils/__pycache__/export.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
GPTQ-for-Qwen_hf/utils/__pycache__/export.cpython-38.pyc ADDED
Binary file (1.19 kB). View file
 
GPTQ-for-Qwen_hf/utils/__pycache__/modelutils.cpython-310.pyc ADDED
Binary file (2.31 kB). View file
 
GPTQ-for-Qwen_hf/utils/__pycache__/modelutils.cpython-38.pyc ADDED
Binary file (2.3 kB). View file
 
GPTQ-for-Qwen_hf/utils/datautils.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ def set_seed(seed):
6
+ np.random.seed(seed)
7
+ torch.random.manual_seed(seed)
8
+
9
+
10
+ def get_wikitext2(nsamples, seed, seqlen, model):
11
+ from datasets import load_dataset
12
+ traindata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train')
13
+ testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
14
+
15
+ from transformers import AutoTokenizer
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
18
+ except:
19
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
20
+ trainenc = tokenizer("\n\n".join(traindata['text']), return_tensors='pt')
21
+ testenc = tokenizer("\n\n".join(testdata['text']), return_tensors='pt')
22
+
23
+ import random
24
+ random.seed(seed)
25
+ trainloader = []
26
+ for _ in range(nsamples):
27
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
28
+ j = i + seqlen
29
+ inp = trainenc.input_ids[:, i:j]
30
+ tar = inp.clone()
31
+ tar[:, :-1] = -100
32
+ trainloader.append((inp, tar))
33
+ return trainloader, testenc
34
+
35
+
36
+ def get_ptb(nsamples, seed, seqlen, model):
37
+ from datasets import load_dataset
38
+ traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
39
+ valdata = load_dataset('ptb_text_only', 'penn_treebank', split='validation')
40
+
41
+ from transformers import AutoTokenizer
42
+ try:
43
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
44
+ except:
45
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
46
+ trainenc = tokenizer("\n\n".join(traindata['sentence']), return_tensors='pt')
47
+ testenc = tokenizer("\n\n".join(valdata['sentence']), return_tensors='pt')
48
+
49
+ import random
50
+ random.seed(seed)
51
+ trainloader = []
52
+ for _ in range(nsamples):
53
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
54
+ j = i + seqlen
55
+ inp = trainenc.input_ids[:, i:j]
56
+ tar = inp.clone()
57
+ tar[:, :-1] = -100
58
+ trainloader.append((inp, tar))
59
+ return trainloader, testenc
60
+
61
+
62
+ # def get_c4(nsamples, seed, seqlen, model):
63
+ # from datasets import load_dataset
64
+ # traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train', use_auth_token=False)
65
+ # valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation', use_auth_token=False)
66
+
67
+ # from transformers import AutoTokenizer
68
+ # try:
69
+ # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
70
+ # except:
71
+ # tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
72
+
73
+ # import random
74
+ # random.seed(seed)
75
+ # trainloader = []
76
+ # for _ in range(nsamples):
77
+ # while True:
78
+ # i = random.randint(0, len(traindata) - 1)
79
+ # trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
80
+ # if trainenc.input_ids.shape[1] >= seqlen:
81
+ # break
82
+ # i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
83
+ # j = i + seqlen
84
+ # inp = trainenc.input_ids[:, i:j]
85
+ # tar = inp.clone()
86
+ # tar[:, :-1] = -100
87
+ # trainloader.append((inp, tar))
88
+
89
+ # import random
90
+ # random.seed(0)
91
+ # valenc = []
92
+ # for _ in range(256):
93
+ # while True:
94
+ # i = random.randint(0, len(valdata) - 1)
95
+ # tmp = tokenizer(valdata[i]['text'], return_tensors='pt')
96
+ # if tmp.input_ids.shape[1] >= seqlen:
97
+ # break
98
+ # i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1)
99
+ # j = i + seqlen
100
+ # valenc.append(tmp.input_ids[:, i:j])
101
+ # valenc = torch.hstack(valenc)
102
+
103
+ # class TokenizerWrapper:
104
+
105
+ # def __init__(self, input_ids):
106
+ # self.input_ids = input_ids
107
+
108
+ # valenc = TokenizerWrapper(valenc)
109
+
110
+ # return trainloader, valenc
111
+
112
+ class TokenizerWrapper:
113
+ def __init__(self, input_ids):
114
+ self.input_ids = input_ids
115
+
116
+ import numpy as np
117
+ import torch
118
+ from datasets import load_dataset, load_from_disk
119
+ from transformers import AutoTokenizer, LlamaTokenizer
120
+ import os
121
+ import random
122
+
123
+ def get_c4(nsamples, seed, seqlen, model, tokenizer):
124
+ # traindata = load_dataset(
125
+ # 'allenai/c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train'
126
+ # )
127
+ # valdata = load_dataset(
128
+ # 'allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation'
129
+ # )
130
+ alldata = load_from_disk('/home/beihang/lin/pretrained_models/c4/allenai--c4')
131
+ traindata = alldata['train']
132
+ valdata = alldata['validation']
133
+
134
+ random.seed(seed)
135
+ trainloader = []
136
+ for _ in range(nsamples):
137
+ while True:
138
+ i = random.randint(0, len(traindata) - 1)
139
+ trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
140
+ if trainenc.input_ids.shape[1] > seqlen:
141
+ break
142
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
143
+ j = i + seqlen
144
+ inp = trainenc.input_ids[:, i:j]
145
+ tar = inp.clone()
146
+ tar[:, :-1] = -100
147
+ trainloader.append((inp, tar))
148
+
149
+ valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
150
+ valenc = valenc.input_ids[:, :(256 * seqlen)]
151
+
152
+
153
+ valenc = TokenizerWrapper(valenc)
154
+
155
+ return trainloader, valenc
156
+
157
+
158
+ def get_ptb_new(nsamples, seed, seqlen, model):
159
+ from datasets import load_dataset
160
+ traindata = load_dataset('ptb_text_only', 'penn_treebank', split='train')
161
+ testdata = load_dataset('ptb_text_only', 'penn_treebank', split='test')
162
+
163
+ from transformers import AutoTokenizer
164
+ try:
165
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
166
+ except:
167
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
168
+ trainenc = tokenizer(" ".join(traindata['sentence']), return_tensors='pt')
169
+ testenc = tokenizer(" ".join(testdata['sentence']), return_tensors='pt')
170
+
171
+ import random
172
+ random.seed(seed)
173
+ trainloader = []
174
+ for _ in range(nsamples):
175
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
176
+ j = i + seqlen
177
+ inp = trainenc.input_ids[:, i:j]
178
+ tar = inp.clone()
179
+ tar[:, :-1] = -100
180
+ trainloader.append((inp, tar))
181
+ return trainloader, testenc
182
+
183
+
184
+ def get_c4_new(nsamples, seed, seqlen, model):
185
+ from datasets import load_dataset
186
+ traindata = load_dataset('allenai/c4', 'allenai--c4', data_files={'train': 'en/c4-train.00000-of-01024.json.gz'}, split='train')
187
+ valdata = load_dataset('allenai/c4', 'allenai--c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')
188
+
189
+ from transformers import AutoTokenizer
190
+ try:
191
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False)
192
+ except:
193
+ tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True)
194
+
195
+ import random
196
+ random.seed(seed)
197
+ trainloader = []
198
+ for _ in range(nsamples):
199
+ while True:
200
+ i = random.randint(0, len(traindata) - 1)
201
+ trainenc = tokenizer(traindata[i]['text'], return_tensors='pt')
202
+ if trainenc.input_ids.shape[1] >= seqlen:
203
+ break
204
+ i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1)
205
+ j = i + seqlen
206
+ inp = trainenc.input_ids[:, i:j]
207
+ tar = inp.clone()
208
+ tar[:, :-1] = -100
209
+ trainloader.append((inp, tar))
210
+
211
+ valenc = tokenizer(' '.join(valdata[:1100]['text']), return_tensors='pt')
212
+ valenc = valenc.input_ids[:, :(256 * seqlen)]
213
+
214
+ class TokenizerWrapper:
215
+
216
+ def __init__(self, input_ids):
217
+ self.input_ids = input_ids
218
+
219
+ valenc = TokenizerWrapper(valenc)
220
+
221
+ return trainloader, valenc
222
+
223
+
224
+ def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model='',tokenizer = None):
225
+ if 'wikitext2' in name:
226
+ return get_wikitext2(nsamples, seed, seqlen, model)
227
+ if 'ptb' in name:
228
+ if 'new' in name:
229
+ return get_ptb_new(nsamples, seed, seqlen, model)
230
+ return get_ptb(nsamples, seed, seqlen, model)
231
+ if 'c4' in name:
232
+ if 'new' in name:
233
+ return get_c4_new(nsamples, seed, seqlen, model)
234
+ return get_c4(nsamples, seed, seqlen, model,tokenizer=tokenizer)
GPTQ-for-Qwen_hf/utils/export.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import toml
3
+ import os
4
+
5
+
6
+ def export_quant_table(quantizers: dict, quant_dir: str, format: str = 'toml'):
7
+
8
+ table = {}
9
+
10
+ def save_tensor(name: str, tensor):
11
+ np.save(os.path.join(quant_dir, name), tensor.numpy())
12
+ return '{}.npy'.format(name)
13
+
14
+ for key, value in quantizers.items():
15
+ quantizer = value[0]
16
+
17
+ dump = dict()
18
+
19
+ sym = quantizer.sym
20
+ if not sym:
21
+ dump['zero'] = save_tensor(name=key + '.zero', tensor=value[2])
22
+ dump['scale'] = save_tensor(name=key + '.scale', tensor=value[1])
23
+ dump['wbits'] = value[4]
24
+ dump['groupsize'] = value[5]
25
+ if value[5] > 0:
26
+ dump['group_ids'] = save_tensor(name=key + '.group_ids', tensor=value[3])
27
+
28
+ dump['sym'] = sym
29
+ dump['perchannel'] = quantizer.perchannel
30
+
31
+ table[key] = dump
32
+
33
+ if not os.path.exists(quant_dir):
34
+ os.mkdir(quant_dir)
35
+
36
+ with open(os.path.join(quant_dir, 'quant.toml'), 'w') as f:
37
+ toml.dump(table, f)
GPTQ-for-Qwen_hf/utils/modelutils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ DEV = torch.device('cuda:0')
5
+
6
+
7
+ def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
8
+ if type(module) in layers:
9
+ return {name: module}
10
+ res = {}
11
+ for name1, child in module.named_children():
12
+ res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
13
+ return res
14
+
15
+
16
+ def gen_conditions(_wbits, _groupsize):
17
+ wbits = _wbits
18
+ groupsize = _groupsize
19
+ conditions = []
20
+ while True:
21
+ if wbits >= 8:
22
+ if groupsize == -1 or groupsize == 32:
23
+ break
24
+
25
+ if groupsize > 32:
26
+ groupsize /= 2
27
+ else:
28
+ wbits *= 2
29
+ groupsize = _groupsize
30
+
31
+ conditions.append((int(wbits), int(groupsize)))
32
+ return conditions
33
+
34
+
35
+ # copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py
36
+ def torch_snr_error(y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
37
+ """
38
+ Compute SNR between y_pred(tensor) and y_real(tensor)
39
+
40
+ SNR can be calcualted as following equation:
41
+
42
+ SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2
43
+
44
+ if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements.
45
+
46
+ SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2)
47
+ Args:
48
+ y_pred (torch.Tensor): _description_
49
+ y_real (torch.Tensor): _description_
50
+ reduction (str, optional): _description_. Defaults to 'mean'.
51
+ Raises:
52
+ ValueError: _description_
53
+ ValueError: _description_
54
+ Returns:
55
+ torch.Tensor: _description_
56
+ """
57
+ y_pred = y_pred.type(torch.float32)
58
+ y_real = y_real.type(torch.float32)
59
+
60
+ if y_pred.shape != y_real.shape:
61
+ raise ValueError(f'Can not compute snr loss for tensors with different shape. '
62
+ f'({y_pred.shape} and {y_real.shape})')
63
+ reduction = str(reduction).lower()
64
+
65
+ if y_pred.ndim == 1:
66
+ y_pred = y_pred.unsqueeze(0)
67
+ y_real = y_real.unsqueeze(0)
68
+
69
+ y_pred = y_pred.flatten(start_dim=1)
70
+ y_real = y_real.flatten(start_dim=1)
71
+
72
+ noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1)
73
+ signal_power = torch.pow(y_real, 2).sum(dim=-1)
74
+ snr = (noise_power) / (signal_power + 1e-7)
75
+
76
+ if reduction == 'mean':
77
+ return torch.mean(snr)
78
+ elif reduction == 'sum':
79
+ return torch.sum(snr)
80
+ elif reduction == 'none':
81
+ return snr
82
+ else:
83
+ raise ValueError(f'Unsupported reduction method.')