medmekk HF Staff commited on
Commit
4bdedae
·
verified ·
1 Parent(s): 60ea4cc

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - kernel
4
+ ---
5
+
6
+ OpenAI Triton flash-attention with attention sinks
build.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [general]
2
+ name = "triton_flash_attn_sink"
3
+ universal = true
build/torch-universal/triton_flash_attn_sink/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .attention import attention
2
+
3
+ ___all__ = ["attention"]
build/torch-universal/triton_flash_attn_sink/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (252 Bytes). View file
 
build/torch-universal/triton_flash_attn_sink/__pycache__/attention.cpython-313.pyc ADDED
Binary file (24.1 kB). View file
 
build/torch-universal/triton_flash_attn_sink/_ops.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ ops = torch.ops._triton_flash_attn_sink_a266b56
3
+
4
+ def add_op_namespace_prefix(op_name: str):
5
+ """
6
+ Prefix op by namespace.
7
+ """
8
+ return f"_triton_flash_attn_sink_a266b56::{op_name}"
build/torch-universal/triton_flash_attn_sink/attention.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashAttention w/support for learned sinks and banded attention.
2
+
3
+ This is an expanded version of the Flash Attention v2 implementation (see https://tridao.me/publications/flash2/flash2.pdf)
4
+ which can be found at https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html.
5
+
6
+ This version has been extended to support banded attention and learned attention sinks.
7
+ """
8
+
9
+ import torch
10
+
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ # ──────────────────────────────────────────────────────────────────────────────
16
+ # _attn_fwd_inner
17
+ # (thanks o3 for the help + kind comment strings....)
18
+ # ──────────────────────────────────────────────────────────────────────────────
19
+ @triton.jit
20
+ def _attn_fwd_inner(
21
+ acc,
22
+ l_i,
23
+ m_i,
24
+ q,
25
+ K_block_ptr,
26
+ V_block_ptr,
27
+ start_m,
28
+ qk_scale,
29
+ BLOCK_M: tl.constexpr,
30
+ HEAD_DIM: tl.constexpr,
31
+ BLOCK_N: tl.constexpr,
32
+ STAGE: tl.constexpr,
33
+ offs_m: tl.constexpr,
34
+ offs_n: tl.constexpr,
35
+ N_CTX: tl.constexpr,
36
+ BANDWIDTH: tl.constexpr,
37
+ ):
38
+ # ---------------- range of kv indices for this stage ---------------------
39
+ if STAGE == 1:
40
+ # off-band (used only when BANDWIDTH == 0)
41
+ lo, hi = 0, start_m * BLOCK_M
42
+ elif STAGE == 2:
43
+ # on-band **plus** the preceding tokens that fall inside `BANDWIDTH`
44
+ if BANDWIDTH == 0: # full context → current block only
45
+ lo = start_m * BLOCK_M
46
+ else: # local context
47
+ lo = tl.maximum(0, start_m * BLOCK_M - BANDWIDTH)
48
+ hi = (start_m + 1) * BLOCK_M
49
+ # make the compiler aware that `lo` is a multiple of BLOCK_N so that
50
+ # the first `tl.load` is aligned (matches what the large kernel does)
51
+ lo = tl.multiple_of(lo, BLOCK_N)
52
+ else: # STAGE == 3 (non-causal)
53
+ lo, hi = 0, N_CTX
54
+
55
+ # advance the KV block-pointers so they point at `lo`
56
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
57
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
58
+
59
+ # ---------------- main loop over K/V tiles -------------------------------
60
+ for start_n in range(lo, hi, BLOCK_N):
61
+ start_n = tl.multiple_of(start_n, BLOCK_N)
62
+
63
+ # ---- Q·Kᵀ ------------------------------------------------------------
64
+ k = tl.load(K_block_ptr)
65
+ qk = tl.dot(q, k)
66
+
67
+ # ------------- causal + bandwidth masking (STAGE == 2) ----------------
68
+ if STAGE == 2:
69
+ # causal mask (j ≤ i)
70
+ causal_ok = offs_m[:, None] >= (start_n + offs_n[None, :])
71
+
72
+ if BANDWIDTH == 0: # full causal attention
73
+ mask = causal_ok
74
+ else: # local causal attention
75
+ # j ≥ i − BANDWIDTH + 1 ⟺ i < j + BANDWIDTH
76
+ within_bw = offs_m[:, None] < (start_n + offs_n[None, :] + BANDWIDTH)
77
+ mask = causal_ok & within_bw
78
+
79
+ qk = qk * qk_scale + tl.where(mask, 0.0, -1.0e30)
80
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
81
+ qk -= m_ij[:, None]
82
+ else:
83
+ # STAGE 1 (when BANDWIDTH == 0) or STAGE 3 (non-causal)
84
+ m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
85
+ qk = qk * qk_scale - m_ij[:, None]
86
+
87
+ # ---- softmax ---------------------------------------------------------
88
+ p = tl.math.exp2(qk)
89
+ l_ij = tl.sum(p, 1)
90
+
91
+ # ---- running numerically-stable accumulators -------------------------
92
+ alpha = tl.math.exp2(m_i - m_ij)
93
+ l_i = l_i * alpha + l_ij
94
+ acc = acc * alpha[:, None]
95
+
96
+ v = tl.load(V_block_ptr)
97
+ p = p.to(tl.float16)
98
+ acc = tl.dot(p, v, acc)
99
+
100
+ m_i = m_ij
101
+
102
+ # ---- advance pointers ------------------------------------------------
103
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
104
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
105
+
106
+ return acc, l_i, m_i
107
+
108
+
109
+ @triton.jit
110
+ def _attn_fwd(
111
+ Q,
112
+ K,
113
+ V,
114
+ Sinks,
115
+ sm_scale,
116
+ M,
117
+ Out, #
118
+ stride_qz,
119
+ stride_qh,
120
+ stride_qm,
121
+ stride_qk, #
122
+ stride_kz,
123
+ stride_kh,
124
+ stride_kn,
125
+ stride_kk, #
126
+ stride_vz,
127
+ stride_vh,
128
+ stride_vk,
129
+ stride_vn, #
130
+ stride_oz,
131
+ stride_oh,
132
+ stride_om,
133
+ stride_on, #
134
+ Z,
135
+ H,
136
+ N_CTX, #
137
+ HEAD_DIM: tl.constexpr, #
138
+ BLOCK_M: tl.constexpr, #
139
+ BLOCK_N: tl.constexpr, #
140
+ STAGE: tl.constexpr, #
141
+ BANDWIDTH: tl.constexpr,
142
+ ):
143
+ tl.static_assert(BLOCK_N <= HEAD_DIM)
144
+ start_m = tl.program_id(0)
145
+ off_hz = tl.program_id(1)
146
+ off_z = off_hz // H
147
+ off_h = off_hz % H
148
+ qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
149
+
150
+ # block pointers
151
+ Q_block_ptr = tl.make_block_ptr(
152
+ base=Q + qvk_offset,
153
+ shape=(N_CTX, HEAD_DIM),
154
+ strides=(stride_qm, stride_qk),
155
+ offsets=(start_m * BLOCK_M, 0),
156
+ block_shape=(BLOCK_M, HEAD_DIM),
157
+ order=(1, 0),
158
+ )
159
+ v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
160
+ V_block_ptr = tl.make_block_ptr(
161
+ base=V + qvk_offset,
162
+ shape=(N_CTX, HEAD_DIM),
163
+ strides=(stride_vk, stride_vn),
164
+ offsets=(0, 0),
165
+ block_shape=(BLOCK_N, HEAD_DIM),
166
+ order=v_order,
167
+ )
168
+ K_block_ptr = tl.make_block_ptr(
169
+ base=K + qvk_offset,
170
+ shape=(HEAD_DIM, N_CTX),
171
+ strides=(stride_kk, stride_kn),
172
+ offsets=(0, 0),
173
+ block_shape=(HEAD_DIM, BLOCK_N),
174
+ order=(0, 1),
175
+ )
176
+ O_block_ptr = tl.make_block_ptr(
177
+ base=Out + qvk_offset,
178
+ shape=(N_CTX, HEAD_DIM),
179
+ strides=(stride_om, stride_on),
180
+ offsets=(start_m * BLOCK_M, 0),
181
+ block_shape=(BLOCK_M, HEAD_DIM),
182
+ order=(1, 0),
183
+ )
184
+
185
+ # load attention sinks
186
+ if Sinks is not None:
187
+ sink = tl.load(Sinks + off_h).to(tl.float32)
188
+ else:
189
+ sink = -1.0e30
190
+
191
+ # initialize offsets
192
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
193
+ offs_n = tl.arange(0, BLOCK_N)
194
+ # initialize pointer to m and l
195
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
196
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
197
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
198
+ # load scales
199
+ qk_scale = sm_scale
200
+ qk_scale *= 1.44269504 # 1/log(2)
201
+ # load q: it will stay in SRAM throughout
202
+ q = tl.load(Q_block_ptr)
203
+ # stage 1: off-band
204
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
205
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
206
+ if STAGE & 1 and (BANDWIDTH == 0):
207
+ acc, l_i, m_i = _attn_fwd_inner(
208
+ acc,
209
+ l_i,
210
+ m_i,
211
+ q,
212
+ K_block_ptr,
213
+ V_block_ptr, #
214
+ start_m,
215
+ qk_scale, #
216
+ BLOCK_M,
217
+ HEAD_DIM,
218
+ BLOCK_N, #
219
+ 4 - STAGE,
220
+ offs_m,
221
+ offs_n,
222
+ N_CTX,
223
+ BANDWIDTH,
224
+ )
225
+ # stage 2: on-band
226
+ if STAGE & 2:
227
+ # barrier makes it easier for compielr to schedule the
228
+ # two loops independently
229
+ acc, l_i, m_i = _attn_fwd_inner(
230
+ acc,
231
+ l_i,
232
+ m_i,
233
+ q,
234
+ K_block_ptr,
235
+ V_block_ptr, #
236
+ start_m,
237
+ qk_scale, #
238
+ BLOCK_M,
239
+ HEAD_DIM,
240
+ BLOCK_N, #
241
+ 2,
242
+ offs_m,
243
+ offs_n,
244
+ N_CTX,
245
+ BANDWIDTH,
246
+ )
247
+ # epilogue
248
+ m_i += tl.math.log2(l_i)
249
+ acc = acc / l_i[:, None]
250
+ m_ptrs = M + off_hz * N_CTX + offs_m
251
+ tl.store(m_ptrs, m_i)
252
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
253
+
254
+
255
+ @triton.jit
256
+ def _attn_bwd_preprocess(
257
+ O,
258
+ DO, #
259
+ Sinks,
260
+ DSinks,
261
+ DSinkstemp,
262
+ Atomic_counters,
263
+ M,
264
+ Delta, #
265
+ Z,
266
+ H,
267
+ N_CTX, #
268
+ BLOCK_M: tl.constexpr,
269
+ HEAD_DIM: tl.constexpr, #
270
+ ):
271
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
272
+ off_hz = tl.program_id(1)
273
+ off_n = tl.arange(0, HEAD_DIM)
274
+ off_z = off_hz // H
275
+ off_h = off_hz % H
276
+ # load
277
+ o = tl.load(
278
+ O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
279
+ )
280
+ do = tl.load(
281
+ DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
282
+ ).to(tl.float32)
283
+ delta = tl.sum(o * do, axis=1)
284
+ # write-back
285
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
286
+
287
+ if Sinks is not None:
288
+ m = tl.load(M + off_z * H * N_CTX + off_h * N_CTX + off_m)
289
+ sink = tl.load(Sinks + off_h)
290
+ dl = tl.sum(tl.math.exp2(sink - m) * delta, axis=0)
291
+
292
+ depth = Z * (N_CTX // BLOCK_M)
293
+
294
+ tl.store(
295
+ DSinkstemp + (off_h * Z + off_z) * (N_CTX // BLOCK_M) + tl.program_id(0), dl
296
+ )
297
+
298
+ if tl.atomic_add(Atomic_counters + off_h, 1) == depth - 1:
299
+ dl_acc = 0.0
300
+
301
+ for i in range(0, depth, BLOCK_M):
302
+ idxs = i + tl.arange(0, BLOCK_M)
303
+ temps = tl.load(
304
+ DSinkstemp + off_h * depth + idxs, mask=(idxs < depth), other=0.0
305
+ )
306
+ dl_acc += tl.sum(temps, axis=0)
307
+
308
+ tl.store(DSinks + off_h, (-0.69314718) * dl_acc)
309
+
310
+
311
+ # The main inner-loop logic for computing dK and dV.
312
+ @triton.jit
313
+ def _attn_bwd_dkdv(
314
+ dk,
315
+ dv, #
316
+ Q,
317
+ k,
318
+ v,
319
+ sm_scale, #
320
+ DO, #
321
+ M,
322
+ D, #
323
+ # shared by Q/K/V/DO.
324
+ stride_tok,
325
+ stride_d, #
326
+ H,
327
+ N_CTX,
328
+ BLOCK_M1: tl.constexpr, #
329
+ BLOCK_N1: tl.constexpr, #
330
+ HEAD_DIM: tl.constexpr, #
331
+ # Filled in by the wrapper.
332
+ start_n,
333
+ start_m,
334
+ num_steps, #
335
+ MASK: tl.constexpr,
336
+ BANDWIDTH: tl.constexpr,
337
+ ):
338
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
339
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
340
+ offs_k = tl.arange(0, HEAD_DIM)
341
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
342
+ do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
343
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
344
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
345
+ curr_m = start_m
346
+ step_m = BLOCK_M1
347
+ for blk_idx in range(num_steps):
348
+ qT = tl.load(qT_ptrs)
349
+ # Load m before computing qk to reduce pipeline stall.
350
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
351
+ m = tl.load(M + offs_m)
352
+ qkT = tl.dot(k, qT)
353
+ pT = tl.math.exp2(qkT - m[None, :])
354
+ # Autoregressive masking.
355
+ if MASK:
356
+ if BANDWIDTH == 0: # full causal
357
+ mask = offs_m[None, :] >= offs_n[:, None]
358
+ else: # local causal
359
+ mask = (offs_m[None, :] >= offs_n[:, None]) & (
360
+ offs_m[None, :] < offs_n[:, None] + BANDWIDTH
361
+ )
362
+ pT = tl.where(mask, pT, 0.0)
363
+ do = tl.load(do_ptrs)
364
+ # Compute dV.
365
+ ppT = pT
366
+ ppT = ppT.to(tl.float16)
367
+ dv += tl.dot(ppT, do)
368
+ # D (= delta) is pre-divided by ds_scale.
369
+ Di = tl.load(D + offs_m)
370
+ # Compute dP and dS.
371
+ dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
372
+ dsT = pT * (dpT - Di[None, :])
373
+ dsT = dsT.to(tl.float16)
374
+ dk += tl.dot(dsT, tl.trans(qT))
375
+ # Increment pointers.
376
+ curr_m += step_m
377
+ qT_ptrs += step_m * stride_tok
378
+ do_ptrs += step_m * stride_tok
379
+ return dk, dv
380
+
381
+
382
+ # the main inner-loop logic for computing dQ
383
+ @triton.jit
384
+ def _attn_bwd_dq(
385
+ dq,
386
+ q,
387
+ K,
388
+ V, #
389
+ do,
390
+ m,
391
+ D,
392
+ # shared by Q/K/V/DO.
393
+ stride_tok,
394
+ stride_d, #
395
+ H,
396
+ N_CTX, #
397
+ BLOCK_M2: tl.constexpr, #
398
+ BLOCK_N2: tl.constexpr, #
399
+ HEAD_DIM: tl.constexpr,
400
+ BANDWIDTH: tl.constexpr,
401
+ # Filled in by the wrapper.
402
+ start_m,
403
+ start_n,
404
+ num_steps, #
405
+ MASK: tl.constexpr,
406
+ ):
407
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
408
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
409
+ offs_k = tl.arange(0, HEAD_DIM)
410
+ kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
411
+ vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
412
+ # D (= delta) is pre-divided by ds_scale.
413
+ Di = tl.load(D + offs_m)
414
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
415
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
416
+ curr_n = start_n
417
+ step_n = BLOCK_N2
418
+ for blk_idx in range(num_steps):
419
+ kT = tl.load(kT_ptrs)
420
+ vT = tl.load(vT_ptrs)
421
+ qk = tl.dot(q, kT)
422
+ p = tl.math.exp2(qk - m)
423
+ # Autoregressive masking.
424
+ if MASK:
425
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
426
+ if BANDWIDTH == 0: # full causal
427
+ mask = offs_m[:, None] >= offs_n[None, :]
428
+ else: # local causal
429
+ mask = (offs_m[:, None] >= offs_n[None, :]) & (
430
+ offs_m[:, None] < offs_n[None, :] + BANDWIDTH
431
+ )
432
+ p = tl.where(mask, p, 0.0)
433
+ # Compute dP and dS.
434
+ dp = tl.dot(do, vT).to(tl.float32)
435
+ ds = p * (dp - Di[:, None])
436
+ ds = ds.to(tl.float16)
437
+ # Compute dQ.
438
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
439
+ dq += tl.dot(ds, tl.trans(kT))
440
+ # Increment pointers.
441
+ curr_n += step_n
442
+ kT_ptrs += step_n * stride_tok
443
+ vT_ptrs += step_n * stride_tok
444
+ return dq
445
+
446
+
447
+ @triton.jit
448
+ def _attn_bwd(
449
+ Q,
450
+ K,
451
+ V,
452
+ sm_scale, #
453
+ DO, #
454
+ DQ,
455
+ DK,
456
+ DV, #
457
+ M,
458
+ D,
459
+ # shared by Q/K/V/DO.
460
+ stride_z,
461
+ stride_h,
462
+ stride_tok,
463
+ stride_d, #
464
+ H,
465
+ N_CTX, #
466
+ BANDWIDTH: tl.constexpr,
467
+ BLOCK_M1: tl.constexpr, #
468
+ BLOCK_N1: tl.constexpr, #
469
+ BLOCK_M2: tl.constexpr, #
470
+ BLOCK_N2: tl.constexpr, #
471
+ BLK_SLICE_FACTOR: tl.constexpr, #
472
+ HEAD_DIM: tl.constexpr,
473
+ ):
474
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
475
+
476
+ bhid = tl.program_id(2)
477
+ off_chz = (bhid * N_CTX).to(tl.int64)
478
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
479
+ pid = tl.program_id(0)
480
+
481
+ # offset pointers for batch/head
482
+ Q += adj
483
+ K += adj
484
+ V += adj
485
+ DO += adj
486
+ DQ += adj
487
+ DK += adj
488
+ DV += adj
489
+ M += off_chz
490
+ D += off_chz
491
+
492
+ # load scales
493
+ offs_k = tl.arange(0, HEAD_DIM)
494
+
495
+ start_n = pid * BLOCK_N1
496
+ start_m = start_n
497
+
498
+ MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
499
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
500
+
501
+ dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
502
+ dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
503
+
504
+ # load K and V: they stay in SRAM throughout the inner loop.
505
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
506
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
507
+
508
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
509
+
510
+ dk, dv = _attn_bwd_dkdv(
511
+ dk,
512
+ dv, #
513
+ Q,
514
+ k,
515
+ v,
516
+ sm_scale, #
517
+ DO, #
518
+ M,
519
+ D, #
520
+ stride_tok,
521
+ stride_d, #
522
+ H,
523
+ N_CTX, #
524
+ MASK_BLOCK_M1,
525
+ BLOCK_N1,
526
+ HEAD_DIM, #
527
+ start_n,
528
+ start_m,
529
+ num_steps, #
530
+ MASK=True, #
531
+ BANDWIDTH=BANDWIDTH,
532
+ )
533
+
534
+ start_m += num_steps * MASK_BLOCK_M1
535
+ # how many *additional* rows may still attend to the current key block?
536
+ if BANDWIDTH == 0:
537
+ rows_left = N_CTX - start_m
538
+ else:
539
+ rows_left = min(N_CTX - start_m, BLOCK_N1)
540
+ num_steps = rows_left // BLOCK_M1
541
+
542
+ # Compute dK and dV for non-masked blocks.
543
+ dk, dv = _attn_bwd_dkdv( #
544
+ dk,
545
+ dv, #
546
+ Q,
547
+ k,
548
+ v,
549
+ sm_scale, #
550
+ DO, #
551
+ M,
552
+ D, #
553
+ stride_tok,
554
+ stride_d, #
555
+ H,
556
+ N_CTX, #
557
+ BLOCK_M1,
558
+ BLOCK_N1,
559
+ HEAD_DIM, #
560
+ start_n,
561
+ start_m,
562
+ num_steps, #
563
+ MASK=BANDWIDTH != 0, #
564
+ BANDWIDTH=BANDWIDTH,
565
+ )
566
+
567
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
568
+ tl.store(dv_ptrs, dv)
569
+
570
+ # Write back dK.
571
+ dk *= sm_scale
572
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
573
+ tl.store(dk_ptrs, dk)
574
+
575
+ # THIS BLOCK DOES DQ:
576
+ start_m = pid * BLOCK_M2
577
+ end_n = start_m + BLOCK_M2
578
+
579
+ MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
580
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
581
+
582
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
583
+ dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
584
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
585
+
586
+ m = tl.load(M + offs_m)
587
+ m = m[:, None]
588
+
589
+ # Compute dQ for masked (diagonal) blocks.
590
+ # NOTE: This code scans each row of QK^T backward (from right to left,
591
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
592
+ # not due to anything important. I just wanted to reuse the loop
593
+ # structure for dK & dV above as much as possible.
594
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
595
+ dq = _attn_bwd_dq(
596
+ dq,
597
+ q,
598
+ K,
599
+ V, #
600
+ do,
601
+ m,
602
+ D, #
603
+ stride_tok,
604
+ stride_d, #
605
+ H,
606
+ N_CTX, #
607
+ BLOCK_M2,
608
+ MASK_BLOCK_N2,
609
+ HEAD_DIM, #
610
+ BANDWIDTH,
611
+ start_m,
612
+ end_n - num_steps * MASK_BLOCK_N2,
613
+ num_steps, #
614
+ MASK=True, #
615
+ )
616
+ end_n -= num_steps * MASK_BLOCK_N2
617
+
618
+ # stage-1 (rows that still fall inside the window)
619
+ if BANDWIDTH == 0:
620
+ cols_left = end_n
621
+ else:
622
+ cols_left = min(end_n, BLOCK_M2)
623
+ num_steps = cols_left // BLOCK_N2
624
+ dq = _attn_bwd_dq(
625
+ dq,
626
+ q,
627
+ K,
628
+ V, #
629
+ do,
630
+ m,
631
+ D, #
632
+ stride_tok,
633
+ stride_d, #
634
+ H,
635
+ N_CTX, #
636
+ BLOCK_M2,
637
+ BLOCK_N2,
638
+ HEAD_DIM, #
639
+ BANDWIDTH,
640
+ start_m,
641
+ end_n - num_steps * BLOCK_N2,
642
+ num_steps, #
643
+ MASK=BANDWIDTH != 0, #
644
+ )
645
+ # Write back dQ.
646
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
647
+ dq *= LN2
648
+ tl.store(dq_ptrs, dq)
649
+
650
+
651
+ class _attention(torch.autograd.Function):
652
+ @staticmethod
653
+ def forward(
654
+ ctx,
655
+ q,
656
+ k,
657
+ v,
658
+ sinks,
659
+ causal,
660
+ sm_scale,
661
+ bandwidth,
662
+ warp_specialize=True,
663
+ USE_TMA=True,
664
+ ):
665
+ # shape constraints
666
+ HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
667
+ # when v is in float8_e5m2 it is transposed.
668
+ HEAD_DIM_V = v.shape[-1]
669
+ assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
670
+ assert HEAD_DIM_K in {16, 32, 64, 128, 256}
671
+ o = torch.empty_like(q)
672
+ stage = 3 if causal else 1
673
+ extra_kern_args = {}
674
+ M = torch.empty(
675
+ (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
676
+ )
677
+ BLOCK_M = 128
678
+ grid = (
679
+ triton.cdiv(q.shape[2], BLOCK_M),
680
+ q.shape[0] * q.shape[1],
681
+ 1,
682
+ )
683
+ _attn_fwd[grid](
684
+ q,
685
+ k,
686
+ v,
687
+ sinks,
688
+ sm_scale,
689
+ M,
690
+ o, #
691
+ q.stride(0),
692
+ q.stride(1),
693
+ q.stride(2),
694
+ q.stride(3), #
695
+ k.stride(0),
696
+ k.stride(1),
697
+ k.stride(2),
698
+ k.stride(3), #
699
+ v.stride(0),
700
+ v.stride(1),
701
+ v.stride(2),
702
+ v.stride(3), #
703
+ o.stride(0),
704
+ o.stride(1),
705
+ o.stride(2),
706
+ o.stride(3), #
707
+ q.shape[0],
708
+ q.shape[1], #
709
+ N_CTX=q.shape[2], #
710
+ HEAD_DIM=HEAD_DIM_K, #
711
+ STAGE=stage, #
712
+ BANDWIDTH=bandwidth,
713
+ BLOCK_M=BLOCK_M,
714
+ BLOCK_N=64,
715
+ **extra_kern_args,
716
+ )
717
+
718
+ ctx.save_for_backward(q, k, v, sinks, o, M)
719
+ ctx.sm_scale = sm_scale
720
+ ctx.HEAD_DIM = HEAD_DIM_K
721
+ ctx.causal = causal
722
+ ctx.bandwidth = bandwidth
723
+ return o
724
+
725
+ @staticmethod
726
+ def backward(ctx, do):
727
+ q, k, v, sinks, o, M = ctx.saved_tensors
728
+ do = do.contiguous()
729
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
730
+ dq = torch.empty_like(q)
731
+ dk = torch.empty_like(k)
732
+ dv = torch.empty_like(v)
733
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
734
+ PRE_BLOCK = 128
735
+ NUM_WARPS, NUM_STAGES = 4, 5
736
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
737
+ BLK_SLICE_FACTOR = 2
738
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
739
+ arg_k = k
740
+ arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
741
+ PRE_BLOCK = 128
742
+ assert N_CTX % PRE_BLOCK == 0
743
+ pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
744
+ delta = torch.empty_like(M)
745
+ if sinks is not None:
746
+ dsinks = torch.empty_like(sinks)
747
+ dsinkstemp = torch.empty(pre_grid, dtype=torch.float32, device=sinks.device)
748
+ atomic_counters = torch.zeros(
749
+ N_HEAD, dtype=torch.int32, device=sinks.device
750
+ )
751
+ else:
752
+ dsinks, dsinkstemp, atomic_counters = None, None, None
753
+ _attn_bwd_preprocess[pre_grid](
754
+ o,
755
+ do, #
756
+ # Info for attention sinks.
757
+ sinks,
758
+ dsinks,
759
+ dsinkstemp,
760
+ atomic_counters,
761
+ M,
762
+ ######
763
+ delta, #
764
+ BATCH,
765
+ N_HEAD,
766
+ N_CTX, #
767
+ BLOCK_M=PRE_BLOCK,
768
+ HEAD_DIM=ctx.HEAD_DIM, #
769
+ )
770
+ grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
771
+ _attn_bwd[grid](
772
+ q,
773
+ arg_k,
774
+ v,
775
+ ctx.sm_scale,
776
+ do,
777
+ dq,
778
+ dk,
779
+ dv, #
780
+ M,
781
+ delta, #
782
+ q.stride(0),
783
+ q.stride(1),
784
+ q.stride(2),
785
+ q.stride(3), #
786
+ N_HEAD,
787
+ N_CTX, #
788
+ BANDWIDTH=ctx.bandwidth,
789
+ BLOCK_M1=BLOCK_M1,
790
+ BLOCK_N1=BLOCK_N1, #
791
+ BLOCK_M2=BLOCK_M2,
792
+ BLOCK_N2=BLOCK_N2, #
793
+ BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
794
+ HEAD_DIM=ctx.HEAD_DIM, #
795
+ num_warps=NUM_WARPS, #
796
+ num_stages=NUM_STAGES, #
797
+ )
798
+
799
+ return dq, dk, dv, dsinks, None, None, None, None, None
800
+
801
+
802
+ attention = _attention.apply
803
+
flake.lock ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1753354560,
77
+ "narHash": "sha256-vmOfRmr0Qm/IbZTWB2sBn+UFrABSTTA/cTg+m27Yt/E=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "7f2aceda2a2e72cd573bdb25e5c0667fd75f89d3",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1753354632,
102
+ "narHash": "sha256-31SX3Raiyx0qCuY9JSlx9ZZgxljeUxvW+JdujjxbofQ=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "524b628fd8e58525dbd28455bffb0628092c5265",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "ref": "torch-2.8",
111
+ "repo": "kernel-builder",
112
+ "type": "github"
113
+ }
114
+ },
115
+ "nixpkgs": {
116
+ "locked": {
117
+ "lastModified": 1752785354,
118
+ "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
119
+ "owner": "nixos",
120
+ "repo": "nixpkgs",
121
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
122
+ "type": "github"
123
+ },
124
+ "original": {
125
+ "owner": "nixos",
126
+ "repo": "nixpkgs",
127
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
128
+ "type": "github"
129
+ }
130
+ },
131
+ "root": {
132
+ "inputs": {
133
+ "kernel-builder": "kernel-builder"
134
+ }
135
+ },
136
+ "systems": {
137
+ "locked": {
138
+ "lastModified": 1681028828,
139
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
140
+ "owner": "nix-systems",
141
+ "repo": "default",
142
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
143
+ "type": "github"
144
+ },
145
+ "original": {
146
+ "owner": "nix-systems",
147
+ "repo": "default",
148
+ "type": "github"
149
+ }
150
+ },
151
+ "systems_2": {
152
+ "locked": {
153
+ "lastModified": 1681028828,
154
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
155
+ "owner": "nix-systems",
156
+ "repo": "default",
157
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
158
+ "type": "github"
159
+ },
160
+ "original": {
161
+ "owner": "nix-systems",
162
+ "repo": "default",
163
+ "type": "github"
164
+ }
165
+ }
166
+ },
167
+ "root": "root",
168
+ "version": 7
169
+ }
flake.nix ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for Triton flash attention with attention sinks";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder/torch-2.8";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ };
17
+ }
torch-ext/triton_flash_attn_sink/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .attention import attention
2
+
3
+ ___all__ = ["attention"]
torch-ext/triton_flash_attn_sink/attention.py ADDED
@@ -0,0 +1,802 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlashAttention w/support for learned sinks and banded attention.
2
+
3
+ This is an expanded version of the Flash Attention v2 implementation (see https://tridao.me/publications/flash2/flash2.pdf)
4
+ which can be found at https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html.
5
+
6
+ This version has been extended to support banded attention and learned attention sinks.
7
+ """
8
+
9
+ import torch
10
+
11
+ import triton
12
+ import triton.language as tl
13
+
14
+
15
+ # ──────────────────────────────────────────────────────────────────────────────
16
+ # _attn_fwd_inner
17
+ # (thanks o3 for the help + kind comment strings....)
18
+ # ──────────────────────────────────────────────────────────────────────────────
19
+ @triton.jit
20
+ def _attn_fwd_inner(
21
+ acc,
22
+ l_i,
23
+ m_i,
24
+ q,
25
+ K_block_ptr,
26
+ V_block_ptr,
27
+ start_m,
28
+ qk_scale,
29
+ BLOCK_M: tl.constexpr,
30
+ HEAD_DIM: tl.constexpr,
31
+ BLOCK_N: tl.constexpr,
32
+ STAGE: tl.constexpr,
33
+ offs_m: tl.constexpr,
34
+ offs_n: tl.constexpr,
35
+ N_CTX: tl.constexpr,
36
+ BANDWIDTH: tl.constexpr,
37
+ ):
38
+ # ---------------- range of kv indices for this stage ---------------------
39
+ if STAGE == 1:
40
+ # off-band (used only when BANDWIDTH == 0)
41
+ lo, hi = 0, start_m * BLOCK_M
42
+ elif STAGE == 2:
43
+ # on-band **plus** the preceding tokens that fall inside `BANDWIDTH`
44
+ if BANDWIDTH == 0: # full context → current block only
45
+ lo = start_m * BLOCK_M
46
+ else: # local context
47
+ lo = tl.maximum(0, start_m * BLOCK_M - BANDWIDTH)
48
+ hi = (start_m + 1) * BLOCK_M
49
+ # make the compiler aware that `lo` is a multiple of BLOCK_N so that
50
+ # the first `tl.load` is aligned (matches what the large kernel does)
51
+ lo = tl.multiple_of(lo, BLOCK_N)
52
+ else: # STAGE == 3 (non-causal)
53
+ lo, hi = 0, N_CTX
54
+
55
+ # advance the KV block-pointers so they point at `lo`
56
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
57
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
58
+
59
+ # ---------------- main loop over K/V tiles -------------------------------
60
+ for start_n in range(lo, hi, BLOCK_N):
61
+ start_n = tl.multiple_of(start_n, BLOCK_N)
62
+
63
+ # ---- Q·Kᵀ ------------------------------------------------------------
64
+ k = tl.load(K_block_ptr)
65
+ qk = tl.dot(q, k)
66
+
67
+ # ------------- causal + bandwidth masking (STAGE == 2) ----------------
68
+ if STAGE == 2:
69
+ # causal mask (j ≤ i)
70
+ causal_ok = offs_m[:, None] >= (start_n + offs_n[None, :])
71
+
72
+ if BANDWIDTH == 0: # full causal attention
73
+ mask = causal_ok
74
+ else: # local causal attention
75
+ # j ≥ i − BANDWIDTH + 1 ⟺ i < j + BANDWIDTH
76
+ within_bw = offs_m[:, None] < (start_n + offs_n[None, :] + BANDWIDTH)
77
+ mask = causal_ok & within_bw
78
+
79
+ qk = qk * qk_scale + tl.where(mask, 0.0, -1.0e30)
80
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
81
+ qk -= m_ij[:, None]
82
+ else:
83
+ # STAGE 1 (when BANDWIDTH == 0) or STAGE 3 (non-causal)
84
+ m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
85
+ qk = qk * qk_scale - m_ij[:, None]
86
+
87
+ # ---- softmax ---------------------------------------------------------
88
+ p = tl.math.exp2(qk)
89
+ l_ij = tl.sum(p, 1)
90
+
91
+ # ---- running numerically-stable accumulators -------------------------
92
+ alpha = tl.math.exp2(m_i - m_ij)
93
+ l_i = l_i * alpha + l_ij
94
+ acc = acc * alpha[:, None]
95
+
96
+ v = tl.load(V_block_ptr)
97
+ acc = tl.dot(p, v, acc)
98
+
99
+ m_i = m_ij
100
+
101
+ # ---- advance pointers ------------------------------------------------
102
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
103
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
104
+
105
+ return acc, l_i, m_i
106
+
107
+
108
+ @triton.jit
109
+ def _attn_fwd(
110
+ Q,
111
+ K,
112
+ V,
113
+ Sinks,
114
+ sm_scale,
115
+ M,
116
+ Out, #
117
+ stride_qz,
118
+ stride_qh,
119
+ stride_qm,
120
+ stride_qk, #
121
+ stride_kz,
122
+ stride_kh,
123
+ stride_kn,
124
+ stride_kk, #
125
+ stride_vz,
126
+ stride_vh,
127
+ stride_vk,
128
+ stride_vn, #
129
+ stride_oz,
130
+ stride_oh,
131
+ stride_om,
132
+ stride_on, #
133
+ Z,
134
+ H,
135
+ N_CTX, #
136
+ HEAD_DIM: tl.constexpr, #
137
+ BLOCK_M: tl.constexpr, #
138
+ BLOCK_N: tl.constexpr, #
139
+ STAGE: tl.constexpr, #
140
+ BANDWIDTH: tl.constexpr,
141
+ ):
142
+ tl.static_assert(BLOCK_N <= HEAD_DIM)
143
+ start_m = tl.program_id(0)
144
+ off_hz = tl.program_id(1)
145
+ off_z = off_hz // H
146
+ off_h = off_hz % H
147
+ qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
148
+
149
+ # block pointers
150
+ Q_block_ptr = tl.make_block_ptr(
151
+ base=Q + qvk_offset,
152
+ shape=(N_CTX, HEAD_DIM),
153
+ strides=(stride_qm, stride_qk),
154
+ offsets=(start_m * BLOCK_M, 0),
155
+ block_shape=(BLOCK_M, HEAD_DIM),
156
+ order=(1, 0),
157
+ )
158
+ v_order: tl.constexpr = (0, 1) if V.dtype.element_ty == tl.float8e5 else (1, 0)
159
+ V_block_ptr = tl.make_block_ptr(
160
+ base=V + qvk_offset,
161
+ shape=(N_CTX, HEAD_DIM),
162
+ strides=(stride_vk, stride_vn),
163
+ offsets=(0, 0),
164
+ block_shape=(BLOCK_N, HEAD_DIM),
165
+ order=v_order,
166
+ )
167
+ K_block_ptr = tl.make_block_ptr(
168
+ base=K + qvk_offset,
169
+ shape=(HEAD_DIM, N_CTX),
170
+ strides=(stride_kk, stride_kn),
171
+ offsets=(0, 0),
172
+ block_shape=(HEAD_DIM, BLOCK_N),
173
+ order=(0, 1),
174
+ )
175
+ O_block_ptr = tl.make_block_ptr(
176
+ base=Out + qvk_offset,
177
+ shape=(N_CTX, HEAD_DIM),
178
+ strides=(stride_om, stride_on),
179
+ offsets=(start_m * BLOCK_M, 0),
180
+ block_shape=(BLOCK_M, HEAD_DIM),
181
+ order=(1, 0),
182
+ )
183
+
184
+ # load attention sinks
185
+ if Sinks is not None:
186
+ sink = tl.load(Sinks + off_h).to(tl.float32)
187
+ else:
188
+ sink = -1.0e30
189
+
190
+ # initialize offsets
191
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
192
+ offs_n = tl.arange(0, BLOCK_N)
193
+ # initialize pointer to m and l
194
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) + sink
195
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
196
+ acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
197
+ # load scales
198
+ qk_scale = sm_scale
199
+ qk_scale *= 1.44269504 # 1/log(2)
200
+ # load q: it will stay in SRAM throughout
201
+ q = tl.load(Q_block_ptr)
202
+ # stage 1: off-band
203
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
204
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
205
+ if STAGE & 1 and (BANDWIDTH == 0):
206
+ acc, l_i, m_i = _attn_fwd_inner(
207
+ acc,
208
+ l_i,
209
+ m_i,
210
+ q,
211
+ K_block_ptr,
212
+ V_block_ptr, #
213
+ start_m,
214
+ qk_scale, #
215
+ BLOCK_M,
216
+ HEAD_DIM,
217
+ BLOCK_N, #
218
+ 4 - STAGE,
219
+ offs_m,
220
+ offs_n,
221
+ N_CTX,
222
+ BANDWIDTH,
223
+ )
224
+ # stage 2: on-band
225
+ if STAGE & 2:
226
+ # barrier makes it easier for compielr to schedule the
227
+ # two loops independently
228
+ acc, l_i, m_i = _attn_fwd_inner(
229
+ acc,
230
+ l_i,
231
+ m_i,
232
+ q,
233
+ K_block_ptr,
234
+ V_block_ptr, #
235
+ start_m,
236
+ qk_scale, #
237
+ BLOCK_M,
238
+ HEAD_DIM,
239
+ BLOCK_N, #
240
+ 2,
241
+ offs_m,
242
+ offs_n,
243
+ N_CTX,
244
+ BANDWIDTH,
245
+ )
246
+ # epilogue
247
+ m_i += tl.math.log2(l_i)
248
+ acc = acc / l_i[:, None]
249
+ m_ptrs = M + off_hz * N_CTX + offs_m
250
+ tl.store(m_ptrs, m_i)
251
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
252
+
253
+
254
+ @triton.jit
255
+ def _attn_bwd_preprocess(
256
+ O,
257
+ DO, #
258
+ Sinks,
259
+ DSinks,
260
+ DSinkstemp,
261
+ Atomic_counters,
262
+ M,
263
+ Delta, #
264
+ Z,
265
+ H,
266
+ N_CTX, #
267
+ BLOCK_M: tl.constexpr,
268
+ HEAD_DIM: tl.constexpr, #
269
+ ):
270
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
271
+ off_hz = tl.program_id(1)
272
+ off_n = tl.arange(0, HEAD_DIM)
273
+ off_z = off_hz // H
274
+ off_h = off_hz % H
275
+ # load
276
+ o = tl.load(
277
+ O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
278
+ )
279
+ do = tl.load(
280
+ DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]
281
+ ).to(tl.float32)
282
+ delta = tl.sum(o * do, axis=1)
283
+ # write-back
284
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
285
+
286
+ if Sinks is not None:
287
+ m = tl.load(M + off_z * H * N_CTX + off_h * N_CTX + off_m)
288
+ sink = tl.load(Sinks + off_h)
289
+ dl = tl.sum(tl.math.exp2(sink - m) * delta, axis=0)
290
+
291
+ depth = Z * (N_CTX // BLOCK_M)
292
+
293
+ tl.store(
294
+ DSinkstemp + (off_h * Z + off_z) * (N_CTX // BLOCK_M) + tl.program_id(0), dl
295
+ )
296
+
297
+ if tl.atomic_add(Atomic_counters + off_h, 1) == depth - 1:
298
+ dl_acc = 0.0
299
+
300
+ for i in range(0, depth, BLOCK_M):
301
+ idxs = i + tl.arange(0, BLOCK_M)
302
+ temps = tl.load(
303
+ DSinkstemp + off_h * depth + idxs, mask=(idxs < depth), other=0.0
304
+ )
305
+ dl_acc += tl.sum(temps, axis=0)
306
+
307
+ tl.store(DSinks + off_h, (-0.69314718) * dl_acc)
308
+
309
+
310
+ # The main inner-loop logic for computing dK and dV.
311
+ @triton.jit
312
+ def _attn_bwd_dkdv(
313
+ dk,
314
+ dv, #
315
+ Q,
316
+ k,
317
+ v,
318
+ sm_scale, #
319
+ DO, #
320
+ M,
321
+ D, #
322
+ # shared by Q/K/V/DO.
323
+ stride_tok,
324
+ stride_d, #
325
+ H,
326
+ N_CTX,
327
+ BLOCK_M1: tl.constexpr, #
328
+ BLOCK_N1: tl.constexpr, #
329
+ HEAD_DIM: tl.constexpr, #
330
+ # Filled in by the wrapper.
331
+ start_n,
332
+ start_m,
333
+ num_steps, #
334
+ MASK: tl.constexpr,
335
+ BANDWIDTH: tl.constexpr,
336
+ ):
337
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
338
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
339
+ offs_k = tl.arange(0, HEAD_DIM)
340
+ qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d
341
+ do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
342
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
343
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
344
+ curr_m = start_m
345
+ step_m = BLOCK_M1
346
+ for blk_idx in range(num_steps):
347
+ qT = tl.load(qT_ptrs)
348
+ # Load m before computing qk to reduce pipeline stall.
349
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
350
+ m = tl.load(M + offs_m)
351
+ qkT = tl.dot(k, qT)
352
+ pT = tl.math.exp2(qkT - m[None, :])
353
+ # Autoregressive masking.
354
+ if MASK:
355
+ if BANDWIDTH == 0: # full causal
356
+ mask = offs_m[None, :] >= offs_n[:, None]
357
+ else: # local causal
358
+ mask = (offs_m[None, :] >= offs_n[:, None]) & (
359
+ offs_m[None, :] < offs_n[:, None] + BANDWIDTH
360
+ )
361
+ pT = tl.where(mask, pT, 0.0)
362
+ do = tl.load(do_ptrs)
363
+ # Compute dV.
364
+ ppT = pT
365
+ ppT = ppT.to(tl.float16)
366
+ dv += tl.dot(ppT, do)
367
+ # D (= delta) is pre-divided by ds_scale.
368
+ Di = tl.load(D + offs_m)
369
+ # Compute dP and dS.
370
+ dpT = tl.dot(v, tl.trans(do)).to(tl.float32)
371
+ dsT = pT * (dpT - Di[None, :])
372
+ dsT = dsT.to(tl.float16)
373
+ dk += tl.dot(dsT, tl.trans(qT))
374
+ # Increment pointers.
375
+ curr_m += step_m
376
+ qT_ptrs += step_m * stride_tok
377
+ do_ptrs += step_m * stride_tok
378
+ return dk, dv
379
+
380
+
381
+ # the main inner-loop logic for computing dQ
382
+ @triton.jit
383
+ def _attn_bwd_dq(
384
+ dq,
385
+ q,
386
+ K,
387
+ V, #
388
+ do,
389
+ m,
390
+ D,
391
+ # shared by Q/K/V/DO.
392
+ stride_tok,
393
+ stride_d, #
394
+ H,
395
+ N_CTX, #
396
+ BLOCK_M2: tl.constexpr, #
397
+ BLOCK_N2: tl.constexpr, #
398
+ HEAD_DIM: tl.constexpr,
399
+ BANDWIDTH: tl.constexpr,
400
+ # Filled in by the wrapper.
401
+ start_m,
402
+ start_n,
403
+ num_steps, #
404
+ MASK: tl.constexpr,
405
+ ):
406
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
407
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
408
+ offs_k = tl.arange(0, HEAD_DIM)
409
+ kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
410
+ vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d
411
+ # D (= delta) is pre-divided by ds_scale.
412
+ Di = tl.load(D + offs_m)
413
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
414
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
415
+ curr_n = start_n
416
+ step_n = BLOCK_N2
417
+ for blk_idx in range(num_steps):
418
+ kT = tl.load(kT_ptrs)
419
+ vT = tl.load(vT_ptrs)
420
+ qk = tl.dot(q, kT)
421
+ p = tl.math.exp2(qk - m)
422
+ # Autoregressive masking.
423
+ if MASK:
424
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
425
+ if BANDWIDTH == 0: # full causal
426
+ mask = offs_m[:, None] >= offs_n[None, :]
427
+ else: # local causal
428
+ mask = (offs_m[:, None] >= offs_n[None, :]) & (
429
+ offs_m[:, None] < offs_n[None, :] + BANDWIDTH
430
+ )
431
+ p = tl.where(mask, p, 0.0)
432
+ # Compute dP and dS.
433
+ dp = tl.dot(do, vT).to(tl.float32)
434
+ ds = p * (dp - Di[:, None])
435
+ ds = ds.to(tl.float16)
436
+ # Compute dQ.
437
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
438
+ dq += tl.dot(ds, tl.trans(kT))
439
+ # Increment pointers.
440
+ curr_n += step_n
441
+ kT_ptrs += step_n * stride_tok
442
+ vT_ptrs += step_n * stride_tok
443
+ return dq
444
+
445
+
446
+ @triton.jit
447
+ def _attn_bwd(
448
+ Q,
449
+ K,
450
+ V,
451
+ sm_scale, #
452
+ DO, #
453
+ DQ,
454
+ DK,
455
+ DV, #
456
+ M,
457
+ D,
458
+ # shared by Q/K/V/DO.
459
+ stride_z,
460
+ stride_h,
461
+ stride_tok,
462
+ stride_d, #
463
+ H,
464
+ N_CTX, #
465
+ BANDWIDTH: tl.constexpr,
466
+ BLOCK_M1: tl.constexpr, #
467
+ BLOCK_N1: tl.constexpr, #
468
+ BLOCK_M2: tl.constexpr, #
469
+ BLOCK_N2: tl.constexpr, #
470
+ BLK_SLICE_FACTOR: tl.constexpr, #
471
+ HEAD_DIM: tl.constexpr,
472
+ ):
473
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
474
+
475
+ bhid = tl.program_id(2)
476
+ off_chz = (bhid * N_CTX).to(tl.int64)
477
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
478
+ pid = tl.program_id(0)
479
+
480
+ # offset pointers for batch/head
481
+ Q += adj
482
+ K += adj
483
+ V += adj
484
+ DO += adj
485
+ DQ += adj
486
+ DK += adj
487
+ DV += adj
488
+ M += off_chz
489
+ D += off_chz
490
+
491
+ # load scales
492
+ offs_k = tl.arange(0, HEAD_DIM)
493
+
494
+ start_n = pid * BLOCK_N1
495
+ start_m = start_n
496
+
497
+ MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
498
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
499
+
500
+ dv = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
501
+ dk = tl.zeros([BLOCK_N1, HEAD_DIM], dtype=tl.float32)
502
+
503
+ # load K and V: they stay in SRAM throughout the inner loop.
504
+ k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
505
+ v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
506
+
507
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
508
+
509
+ dk, dv = _attn_bwd_dkdv(
510
+ dk,
511
+ dv, #
512
+ Q,
513
+ k,
514
+ v,
515
+ sm_scale, #
516
+ DO, #
517
+ M,
518
+ D, #
519
+ stride_tok,
520
+ stride_d, #
521
+ H,
522
+ N_CTX, #
523
+ MASK_BLOCK_M1,
524
+ BLOCK_N1,
525
+ HEAD_DIM, #
526
+ start_n,
527
+ start_m,
528
+ num_steps, #
529
+ MASK=True, #
530
+ BANDWIDTH=BANDWIDTH,
531
+ )
532
+
533
+ start_m += num_steps * MASK_BLOCK_M1
534
+ # how many *additional* rows may still attend to the current key block?
535
+ if BANDWIDTH == 0:
536
+ rows_left = N_CTX - start_m
537
+ else:
538
+ rows_left = min(N_CTX - start_m, BLOCK_N1)
539
+ num_steps = rows_left // BLOCK_M1
540
+
541
+ # Compute dK and dV for non-masked blocks.
542
+ dk, dv = _attn_bwd_dkdv( #
543
+ dk,
544
+ dv, #
545
+ Q,
546
+ k,
547
+ v,
548
+ sm_scale, #
549
+ DO, #
550
+ M,
551
+ D, #
552
+ stride_tok,
553
+ stride_d, #
554
+ H,
555
+ N_CTX, #
556
+ BLOCK_M1,
557
+ BLOCK_N1,
558
+ HEAD_DIM, #
559
+ start_n,
560
+ start_m,
561
+ num_steps, #
562
+ MASK=BANDWIDTH != 0, #
563
+ BANDWIDTH=BANDWIDTH,
564
+ )
565
+
566
+ dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
567
+ tl.store(dv_ptrs, dv)
568
+
569
+ # Write back dK.
570
+ dk *= sm_scale
571
+ dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d
572
+ tl.store(dk_ptrs, dk)
573
+
574
+ # THIS BLOCK DOES DQ:
575
+ start_m = pid * BLOCK_M2
576
+ end_n = start_m + BLOCK_M2
577
+
578
+ MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
579
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
580
+
581
+ q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
582
+ dq = tl.zeros([BLOCK_M2, HEAD_DIM], dtype=tl.float32)
583
+ do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d)
584
+
585
+ m = tl.load(M + offs_m)
586
+ m = m[:, None]
587
+
588
+ # Compute dQ for masked (diagonal) blocks.
589
+ # NOTE: This code scans each row of QK^T backward (from right to left,
590
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
591
+ # not due to anything important. I just wanted to reuse the loop
592
+ # structure for dK & dV above as much as possible.
593
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
594
+ dq = _attn_bwd_dq(
595
+ dq,
596
+ q,
597
+ K,
598
+ V, #
599
+ do,
600
+ m,
601
+ D, #
602
+ stride_tok,
603
+ stride_d, #
604
+ H,
605
+ N_CTX, #
606
+ BLOCK_M2,
607
+ MASK_BLOCK_N2,
608
+ HEAD_DIM, #
609
+ BANDWIDTH,
610
+ start_m,
611
+ end_n - num_steps * MASK_BLOCK_N2,
612
+ num_steps, #
613
+ MASK=True, #
614
+ )
615
+ end_n -= num_steps * MASK_BLOCK_N2
616
+
617
+ # stage-1 (rows that still fall inside the window)
618
+ if BANDWIDTH == 0:
619
+ cols_left = end_n
620
+ else:
621
+ cols_left = min(end_n, BLOCK_M2)
622
+ num_steps = cols_left // BLOCK_N2
623
+ dq = _attn_bwd_dq(
624
+ dq,
625
+ q,
626
+ K,
627
+ V, #
628
+ do,
629
+ m,
630
+ D, #
631
+ stride_tok,
632
+ stride_d, #
633
+ H,
634
+ N_CTX, #
635
+ BLOCK_M2,
636
+ BLOCK_N2,
637
+ HEAD_DIM, #
638
+ BANDWIDTH,
639
+ start_m,
640
+ end_n - num_steps * BLOCK_N2,
641
+ num_steps, #
642
+ MASK=BANDWIDTH != 0, #
643
+ )
644
+ # Write back dQ.
645
+ dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d
646
+ dq *= LN2
647
+ tl.store(dq_ptrs, dq)
648
+
649
+
650
+ class _attention(torch.autograd.Function):
651
+ @staticmethod
652
+ def forward(
653
+ ctx,
654
+ q,
655
+ k,
656
+ v,
657
+ sinks,
658
+ causal,
659
+ sm_scale,
660
+ bandwidth,
661
+ warp_specialize=True,
662
+ USE_TMA=True,
663
+ ):
664
+ # shape constraints
665
+ HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]
666
+ # when v is in float8_e5m2 it is transposed.
667
+ HEAD_DIM_V = v.shape[-1]
668
+ assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V
669
+ assert HEAD_DIM_K in {16, 32, 64, 128, 256}
670
+ o = torch.empty_like(q)
671
+ stage = 3 if causal else 1
672
+ extra_kern_args = {}
673
+ M = torch.empty(
674
+ (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
675
+ )
676
+ BLOCK_M = 128
677
+ grid = (
678
+ triton.cdiv(q.shape[2], BLOCK_M),
679
+ q.shape[0] * q.shape[1],
680
+ 1,
681
+ )
682
+ _attn_fwd[grid](
683
+ q,
684
+ k,
685
+ v,
686
+ sinks,
687
+ sm_scale,
688
+ M,
689
+ o, #
690
+ q.stride(0),
691
+ q.stride(1),
692
+ q.stride(2),
693
+ q.stride(3), #
694
+ k.stride(0),
695
+ k.stride(1),
696
+ k.stride(2),
697
+ k.stride(3), #
698
+ v.stride(0),
699
+ v.stride(1),
700
+ v.stride(2),
701
+ v.stride(3), #
702
+ o.stride(0),
703
+ o.stride(1),
704
+ o.stride(2),
705
+ o.stride(3), #
706
+ q.shape[0],
707
+ q.shape[1], #
708
+ N_CTX=q.shape[2], #
709
+ HEAD_DIM=HEAD_DIM_K, #
710
+ STAGE=stage, #
711
+ BANDWIDTH=bandwidth,
712
+ BLOCK_M=BLOCK_M,
713
+ BLOCK_N=64,
714
+ **extra_kern_args,
715
+ )
716
+
717
+ ctx.save_for_backward(q, k, v, sinks, o, M)
718
+ ctx.sm_scale = sm_scale
719
+ ctx.HEAD_DIM = HEAD_DIM_K
720
+ ctx.causal = causal
721
+ ctx.bandwidth = bandwidth
722
+ return o
723
+
724
+ @staticmethod
725
+ def backward(ctx, do):
726
+ q, k, v, sinks, o, M = ctx.saved_tensors
727
+ do = do.contiguous()
728
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
729
+ dq = torch.empty_like(q)
730
+ dk = torch.empty_like(k)
731
+ dv = torch.empty_like(v)
732
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
733
+ PRE_BLOCK = 128
734
+ NUM_WARPS, NUM_STAGES = 4, 5
735
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32
736
+ BLK_SLICE_FACTOR = 2
737
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
738
+ arg_k = k
739
+ arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
740
+ PRE_BLOCK = 128
741
+ assert N_CTX % PRE_BLOCK == 0
742
+ pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
743
+ delta = torch.empty_like(M)
744
+ if sinks is not None:
745
+ dsinks = torch.empty_like(sinks)
746
+ dsinkstemp = torch.empty(pre_grid, dtype=torch.float32, device=sinks.device)
747
+ atomic_counters = torch.zeros(
748
+ N_HEAD, dtype=torch.int32, device=sinks.device
749
+ )
750
+ else:
751
+ dsinks, dsinkstemp, atomic_counters = None, None, None
752
+ _attn_bwd_preprocess[pre_grid](
753
+ o,
754
+ do, #
755
+ # Info for attention sinks.
756
+ sinks,
757
+ dsinks,
758
+ dsinkstemp,
759
+ atomic_counters,
760
+ M,
761
+ ######
762
+ delta, #
763
+ BATCH,
764
+ N_HEAD,
765
+ N_CTX, #
766
+ BLOCK_M=PRE_BLOCK,
767
+ HEAD_DIM=ctx.HEAD_DIM, #
768
+ )
769
+ grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD)
770
+ _attn_bwd[grid](
771
+ q,
772
+ arg_k,
773
+ v,
774
+ ctx.sm_scale,
775
+ do,
776
+ dq,
777
+ dk,
778
+ dv, #
779
+ M,
780
+ delta, #
781
+ q.stride(0),
782
+ q.stride(1),
783
+ q.stride(2),
784
+ q.stride(3), #
785
+ N_HEAD,
786
+ N_CTX, #
787
+ BANDWIDTH=ctx.bandwidth,
788
+ BLOCK_M1=BLOCK_M1,
789
+ BLOCK_N1=BLOCK_N1, #
790
+ BLOCK_M2=BLOCK_M2,
791
+ BLOCK_N2=BLOCK_N2, #
792
+ BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
793
+ HEAD_DIM=ctx.HEAD_DIM, #
794
+ num_warps=NUM_WARPS, #
795
+ num_stages=NUM_STAGES, #
796
+ )
797
+
798
+ return dq, dk, dv, dsinks, None, None, None, None, None
799
+
800
+
801
+ attention = _attention.apply
802
+