Spaces:
Runtime error
Runtime error
/* written by Xin Lai. Email: [email protected] */ | |
#include "../cuda_utils.h" | |
#include "attention_cuda_kernel.h" | |
__global__ void attention_step1_forward_cuda_kernel( // M, h, C//h | |
int N, int M, int h, int C, const float *q, const float *k, | |
const int *index0, const int *index1, float *attn) { | |
int c_idx = blockIdx.z; | |
int h_idx = blockIdx.y; | |
int m_idx = blockIdx.x * blockDim.x + threadIdx.x; | |
if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; | |
int idx0 = index0[m_idx]; | |
int idx1 = index1[m_idx]; | |
float val = q[idx0*C+h_idx*C/h+c_idx] * k[idx1*C+h_idx*C/h+c_idx]; | |
atomicAdd(attn+m_idx*h+h_idx, val); | |
} | |
__global__ void attention_step1_backward_cuda_kernel( // M, h, C//h | |
int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *q, const float *k, | |
float *grad_q, float *grad_k) { | |
int c_idx = blockIdx.z; | |
int h_idx = blockIdx.y; | |
int m_idx = blockIdx.x * blockDim.x + threadIdx.x; | |
if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; | |
int idx0 = index0[m_idx]; | |
int idx1 = index1[m_idx]; | |
int grad_out_idx = m_idx*h+h_idx; | |
int q_idx = idx0*C+h_idx*C/h+c_idx; | |
int k_idx = idx1*C+h_idx*C/h+c_idx; | |
atomicAdd(grad_q+q_idx, grad_out[grad_out_idx] * k[k_idx]); | |
atomicAdd(grad_k+k_idx, grad_out[grad_out_idx] * q[q_idx]); | |
} | |
void attention_step1_forward_cuda_launcher(int N, int M, int h, int C, const float *q, const float *k, | |
const int *index0, const int *index1, float *attn) { | |
// input: attn: (M, h), v: (N, h, C/h), index0: (M, ), index1: (M, ) | |
//dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); | |
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); | |
dim3 threads(THREADS_PER_BLOCK); | |
attention_step1_forward_cuda_kernel<<<blocks, threads, 0>>>(N, M, h, C, q, k, index0, index1, attn); | |
} | |
void attention_step1_backward_cuda_launcher(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, | |
const float *q, const float *k, float *grad_q, float *grad_k) { | |
// input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) | |
//dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); | |
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); | |
dim3 threads(THREADS_PER_BLOCK); | |
attention_step1_backward_cuda_kernel<<<blocks, threads, 0>>>(N, M, h, C, grad_out, index0, index1, q, k, grad_q, grad_k); | |
} | |
__global__ void attention_step2_forward_cuda_kernel( // M, h, C//h | |
int N, int M, int h, int C, const float *attn, const float *v, | |
const int *index0, const int *index1, float *output) { | |
int c_idx = blockIdx.z; | |
int h_idx = blockIdx.y; | |
int m_idx = blockIdx.x * blockDim.x + threadIdx.x; | |
if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; | |
int idx1 = index1[m_idx]; | |
float val = attn[m_idx*h+h_idx] * v[idx1*C+h_idx*C/h+c_idx]; | |
int idx0 = index0[m_idx]; | |
atomicAdd(output+idx0*C+h_idx*C/h+c_idx, val); | |
} | |
__global__ void attention_step2_backward_cuda_kernel( // M, h, C//h | |
int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *attn, const float *v, | |
float *grad_attn, float *grad_v) { | |
int c_idx = blockIdx.z; | |
int h_idx = blockIdx.y; | |
int m_idx = blockIdx.x * blockDim.x + threadIdx.x; | |
if (m_idx >= M || h_idx >= h || c_idx >= C / h) return; | |
int idx0 = index0[m_idx]; | |
int idx1 = index1[m_idx]; | |
int grad_out_idx = idx0*C+h_idx*C/h+c_idx; | |
atomicAdd(grad_attn+m_idx*h+h_idx, grad_out[grad_out_idx] * v[idx1*C+h_idx*C/h+c_idx]); | |
atomicAdd(grad_v+idx1*C+h_idx*C/h+c_idx, grad_out[grad_out_idx] * attn[m_idx*h+h_idx]); | |
} | |
void attention_step2_forward_cuda_launcher(int N, int M, int h, int C, const float *attn, const float *v, | |
const int *index0, const int *index1, float *output) { | |
// input: attn: (M, h), v: (N, h, C/h), index0: (M, ), index1: (M, ) | |
//dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); | |
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); | |
dim3 threads(THREADS_PER_BLOCK); | |
attention_step2_forward_cuda_kernel<<<blocks, threads, 0>>>(N, M, h, C, attn, v, index0, index1, output); | |
} | |
void attention_step2_backward_cuda_launcher(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, | |
const float *attn, const float *v, float *grad_attn, float *grad_v) { | |
// input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c) | |
//dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M); | |
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h); | |
dim3 threads(THREADS_PER_BLOCK); | |
attention_step2_backward_cuda_kernel<<<blocks, threads, 0>>>(N, M, h, C, grad_out, index0, index1, attn, v, grad_attn, grad_v); | |
} | |