Spaces:
Running
Running
| layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] | |
| layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] | |
| layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout] | |
| layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in; | |
| layout (push_constant) uniform parameter { | |
| uint32_t Cout; | |
| uint32_t Cin; | |
| uint32_t K; | |
| uint32_t L; | |
| uint32_t KL; | |
| uint32_t nb01; | |
| uint32_t nb02; | |
| uint32_t nb11; | |
| uint32_t nb1; | |
| int32_t s0; | |
| } p; | |
| uint32_t Cout_idx = gl_WorkGroupID.x; | |
| const uint32_t bs = gl_WorkGroupSize.x; | |
| uint32_t tid = gl_LocalInvocationID.x; | |
| // Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K. | |
| uint32_t tmp_len = bs*p.s0+p.K; | |
| shared D_TYPE tmp[4096]; | |
| uint splitWork(uint workSize){ | |
| return (bs + workSize -1) / bs; | |
| } | |
| void main(){ | |
| for(uint32_t i = 0; i < splitWork(tmp_len); i++){ | |
| uint32_t idx = i*bs+tid; | |
| if(idx < tmp_len){ | |
| tmp[idx] = 0.0; | |
| } | |
| } | |
| uint32_t L_blocks = splitWork(p.L); | |
| for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){ | |
| if(L_block_id > 0){ | |
| barrier(); | |
| // Shift values in tmp to the current processing window | |
| for(int i = 0; i < splitWork(tmp_len); i++){ | |
| uint32_t idx = i*bs+tid; | |
| if(idx >= bs*p.s0 && idx < tmp_len){ | |
| tmp[idx-bs*p.s0] = tmp[idx]; | |
| tmp[idx] = 0.0; | |
| }else if(idx >= p.K && idx < bs*p.s0){ | |
| tmp[idx] = 0.0; | |
| } | |
| } | |
| } | |
| barrier(); | |
| // Save contributions of the block to tmp | |
| uint32_t L_idx = L_block_id*bs + tid; | |
| for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){ | |
| D_TYPE dp = 0.0; | |
| for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){ | |
| A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02]; | |
| if(L_idx < p.L){ | |
| B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11]; | |
| dp = fma(elemKrn, elemInp, dp); | |
| } | |
| } | |
| tmp[tid*p.s0 + K_idx] += dp; | |
| barrier(); | |
| } | |
| // Save the computed values except the last block that can have different size | |
| uint32_t KLb_idx = L_block_id*bs*p.s0; | |
| if(L_block_id < L_blocks-1){ | |
| for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){ | |
| uint32_t sh_idx = p.s0*tid+s0_idx; | |
| uint32_t KL_idx = KLb_idx+sh_idx; | |
| if(KL_idx < p.KL){ | |
| data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx]; | |
| } | |
| } | |
| } | |
| } | |
| for(uint32_t i = 0; i < splitWork(tmp_len); i++){ | |
| uint32_t idx = i*bs+tid; | |
| uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx; | |
| if(KL_idx < p.KL){ | |
| data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx]; | |
| } | |
| } | |
| } | |