Find3D / Pointcept /libs /pointops /src /attention /attention_cuda_kernel.cu
ziqima's picture
initial commit
4893ce0
#include "../cuda_utils.h"
#include "attention_cuda_kernel.h"
/*
Kernels
*/
__global__ void attention_relation_step_forward_cuda_kernel(int m, int g, int c,
const float *query, const float *key, const float *weight,
const int *index_target, const int *index_refer,
float *output)
{
int r_idx = blockIdx.x * blockDim.x + threadIdx.x;
int g_idx = blockIdx.y;
int c_idx = blockIdx.z;
if (r_idx >= m || g_idx >= g || c_idx >= c) return;
int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx;
int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx;
float r = query[q_idx] * key[k_idx] * weight[c_idx];
atomicAdd(output + r_idx * g + g_idx, r);
}
__global__ void attention_relation_step_backward_cuda_kernel(int m, int g, int c,
const float *query, float *grad_query,
const float *key, float *grad_key,
const float *weight, float *grad_weight,
const int *index_target, const int *index_refer,
const float *grad_output)
{
int r_idx = blockIdx.x * blockDim.x + threadIdx.x;
int g_idx = blockIdx.y;
int c_idx = blockIdx.z;
if (r_idx >= m || g_idx >= g || c_idx >= c) return;
int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx;
int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx;
int o_idx = r_idx * g + g_idx;
float grad_r = grad_output[o_idx];
atomicAdd(grad_query + q_idx, grad_r * key[k_idx] * weight[c_idx]);
atomicAdd(grad_key + k_idx, grad_r * query[q_idx] * weight[c_idx]);
atomicAdd(grad_weight + c_idx, grad_r * key[k_idx] * query[q_idx]);
}
__global__ void attention_fusion_step_forward_cuda_kernel(int m, int g, int c,
const float *weight, const float *value,
const int *index_target, const int *index_refer,
float *output)
{
int r_idx = blockIdx.x * blockDim.x + threadIdx.x;
int g_idx = blockIdx.y;
int c_idx = blockIdx.z;
if (r_idx >= m || g_idx >= g || c_idx >= c) return;
int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx;
int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx;
float f = weight[r_idx * g + g_idx] * value[v_idx];
atomicAdd(output + o_idx, f);
}
__global__ void attention_fusion_step_backward_cuda_kernel(int m, int g, int c,
const float *weight, float *grad_weight,
const float *value, float *grad_value,
const int *index_target, const int *index_refer,
const float *grad_output)
{
int r_idx = blockIdx.x * blockDim.x + threadIdx.x;
int g_idx = blockIdx.y;
int c_idx = blockIdx.z;
if (r_idx >= m || g_idx >= g || c_idx >= c) return;
int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx;
int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx;
int w_idx = r_idx * g + g_idx;
float grad = grad_output[o_idx];
atomicAdd(grad_weight + w_idx, grad * value[v_idx]);
atomicAdd(grad_value + v_idx, grad * weight[w_idx]);
}
/*
Launchers
*/
void attention_relation_step_forward_cuda_launcher(int m, int g, int c,
const float *query, const float *key, const float *weight,
const int *index_target, const int *index_refer,
float *output)
{
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c);
dim3 threads(THREADS_PER_BLOCK);
attention_relation_step_forward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c, query, key, weight,
index_target, index_refer, output);
}
void attention_relation_step_backward_cuda_launcher(int m, int g, int c,
const float *query, float *grad_query,
const float *key, float *grad_key,
const float *weight, float *grad_weight,
const int *index_target, const int *index_refer,
const float *grad_output)
{
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c);
dim3 threads(THREADS_PER_BLOCK);
attention_relation_step_backward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c,
query, grad_query,
key, grad_key,
weight, grad_weight,
index_target, index_refer,
grad_output);
}
void attention_fusion_step_forward_cuda_launcher(int m, int g, int c,
const float *weight, const float *value,
const int *index_target, const int *index_refer,
float *output)
{
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c);
dim3 threads(THREADS_PER_BLOCK);
attention_fusion_step_forward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c, weight, value,
index_target, index_refer, output);
}
void attention_fusion_step_backward_cuda_launcher(int m, int g, int c,
const float *weight, float *grad_weight,
const float *value, float *grad_value,
const int *index_target, const int *index_refer,
const float *grad_output)
{
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c);
dim3 threads(THREADS_PER_BLOCK);
attention_fusion_step_backward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c,
weight, grad_weight,
value, grad_value,
index_target, index_refer,
grad_output);
}