Spaces:
Runtime error
Runtime error
File size: 7,166 Bytes
4893ce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
#include "../cuda_utils.h"
#include "attention_cuda_kernel.h"
/*
Kernels
*/
__global__ void attention_relation_step_forward_cuda_kernel(int m, int g, int c,
const float *query, const float *key, const float *weight,
const int *index_target, const int *index_refer,
float *output)
{
int r_idx = blockIdx.x * blockDim.x + threadIdx.x;
int g_idx = blockIdx.y;
int c_idx = blockIdx.z;
if (r_idx >= m || g_idx >= g || c_idx >= c) return;
int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx;
int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx;
float r = query[q_idx] * key[k_idx] * weight[c_idx];
atomicAdd(output + r_idx * g + g_idx, r);
}
__global__ void attention_relation_step_backward_cuda_kernel(int m, int g, int c,
const float *query, float *grad_query,
const float *key, float *grad_key,
const float *weight, float *grad_weight,
const int *index_target, const int *index_refer,
const float *grad_output)
{
int r_idx = blockIdx.x * blockDim.x + threadIdx.x;
int g_idx = blockIdx.y;
int c_idx = blockIdx.z;
if (r_idx >= m || g_idx >= g || c_idx >= c) return;
int q_idx = index_target[r_idx] * g * c + g_idx * c + c_idx;
int k_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx;
int o_idx = r_idx * g + g_idx;
float grad_r = grad_output[o_idx];
atomicAdd(grad_query + q_idx, grad_r * key[k_idx] * weight[c_idx]);
atomicAdd(grad_key + k_idx, grad_r * query[q_idx] * weight[c_idx]);
atomicAdd(grad_weight + c_idx, grad_r * key[k_idx] * query[q_idx]);
}
__global__ void attention_fusion_step_forward_cuda_kernel(int m, int g, int c,
const float *weight, const float *value,
const int *index_target, const int *index_refer,
float *output)
{
int r_idx = blockIdx.x * blockDim.x + threadIdx.x;
int g_idx = blockIdx.y;
int c_idx = blockIdx.z;
if (r_idx >= m || g_idx >= g || c_idx >= c) return;
int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx;
int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx;
float f = weight[r_idx * g + g_idx] * value[v_idx];
atomicAdd(output + o_idx, f);
}
__global__ void attention_fusion_step_backward_cuda_kernel(int m, int g, int c,
const float *weight, float *grad_weight,
const float *value, float *grad_value,
const int *index_target, const int *index_refer,
const float *grad_output)
{
int r_idx = blockIdx.x * blockDim.x + threadIdx.x;
int g_idx = blockIdx.y;
int c_idx = blockIdx.z;
if (r_idx >= m || g_idx >= g || c_idx >= c) return;
int o_idx = index_target[r_idx] * g * c + g_idx * c + c_idx;
int v_idx = index_refer[r_idx] * g * c + g_idx * c + c_idx;
int w_idx = r_idx * g + g_idx;
float grad = grad_output[o_idx];
atomicAdd(grad_weight + w_idx, grad * value[v_idx]);
atomicAdd(grad_value + v_idx, grad * weight[w_idx]);
}
/*
Launchers
*/
void attention_relation_step_forward_cuda_launcher(int m, int g, int c,
const float *query, const float *key, const float *weight,
const int *index_target, const int *index_refer,
float *output)
{
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c);
dim3 threads(THREADS_PER_BLOCK);
attention_relation_step_forward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c, query, key, weight,
index_target, index_refer, output);
}
void attention_relation_step_backward_cuda_launcher(int m, int g, int c,
const float *query, float *grad_query,
const float *key, float *grad_key,
const float *weight, float *grad_weight,
const int *index_target, const int *index_refer,
const float *grad_output)
{
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c);
dim3 threads(THREADS_PER_BLOCK);
attention_relation_step_backward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c,
query, grad_query,
key, grad_key,
weight, grad_weight,
index_target, index_refer,
grad_output);
}
void attention_fusion_step_forward_cuda_launcher(int m, int g, int c,
const float *weight, const float *value,
const int *index_target, const int *index_refer,
float *output)
{
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c);
dim3 threads(THREADS_PER_BLOCK);
attention_fusion_step_forward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c, weight, value,
index_target, index_refer, output);
}
void attention_fusion_step_backward_cuda_launcher(int m, int g, int c,
const float *weight, float *grad_weight,
const float *value, float *grad_value,
const int *index_target, const int *index_refer,
const float *grad_output)
{
dim3 blocks(DIVUP(m, THREADS_PER_BLOCK), g, c);
dim3 threads(THREADS_PER_BLOCK);
attention_fusion_step_backward_cuda_kernel<<<blocks, threads, 0>>>(m, g, c,
weight, grad_weight,
value, grad_value,
index_target, index_refer,
grad_output);
}
|