Find3D / Pointcept /libs /pointops2 /src /attention /attention_cuda_kernel.cu
ziqima's picture
initial commit
4893ce0
/* written by Xin Lai. Email: [email protected] */
#include "../cuda_utils.h"
#include "attention_cuda_kernel.h"
__global__ void attention_step1_forward_cuda_kernel( // M, h, C//h
int N, int M, int h, int C, const float *q, const float *k,
const int *index0, const int *index1, float *attn) {
int c_idx = blockIdx.z;
int h_idx = blockIdx.y;
int m_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (m_idx >= M || h_idx >= h || c_idx >= C / h) return;
int idx0 = index0[m_idx];
int idx1 = index1[m_idx];
float val = q[idx0*C+h_idx*C/h+c_idx] * k[idx1*C+h_idx*C/h+c_idx];
atomicAdd(attn+m_idx*h+h_idx, val);
}
__global__ void attention_step1_backward_cuda_kernel( // M, h, C//h
int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *q, const float *k,
float *grad_q, float *grad_k) {
int c_idx = blockIdx.z;
int h_idx = blockIdx.y;
int m_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (m_idx >= M || h_idx >= h || c_idx >= C / h) return;
int idx0 = index0[m_idx];
int idx1 = index1[m_idx];
int grad_out_idx = m_idx*h+h_idx;
int q_idx = idx0*C+h_idx*C/h+c_idx;
int k_idx = idx1*C+h_idx*C/h+c_idx;
atomicAdd(grad_q+q_idx, grad_out[grad_out_idx] * k[k_idx]);
atomicAdd(grad_k+k_idx, grad_out[grad_out_idx] * q[q_idx]);
}
void attention_step1_forward_cuda_launcher(int N, int M, int h, int C, const float *q, const float *k,
const int *index0, const int *index1, float *attn) {
// input: attn: (M, h), v: (N, h, C/h), index0: (M, ), index1: (M, )
//dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M);
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h);
dim3 threads(THREADS_PER_BLOCK);
attention_step1_forward_cuda_kernel<<<blocks, threads, 0>>>(N, M, h, C, q, k, index0, index1, attn);
}
void attention_step1_backward_cuda_launcher(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1,
const float *q, const float *k, float *grad_q, float *grad_k) {
// input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c)
//dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M);
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h);
dim3 threads(THREADS_PER_BLOCK);
attention_step1_backward_cuda_kernel<<<blocks, threads, 0>>>(N, M, h, C, grad_out, index0, index1, q, k, grad_q, grad_k);
}
__global__ void attention_step2_forward_cuda_kernel( // M, h, C//h
int N, int M, int h, int C, const float *attn, const float *v,
const int *index0, const int *index1, float *output) {
int c_idx = blockIdx.z;
int h_idx = blockIdx.y;
int m_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (m_idx >= M || h_idx >= h || c_idx >= C / h) return;
int idx1 = index1[m_idx];
float val = attn[m_idx*h+h_idx] * v[idx1*C+h_idx*C/h+c_idx];
int idx0 = index0[m_idx];
atomicAdd(output+idx0*C+h_idx*C/h+c_idx, val);
}
__global__ void attention_step2_backward_cuda_kernel( // M, h, C//h
int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1, const float *attn, const float *v,
float *grad_attn, float *grad_v) {
int c_idx = blockIdx.z;
int h_idx = blockIdx.y;
int m_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (m_idx >= M || h_idx >= h || c_idx >= C / h) return;
int idx0 = index0[m_idx];
int idx1 = index1[m_idx];
int grad_out_idx = idx0*C+h_idx*C/h+c_idx;
atomicAdd(grad_attn+m_idx*h+h_idx, grad_out[grad_out_idx] * v[idx1*C+h_idx*C/h+c_idx]);
atomicAdd(grad_v+idx1*C+h_idx*C/h+c_idx, grad_out[grad_out_idx] * attn[m_idx*h+h_idx]);
}
void attention_step2_forward_cuda_launcher(int N, int M, int h, int C, const float *attn, const float *v,
const int *index0, const int *index1, float *output) {
// input: attn: (M, h), v: (N, h, C/h), index0: (M, ), index1: (M, )
//dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M);
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h);
dim3 threads(THREADS_PER_BLOCK);
attention_step2_forward_cuda_kernel<<<blocks, threads, 0>>>(N, M, h, C, attn, v, index0, index1, output);
}
void attention_step2_backward_cuda_launcher(int N, int M, int h, int C, const float *grad_out, const int *index0, const int *index1,
const float *attn, const float *v, float *grad_attn, float *grad_v) {
// input: grad_output: (n, nsample, c), output: grad_input1: (n, c), grad_input2: (n, c)
//dim3 blocks(DIVUP(C/h, THREADS_PER_BLOCK), h, M);
dim3 blocks(DIVUP(M, THREADS_PER_BLOCK), h, C/h);
dim3 threads(THREADS_PER_BLOCK);
attention_step2_backward_cuda_kernel<<<blocks, threads, 0>>>(N, M, h, C, grad_out, index0, index1, attn, v, grad_attn, grad_v);
}