File size: 1,587 Bytes
5ab06d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#version 450

#extension GL_EXT_control_flow_attributes : enable

#define BLOCK_SIZE 32

layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;

layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 1) writeonly buffer D {float data_d[];};

layout (push_constant) uniform parameter {
    uint D;
    uint N;
    uint k_num;
} p;

void main() {
    // Each workgroup handles a row
    const uint n = gl_WorkGroupID.x;
    const uint tid = gl_LocalInvocationID.x;

    uint D = p.D;
    uint N = p.N;
    uint k_num = p.k_num;

    uint l_offset = D * N * k_num + n;
    uint m_offset = D * N * k_num + N + n;
    uint lm_stride = N * 2;

    // Compute the max m value for the row
    float m_max = -1.0/0.0;
    [[unroll]] for (uint k = 0; k < k_num; ++k) {
        float m = data_a[m_offset + k * lm_stride];
        m_max = max(m_max, m);
    }

    // Compute L based on m_max
    float L = 0;
    [[unroll]] for (uint k = 0; k < k_num; ++k) {
        float l = data_a[l_offset + k * lm_stride];
        float m = data_a[m_offset + k * lm_stride];
        L += exp(m - m_max) * l;
    }

    L = 1.0 / L;

    // Scale and sum the O contributions based on m_max and store the result to memory
    for (uint d = tid; d < D; d += BLOCK_SIZE) {
        float O = 0.0;
        [[unroll]] for (uint k = 0; k < k_num; ++k) {
            uint o_offset = D * N * k + D * n + d;
            float m = data_a[m_offset + k * lm_stride];
            O += exp(m - m_max) * data_a[o_offset];
        }
        O *= L;
        data_d[D * n + d] = O;
    }
}