| | """ |
| | monoid_scan_cuda.py — Triton CUDA JIT Accelerated Parallel Prefix Scan |
| | monoid_scan_cuda.py — Triton CUDA JIT 加速的并行前缀扫描 |
| | |
| | This module implements the parallel prefix scan for the vector-decay monoid recurrence: |
| | y_t[i,:] = decay_t[i] · y_{t-1}[i,:] + x_t[i,:] |
| | 本模块实现向量衰减幺半群递推的并行前缀扫描: |
| | y_t[i,:] = decay_t[i] · y_{t-1}[i,:] + x_t[i,:] |
| | |
| | This is the computational backbone of Monoid Attention's state compression. |
| | 这是幺半群注意力状态压缩的计算骨干。 |
| | |
| | Vector decay: each dimension of the D_k×D_v state matrix has its own |
| | per-dimension decay rate α_t ∈ (0,1)^{D_k} (sigmoid output), enabling |
| | different feature dimensions to have independent memory lifetimes |
| | (fast-decaying for local syntax, slow-decaying for global entity memory). |
| | 向量衰减: D_k×D_v 状态矩阵的每个维度拥有独立的衰减率 α_t ∈ (0,1)^{D_k} (sigmoid 输出), |
| | 使不同特征维度拥有独立的记忆生命周期 (快速衰减用于局部语法, 慢速衰减用于全局实体记忆)。 |
| | |
| | Implementation: |
| | Forward: sequential scan along T, parallelized across B*H*D_k on GPU. |
| | Each program handles one row of the state matrix (D_v elements) |
| | with a scalar decay per row. |
| | Backward: reverse-order adjoint scan for gradient computation. |
| | Per-row reduction for decay gradient (no atomic_add needed). |
| | Auto-dispatches: CUDA → Triton kernel, CPU/MPS → PyTorch fallback. |
| | |
| | 前向: 沿 T 维顺序扫描, 跨 B*H*D_k 在 GPU 上并行。 |
| | 每个 program 处理状态矩阵的一行 (D_v 个元素), 每行一个标量衰减。 |
| | 反向: 逆序伴随变量扫描计算梯度。 |
| | 逐行归约计算 decay 梯度 (无需 atomic_add)。 |
| | 自动分派: CUDA → Triton 核函数, CPU/MPS → PyTorch 回退。 |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import torch |
| | from torch import Tensor |
| | from torch.autograd import Function |
| | from typing import Tuple |
| |
|
| | try: |
| | import triton |
| | import triton.language as tl |
| | HAS_TRITON = True |
| | except ImportError: |
| | HAS_TRITON = False |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def _sequential_scan(decays: Tensor, values: Tensor) -> Tensor: |
| | """ |
| | Pure PyTorch sequential scan fallback (when no CUDA / Triton available). |
| | 纯 PyTorch 串行扫描回退 (无 CUDA / Triton 时使用)。 |
| | |
| | Implements the vector-decay monoid recurrence step by step: |
| | acc_0 = 0 |
| | acc_t[i,:] = decay_t[i] · acc_{t-1}[i,:] + values_t[i,:] |
| | This is O(T) sequential — correct but slow on GPU. |
| | 逐步实现向量衰减幺半群递推: |
| | acc_0 = 0 |
| | acc_t[i,:] = decay_t[i] · acc_{t-1}[i,:] + values_t[i,:] |
| | 这是 O(T) 串行的 — 结果正确但在 GPU 上较慢。 |
| | |
| | Args: |
| | decays: [B, H, T, D_k] — per-dimension per-step decay gates α_t ∈ (0,1) |
| | 每维度每步衰减门 α_t ∈ (0,1) |
| | values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t to accumulate |
| | 待累积的外积 k_t⊗v_t |
| | Returns: |
| | output: [B, H, T, D_k, D_v] — all prefix states S_1, ..., S_T |
| | 所有前缀状态 S_1, ..., S_T |
| | """ |
| | B, H, T, D_k, D_v = values.shape |
| | out = torch.empty_like(values) |
| | |
| | |
| | acc = torch.zeros(B, H, D_k, D_v, device=values.device, dtype=values.dtype) |
| | for t in range(T): |
| | |
| | |
| | decay_t = decays[:, :, t].unsqueeze(-1) |
| | acc = acc * decay_t + values[:, :, t] |
| | out[:, :, t] = acc |
| | return out |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | if HAS_TRITON: |
| |
|
| | @triton.jit |
| | def _scan_fwd_kernel( |
| | LD_ptr, V_ptr, O_ptr, |
| | T, D_v, |
| | s_ld_bhdk, s_ld_t, |
| | s_v_bhdk, s_v_t, s_v_dv, |
| | s_o_bhdk, s_o_t, s_o_dv, |
| | BLOCK_DV: tl.constexpr, |
| | ): |
| | """ |
| | Forward scan kernel — computes all prefix states S_1..S_T (vector decay). |
| | 前向扫描核函数 — 计算所有前缀状态 S_1..S_T (向量衰减)。 |
| | |
| | Parallelization strategy / 并行化策略: |
| | - program_id(0) = bhdk: one program per (batch, head, d_k row) triple |
| | 每个 (batch, head, d_k 行) 三元组一个 program |
| | - program_id(1) = dvb: one program per D_v-dimension block (typically 1 block) |
| | 每个 D_v 维 block 一个 program (通常只有 1 个 block) |
| | - Sequential loop over T (the causal recurrence is inherently sequential) |
| | 沿 T 维串行循环 (因果递推本质上是串行的) |
| | |
| | Each program handles one row of the D_k×D_v state matrix, where the |
| | decay is a single scalar per row. This eliminates the need for |
| | row-index computation in the inner loop. |
| | 每个 program 处理 D_k×D_v 状态矩阵的一行, 该行的衰减是一个标量。 |
| | 这消除了内循环中行索引计算的需要。 |
| | |
| | Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)) |
| | 网格: (B*H*D_k, ceil(D_v/BLOCK_DV)) |
| | """ |
| | bhdk = tl.program_id(0) |
| | dvb = tl.program_id(1) |
| | dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV) |
| | dv_mask = dv_offs < D_v |
| |
|
| | |
| | |
| | acc = tl.zeros([BLOCK_DV], dtype=tl.float32) |
| |
|
| | ld_base = LD_ptr + bhdk * s_ld_bhdk |
| | v_base = V_ptr + bhdk * s_v_bhdk |
| | o_base = O_ptr + bhdk * s_o_bhdk |
| |
|
| | for t in range(T): |
| | |
| | |
| | decay = tl.load(ld_base + t * s_ld_t).to(tl.float32) |
| |
|
| | |
| | |
| | val = tl.load( |
| | v_base + t * s_v_t + dv_offs * s_v_dv, |
| | mask=dv_mask, other=0.0, |
| | ).to(tl.float32) |
| |
|
| | |
| | |
| | acc = acc * decay + val |
| |
|
| | |
| | tl.store( |
| | o_base + t * s_o_t + dv_offs * s_o_dv, |
| | acc, mask=dv_mask, |
| | ) |
| |
|
| | @triton.jit |
| | def _scan_bwd_kernel( |
| | LD_ptr, O_ptr, GO_ptr, GV_ptr, GLD_ptr, |
| | T, D_v, |
| | s_ld_bhdk, s_ld_t, |
| | s_o_bhdk, s_o_t, s_o_dv, |
| | s_go_bhdk, s_go_t, s_go_dv, |
| | s_gv_bhdk, s_gv_t, s_gv_dv, |
| | s_gld_bhdk, s_gld_t, |
| | BLOCK_DV: tl.constexpr, |
| | ): |
| | """ |
| | Backward scan kernel — computes gradients via adjoint method (vector decay). |
| | 反向扫描核函数 — 通过伴随方法计算梯度 (向量衰减)。 |
| | |
| | Each program handles one row of the state matrix (one d_k dimension). |
| | The decay for this row is a scalar, so the decay gradient is: |
| | ∂L/∂α_t[i] = Σ_j(λ_t[i,j] · y_{t-1}[i,j]) |
| | The sum over j (D_v) is computed within this single program — no atomic_add. |
| | 每个 program 处理状态矩阵的一行 (一个 d_k 维度)。 |
| | 该行的衰减是标量, 因此 decay 梯度为: |
| | ∂L/∂α_t[i] = Σ_j(λ_t[i,j] · y_{t-1}[i,j]) |
| | 对 j (D_v) 的求和在单个 program 内完成 — 无需 atomic_add。 |
| | """ |
| | bhdk = tl.program_id(0) |
| | dvb = tl.program_id(1) |
| | dv_offs = dvb * BLOCK_DV + tl.arange(0, BLOCK_DV) |
| | dv_mask = dv_offs < D_v |
| |
|
| | |
| | |
| | adj = tl.zeros([BLOCK_DV], dtype=tl.float32) |
| |
|
| | for t_rev in range(T): |
| | t = T - 1 - t_rev |
| |
|
| | |
| | |
| | go = tl.load( |
| | GO_ptr + bhdk * s_go_bhdk + t * s_go_t + dv_offs * s_go_dv, |
| | mask=dv_mask, other=0.0, |
| | ).to(tl.float32) |
| |
|
| | |
| | |
| | lam = go + adj |
| |
|
| | |
| | |
| | tl.store( |
| | GV_ptr + bhdk * s_gv_bhdk + t * s_gv_t + dv_offs * s_gv_dv, |
| | lam, mask=dv_mask, |
| | ) |
| |
|
| | |
| | |
| | |
| | a_t = tl.load(LD_ptr + bhdk * s_ld_bhdk + t * s_ld_t).to(tl.float32) |
| |
|
| | if t > 0: |
| | y_prev = tl.load( |
| | O_ptr + bhdk * s_o_bhdk + (t - 1) * s_o_t + dv_offs * s_o_dv, |
| | mask=dv_mask, other=0.0, |
| | ).to(tl.float32) |
| | grad_d = tl.sum(lam * y_prev) |
| | tl.atomic_add(GLD_ptr + bhdk * s_gld_bhdk + t * s_gld_t, grad_d) |
| |
|
| | |
| | |
| | adj = a_t * lam |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | class _ParallelScanFn(Function): |
| | """ |
| | Custom autograd function for the parallel prefix scan (vector decay). |
| | 并行前缀扫描的自定义 autograd 函数 (向量衰减)。 |
| | |
| | Forward: launches _scan_fwd_kernel to compute all prefix states. |
| | Grid: (B*H*D_k, ceil(D_v/BLOCK_DV)), one program per state row. |
| | Backward: launches _scan_bwd_kernel to compute gradients via adjoint method. |
| | Per-row reduction eliminates most atomic_add overhead. |
| | |
| | 前向: 启动 _scan_fwd_kernel 计算所有前缀状态。 |
| | 网格: (B*H*D_k, ceil(D_v/BLOCK_DV)), 每行状态一个 program。 |
| | 反向: 启动 _scan_bwd_kernel 通过伴随方法计算梯度。 |
| | 逐行归约消除大部分 atomic_add 开销。 |
| | """ |
| | @staticmethod |
| | def forward(ctx, decays: Tensor, values: Tensor) -> Tensor: |
| | B, H, T, D_k, D_v = values.shape |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | ld_flat = decays.permute(0, 1, 3, 2).contiguous().reshape(B * H * D_k, T) |
| | v_flat = values.permute(0, 1, 3, 2, 4).contiguous().reshape(B * H * D_k, T, D_v) |
| | o_flat = torch.empty_like(v_flat) |
| |
|
| | BHDK = B * H * D_k |
| | BLOCK_DV = min(triton.next_power_of_2(D_v), 1024) |
| | |
| | |
| | grid = (BHDK, triton.cdiv(D_v, BLOCK_DV)) |
| |
|
| | _scan_fwd_kernel[grid]( |
| | ld_flat, v_flat, o_flat, |
| | T, D_v, |
| | ld_flat.stride(0), ld_flat.stride(1), |
| | v_flat.stride(0), v_flat.stride(1), v_flat.stride(2), |
| | o_flat.stride(0), o_flat.stride(1), o_flat.stride(2), |
| | BLOCK_DV=BLOCK_DV, |
| | ) |
| |
|
| | |
| | |
| | ctx.save_for_backward(ld_flat, o_flat) |
| | ctx.shape_info = (B, H, T, D_k, D_v, BHDK, BLOCK_DV) |
| | |
| | return o_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous() |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_output: Tensor): |
| | ld_flat, o_flat = ctx.saved_tensors |
| | B, H, T, D_k, D_v, BHDK, BLOCK_DV = ctx.shape_info |
| |
|
| | |
| | go_flat = grad_output.permute(0, 1, 3, 2, 4).contiguous().reshape(BHDK, T, D_v) |
| | gv_flat = torch.empty_like(go_flat) |
| | |
| | |
| | gld_flat = torch.zeros(BHDK, T, device=ld_flat.device, dtype=torch.float32) |
| |
|
| | grid = (BHDK, triton.cdiv(D_v, BLOCK_DV)) |
| |
|
| | _scan_bwd_kernel[grid]( |
| | ld_flat, o_flat, go_flat, gv_flat, gld_flat, |
| | T, D_v, |
| | ld_flat.stride(0), ld_flat.stride(1), |
| | o_flat.stride(0), o_flat.stride(1), o_flat.stride(2), |
| | go_flat.stride(0), go_flat.stride(1), go_flat.stride(2), |
| | gv_flat.stride(0), gv_flat.stride(1), gv_flat.stride(2), |
| | gld_flat.stride(0), gld_flat.stride(1), |
| | BLOCK_DV=BLOCK_DV, |
| | ) |
| |
|
| | |
| | |
| | |
| | grad_decays = gld_flat.to(grad_output.dtype).reshape(B, H, D_k, T).permute(0, 1, 3, 2).contiguous() |
| | |
| | grad_values = gv_flat.reshape(B, H, D_k, T, D_v).permute(0, 1, 3, 2, 4).contiguous() |
| | return grad_decays, grad_values |
| |
|
| | def _triton_parallel_scan(decays: Tensor, values: Tensor) -> Tensor: |
| | """Triton-accelerated parallel scan entry point (vector decay). |
| | Triton 加速的并行扫描入口 (向量衰减)。""" |
| | return _ParallelScanFn.apply(decays, values) |
| |
|
| | else: |
| | _triton_parallel_scan = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def parallel_scan(decays: Tensor, values: Tensor) -> Tensor: |
| | """ |
| | Parallel prefix scan — computes all prefix monoid sums (vector decay). |
| | 并行前缀扫描 — 计算所有前缀幺半群和 (向量衰减)。 |
| | |
| | This is the training-time workhorse of Monoid Attention. |
| | It computes S_1, S_2, ..., S_T where |
| | S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:] |
| | for ALL timesteps simultaneously. |
| | 这是幺半群注意力训练时的主力计算。 |
| | 它同时计算所有时间步的 S_1, S_2, ..., S_T, |
| | 其中 S_t[i,:] = α_t[i]·S_{t-1}[i,:] + kv_t[i,:]。 |
| | |
| | Auto-dispatches based on device: |
| | CUDA → Triton JIT kernel (fast, with custom backward) |
| | CPU/MPS → PyTorch sequential scan (correct, slower) |
| | 根据设备自动分派: |
| | CUDA → Triton JIT 核函数 (快速, 带自定义反向传播) |
| | CPU/MPS → PyTorch 串行扫描 (正确, 较慢) |
| | |
| | Args: |
| | decays: [B, H, T, D_k] — per-dimension decay gates α_t ∈ (0,1) (sigmoid output) |
| | 每维度衰减门 α_t ∈ (0,1) (sigmoid 输出) |
| | values: [B, H, T, D_k, D_v] — outer products k_t⊗v_t |
| | 外积 k_t⊗v_t |
| | Returns: |
| | states: [B, H, T, D_k, D_v] — all prefix states S_1..S_T |
| | 所有前缀状态 S_1..S_T |
| | """ |
| | if _triton_parallel_scan is not None and values.is_cuda: |
| | return _triton_parallel_scan(decays, values) |
| | return _sequential_scan(decays, values) |
| |
|
| |
|
| | def parallel_scan_with_state( |
| | decays: Tensor, values: Tensor, |
| | ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| | """ |
| | Parallel prefix scan + extract final state for inference handoff (vector decay). |
| | 并行前缀扫描 + 提取最终状态用于推理切换 (向量衰减)。 |
| | |
| | Used during prefill: compute all training-time prefix states, |
| | AND extract the final accumulated state S_T so that subsequent |
| | tokens can be generated in O(1) RNN mode via monoid_op. |
| | 在预填充时使用: 计算所有训练时的前缀状态, |
| | 同时提取最终累积状态 S_T, 以便后续 token 可以 |
| | 通过 monoid_op 以 O(1) RNN 模式生成。 |
| | |
| | This is the bridge between training mode (parallel scan) |
| | and inference mode (sequential monoid_op). |
| | 这是训练模式 (并行扫描) 和推理模式 (串行 monoid_op) 之间的桥梁。 |
| | |
| | Args: |
| | decays: [B, H, T, D_k] — per-dimension decay gates α_t ∈ (0,1) |
| | values: [B, H, T, D_k, D_v] |
| | |
| | Returns: |
| | output: [B, H, T, D_k, D_v] — all prefix states S_1..S_T |
| | 所有前缀状态 |
| | final_state: (decay_acc, S_T) where |
| | decay_acc: [B, H, D_k] — accumulated decay product (for future monoid_op) |
| | 累积衰减乘积 (供后续 monoid_op 使用) |
| | final_state: [B, H, D_k, D_v] — S_T, the compressed causal summary |
| | S_T, 压缩的因果摘要 |
| | """ |
| | output = parallel_scan(decays, values) |
| | |
| | |
| | decay_acc = torch.exp(torch.sum(torch.log(decays + 1e-8), dim=2)) |
| | |
| | |
| | final_state = output[:, :, -1] |
| | return output, (decay_acc, final_state) |
| |
|