Spaces:
Runtime error
Runtime error
#include "../cuda_utils.h" | |
#include "attention_cuda_kernel.h" | |
/* | |
Kernels | |
*/ | |
__global__ void attention_relation_step_forward_cuda_kernel(int m, int g, int c, | |
const float *query, const float *key, const float *weight, | |
const int *index_target, const int *index_refer, | |
float *output) | |
{ | |
int r_idx = blockIdx.x * blockDim.x + threadIdx.x; | |
int g_idx = blockIdx.y; | |
int c_idx = blockIdx.z; | |
if (r_idx >= m || g_idx >= g || c_idx >= c) return; | |
int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; | |
int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; | |
float r = query[q_idx] * key[k_idx] * weight[c_idx]; | |
atomicAdd(output + r_idx * g + g_idx, r); | |
} | |
__global__ void attention_relation_step_backward_cuda_kernel(int m, int g, int c, | |
const float *query, float *grad_query, | |
const float *key, float *grad_key, | |
const float *weight, float *grad_weight, | |
const int *index_target, const int *index_refer, | |
const float *grad_output) | |
{ | |
int r_idx = blockIdx.x * blockDim.x + threadIdx.x; | |
int g_idx = blockIdx.y; | |
int c_idx = blockIdx.z; | |
if (r_idx >= m || g_idx >= g || c_idx >= c) return; | |
int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; | |
int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; | |
int o_idx = r_idx * g + g_idx; | |
float grad_r = grad_output[o_idx]; | |
atomicAdd(grad_query + q_idx, grad_r * key[k_idx] * weight[c_idx]); | |
atomicAdd(grad_key + k_idx, grad_r * query[q_idx] * weight[c_idx]); | |
atomicAdd(grad_weight + c_idx, grad_r * key[k_idx] * query[q_idx]); | |
} | |
__global__ void attention_fusion_step_forward_cuda_kernel(int m, int g, int c, | |
const float *weight, const float *value, | |
const int *index_target, const int *index_refer, | |
float *output) | |
{ | |
int r_idx = blockIdx.x * blockDim.x + threadIdx.x; | |
int g_idx = blockIdx.y; | |
int c_idx = blockIdx.z; | |
if (r_idx >= m || g_idx >= g || c_idx >= c) return; | |
int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; | |
int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; | |
float f = weight[r_idx * g + g_idx] * value[v_idx]; | |
atomicAdd(output + o_idx, f); | |
} | |
__global__ void attention_fusion_step_backward_cuda_kernel(int m, int g, int c, | |
const float *weight, float *grad_weight, | |
const float *value, float *grad_value, | |
const int *index_target, const int *index_refer, | |
const float *grad_output) | |
{ | |
int r_idx = blockIdx.x * blockDim.x + threadIdx.x; | |
int g_idx = blockIdx.y; | |
int c_idx = blockIdx.z; | |
if (r_idx >= m || g_idx >= g || c_idx >= c) return; | |
int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx; | |
int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx; | |
int w_idx = r_idx * g + g_idx; | |
float grad = grad_output[o_idx]; | |
atomicAdd(grad_weight + w_idx, grad * value[v_idx]); | |
atomicAdd(grad_value + v_idx, grad * weight[w_idx]); | |
} | |
/* | |
Launchers | |
*/ | |
void attention_relation_step_forward_cuda_launcher(int m, int g, int c, | |
const float *query, const float *key, const float *weight, | |
const int *index_target, const int *index_refer, | |
float *output) | |
{ | |
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); | |
dim3 threads(THREADS_PER_BLOCK); | |
attention_relation_step_forward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c, query, key, weight, | |
index_target, index_refer, output); | |
} | |
void attention_relation_step_backward_cuda_launcher(int m, int g, int c, | |
const float *query, float *grad_query, | |
const float *key, float *grad_key, | |
const float *weight, float *grad_weight, | |
const int *index_target, const int *index_refer, | |
const float *grad_output) | |
{ | |
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); | |
dim3 threads(THREADS_PER_BLOCK); | |
attention_relation_step_backward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c, | |
query, grad_query, | |
key, grad_key, | |
weight, grad_weight, | |
index_target, index_refer, | |
grad_output); | |
} | |
void attention_fusion_step_forward_cuda_launcher(int m, int g, int c, | |
const float *weight, const float *value, | |
const int *index_target, const int *index_refer, | |
float *output) | |
{ | |
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); | |
dim3 threads(THREADS_PER_BLOCK); | |
attention_fusion_step_forward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c, weight, value, | |
index_target, index_refer, output); | |
} | |
void attention_fusion_step_backward_cuda_launcher(int m, int g, int c, | |
const float *weight, float *grad_weight, | |
const float *value, float *grad_value, | |
const int *index_target, const int *index_refer, | |
const float *grad_output) | |
{ | |
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c); | |
dim3 threads(THREADS_PER_BLOCK); | |
attention_fusion_step_backward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c, | |
weight, grad_weight, | |
value, grad_value, | |
index_target, index_refer, | |
grad_output); | |
} | |