David Huang commited on
Commit
a027c1d
·
1 Parent(s): dbc0180

HIP: implement FlashAttention via rocWMMA for CDNA and RDNA3+ (llama/12032)

Browse files

Adds GGML_HIP_ROCWMMA_FATTN and rocwmma header check
Adds rocWMMA support to fattn-wmma-f16

ggml/CMakeLists.txt CHANGED
@@ -162,6 +162,7 @@ set_property(CACHE GGML_CUDA_COMPRESSION_MODE PROPERTY STRINGS "none;speed;balan
162
  option(GGML_HIP "ggml: use HIP" OFF)
163
  option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
164
  option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
 
165
  option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
166
  option(GGML_VULKAN "ggml: use Vulkan" OFF)
167
  option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
 
162
  option(GGML_HIP "ggml: use HIP" OFF)
163
  option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, slow" OFF)
164
  option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON)
165
+ option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF)
166
  option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
167
  option(GGML_VULKAN "ggml: use Vulkan" OFF)
168
  option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -62,6 +62,7 @@
62
  #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
63
  #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
64
 
 
65
  #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
66
  #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
67
  #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
@@ -196,6 +197,10 @@ typedef float2 dfloat2;
196
  #define FP16_MMA_AVAILABLE
197
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
198
 
 
 
 
 
199
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
200
  #define NEW_MMA_AVAILABLE
201
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
@@ -223,12 +228,18 @@ static bool fast_fp16_hardware_available(const int cc) {
223
 
224
  // Any FP16 tensor core instructions are available for ggml code.
225
  static bool fp16_mma_available(const int cc) {
226
- return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
 
 
 
 
 
227
  }
228
 
229
  // To be used for feature selection of external libraries, e.g. cuBLAS.
230
  static bool fp16_mma_hardware_available(const int cc) {
231
- return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
 
232
  }
233
 
234
  // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
 
62
  #define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
63
  #define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
64
 
65
+ #define GGML_CUDA_CC_IS_AMD(cc) (cc >= GGML_CUDA_CC_OFFSET_AMD)
66
  #define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
67
  #define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
68
  #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
 
197
  #define FP16_MMA_AVAILABLE
198
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
199
 
200
+ #if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
201
+ #define FP16_MMA_AVAILABLE
202
+ #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3))
203
+
204
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
205
  #define NEW_MMA_AVAILABLE
206
  #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
228
 
229
  // Any FP16 tensor core instructions are available for ggml code.
230
  static bool fp16_mma_available(const int cc) {
231
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
232
+ return false;
233
+ #else
234
+ return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ||
235
+ GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
236
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(GGML_HIP_ROCWMMA_FATTN)
237
  }
238
 
239
  // To be used for feature selection of external libraries, e.g. cuBLAS.
240
  static bool fp16_mma_hardware_available(const int cc) {
241
+ return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA ||
242
+ GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
243
  }
244
 
245
  // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -57,12 +57,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
57
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
58
 
59
  const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
 
60
  GGML_UNUSED(Q_v);
61
 
62
  T sum = 0.0f;
63
 
64
  #pragma unroll
65
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
66
  const int k_KQ = k_KQ_0 + threadIdx.x;
67
 
68
  const int ib = k_KQ / QI8_1;
@@ -70,7 +71,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
70
  const int shift = k_KQ & (QI8_1/2);
71
 
72
  const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
73
- const int u = Q_q8[k_KQ_0/WARP_SIZE];
74
 
75
  const int sumi = ggml_cuda_dp4a(v, u, 0);
76
 
@@ -78,14 +79,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
78
  if (std::is_same<T, half>::value) {
79
  const half2 * Q_ds = (const half2 *) Q_ds_v;
80
 
81
- const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
82
  sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
83
  } else
84
  #endif // FP16_AVAILABLE
85
  {
86
  const float2 * Q_ds = (const float2 *) Q_ds_v;
87
 
88
- sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (8/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
89
  }
90
  }
91
 
@@ -97,12 +98,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
97
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
98
 
99
  const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
 
100
  GGML_UNUSED(Q_v);
101
 
102
  T sum = 0.0f;
103
 
104
  #pragma unroll
105
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
106
  const int k_KQ = k_KQ_0 + threadIdx.x;
107
 
108
  const int ib = k_KQ / QI8_1;
@@ -110,7 +112,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
110
  const int shift = k_KQ & (QI8_1/2);
111
 
112
  const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
113
- const int u = Q_q8[k_KQ_0/WARP_SIZE];
114
 
115
  const int sumi = ggml_cuda_dp4a(v, u, 0);
116
 
@@ -118,7 +120,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
118
  if (std::is_same<T, half>::value) {
119
  const half2 * Q_ds = (const half2 *) Q_ds_v;
120
 
121
- const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
122
  const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
123
  sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
124
  } else
@@ -126,8 +128,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
126
  {
127
  const float2 * Q_ds = (const float2 *) Q_ds_v;
128
 
129
- const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
130
- const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
131
 
132
  sum += (T) (sumid4d8 + m4s8scaled);
133
  }
@@ -141,12 +143,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
141
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
142
 
143
  const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
 
144
  GGML_UNUSED(Q_v);
145
 
146
  T sum = 0.0f;
147
 
148
  #pragma unroll
149
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
150
  const int k_KQ = k_KQ_0 + threadIdx.x;
151
 
152
  const int ib = k_KQ / QI8_1;
@@ -161,7 +164,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
161
  v |= (vh << 18) & 0x00100000; // 2 -> 20
162
  v |= (vh << 25) & 0x10000000; // 3 -> 28
163
 
164
- const int u = Q_q8[k_KQ_0/WARP_SIZE];
165
 
166
  const int sumi = ggml_cuda_dp4a(v, u, 0);
167
 
@@ -169,14 +172,14 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
169
  if (std::is_same<T, half>::value) {
170
  const half2 * Q_ds = (const half2 *) Q_ds_v;
171
 
172
- const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/WARP_SIZE];
173
  sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
174
  } else
175
  #endif // FP16_AVAILABLE
176
  {
177
  const float2 * Q_ds = (const float2 *) Q_ds_v;
178
 
179
- sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/WARP_SIZE].x - (16/QI8_1)*Q_ds[k_KQ_0/WARP_SIZE].y));
180
  }
181
  }
182
 
@@ -188,12 +191,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
188
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
189
 
190
  const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
 
191
  GGML_UNUSED(Q_v);
192
 
193
  T sum = 0.0f;
194
 
195
  #pragma unroll
196
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
197
  const int k_KQ = k_KQ_0 + threadIdx.x;
198
 
199
  const int ib = k_KQ / QI8_1;
@@ -208,7 +212,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
208
  v |= (vh << 18) & 0x00100000; // 2 -> 20
209
  v |= (vh << 25) & 0x10000000; // 3 -> 28
210
 
211
- const int u = Q_q8[k_KQ_0/WARP_SIZE];
212
 
213
  const int sumi = ggml_cuda_dp4a(v, u, 0);
214
 
@@ -216,7 +220,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
216
  if (std::is_same<T, half>::value) {
217
  const half2 * Q_ds = (const half2 *) Q_ds_v;
218
 
219
- const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/WARP_SIZE];
220
  const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
221
  sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
222
  } else
@@ -224,8 +228,8 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
224
  {
225
  const float2 * Q_ds = (const float2 *) Q_ds_v;
226
 
227
- const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].x * sumi;
228
- const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/WARP_SIZE].y / QI8_1;
229
 
230
  sum += (T) (sumid5d8 + m5s8scaled);
231
  }
@@ -239,12 +243,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
239
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
240
 
241
  const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
 
242
  GGML_UNUSED(Q_v);
243
 
244
  T sum = 0.0f;
245
 
246
  #pragma unroll
247
- for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += WARP_SIZE) {
248
  const int k_KQ = k_KQ_0 + threadIdx.x;
249
 
250
  const int ib = k_KQ / QI8_0;
@@ -255,13 +260,13 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0(
255
  T Q_d;
256
  if (std::is_same<T, half>::value) {
257
  const half2 * Q_ds = (const half2 *) Q_ds_v;
258
- Q_d = __low2half(Q_ds[k_KQ_0/WARP_SIZE]);
259
  } else {
260
  const float2 * Q_ds = (const float2 *) Q_ds_v;
261
- Q_d = Q_ds[k_KQ_0/WARP_SIZE].x;
262
  }
263
 
264
- sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/WARP_SIZE], K_q8_0[ib].d, Q_d);
265
  }
266
 
267
  return sum;
@@ -272,6 +277,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
272
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
273
 
274
  const half2 * K_h2 = (const half2 *) K_c;
 
275
  GGML_UNUSED(Q_q8);
276
  GGML_UNUSED(Q_ds_v);
277
 
@@ -282,11 +288,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
282
  half2 sum2 = make_half2(0.0f, 0.0f);
283
 
284
  #pragma unroll
285
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
286
  const int k_KQ = k_KQ_0 + threadIdx.x;
287
 
288
  const half2 K_ik = K_h2[k_KQ];
289
- sum2 += K_ik * Q_h2[k_KQ_0/WARP_SIZE];
290
  }
291
 
292
  return __low2half(sum2) + __high2half(sum2);
@@ -298,12 +304,12 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
298
  float sum = 0.0f;
299
 
300
  #pragma unroll
301
- for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
302
  const int k_KQ = k_KQ_0 + threadIdx.x;
303
 
304
  const half2 K_ik = K_h2[k_KQ];
305
- sum += __low2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].x;
306
- sum += __high2float(K_ik) * Q_f2[k_KQ_0/WARP_SIZE].y;
307
  }
308
 
309
  return sum;
@@ -698,6 +704,8 @@ void launch_fattn(
698
 
699
  GGML_ASSERT(Q->ne[3] == 1);
700
 
 
 
701
  ggml_cuda_pool & pool = ctx.pool();
702
  cudaStream_t main_stream = ctx.stream();
703
  const int id = ggml_cuda_get_device();
@@ -750,7 +758,7 @@ void launch_fattn(
750
  const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
751
  const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
752
 
753
- const dim3 block_dim(WARP_SIZE, nwarps, 1);
754
  dim3 blocks_num;
755
  if (parallel_blocks == 0) {
756
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
@@ -796,6 +804,8 @@ void launch_fattn(
796
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
797
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
798
 
 
 
799
  fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
800
  (const char *) Q->data,
801
  K_data,
 
57
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
58
 
59
  const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c;
60
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
61
  GGML_UNUSED(Q_v);
62
 
63
  T sum = 0.0f;
64
 
65
  #pragma unroll
66
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
67
  const int k_KQ = k_KQ_0 + threadIdx.x;
68
 
69
  const int ib = k_KQ / QI8_1;
 
71
  const int shift = k_KQ & (QI8_1/2);
72
 
73
  const int v = (get_int_b2(K_q4_0[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
74
+ const int u = Q_q8[k_KQ_0/warp_size];
75
 
76
  const int sumi = ggml_cuda_dp4a(v, u, 0);
77
 
 
79
  if (std::is_same<T, half>::value) {
80
  const half2 * Q_ds = (const half2 *) Q_ds_v;
81
 
82
+ const half2 sum2 = __half2half2(K_q4_0[ib].d) * Q_ds[k_KQ_0/warp_size];
83
  sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2) /* *8/QI8_1 == 1 */);
84
  } else
85
  #endif // FP16_AVAILABLE
86
  {
87
  const float2 * Q_ds = (const float2 *) Q_ds_v;
88
 
89
+ sum += (T) (__half2float(K_q4_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (8/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
90
  }
91
  }
92
 
 
98
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
99
 
100
  const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c;
101
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
102
  GGML_UNUSED(Q_v);
103
 
104
  T sum = 0.0f;
105
 
106
  #pragma unroll
107
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
108
  const int k_KQ = k_KQ_0 + threadIdx.x;
109
 
110
  const int ib = k_KQ / QI8_1;
 
112
  const int shift = k_KQ & (QI8_1/2);
113
 
114
  const int v = (get_int_b4(K_q4_1[ib].qs, iqs4) >> shift) & 0x0F0F0F0F;
115
+ const int u = Q_q8[k_KQ_0/warp_size];
116
 
117
  const int sumi = ggml_cuda_dp4a(v, u, 0);
118
 
 
120
  if (std::is_same<T, half>::value) {
121
  const half2 * Q_ds = (const half2 *) Q_ds_v;
122
 
123
+ const half2 d4d8_m4s8 = K_q4_1[ib].dm * Q_ds[k_KQ_0/warp_size];
124
  const half2 sumid4d8_m4s8scaled = d4d8_m4s8 * make_half2(sumi, 1.0f/QI8_1);
125
  sum += (T) (__low2half(sumid4d8_m4s8scaled) + __high2half(sumid4d8_m4s8scaled));
126
  } else
 
128
  {
129
  const float2 * Q_ds = (const float2 *) Q_ds_v;
130
 
131
+ const float sumid4d8 = __low2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
132
+ const float m4s8scaled = __high2float(K_q4_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
133
 
134
  sum += (T) (sumid4d8 + m4s8scaled);
135
  }
 
143
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
144
 
145
  const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c;
146
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
147
  GGML_UNUSED(Q_v);
148
 
149
  T sum = 0.0f;
150
 
151
  #pragma unroll
152
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
153
  const int k_KQ = k_KQ_0 + threadIdx.x;
154
 
155
  const int ib = k_KQ / QI8_1;
 
164
  v |= (vh << 18) & 0x00100000; // 2 -> 20
165
  v |= (vh << 25) & 0x10000000; // 3 -> 28
166
 
167
+ const int u = Q_q8[k_KQ_0/warp_size];
168
 
169
  const int sumi = ggml_cuda_dp4a(v, u, 0);
170
 
 
172
  if (std::is_same<T, half>::value) {
173
  const half2 * Q_ds = (const half2 *) Q_ds_v;
174
 
175
+ const half2 sum2 = __half2half2(K_q5_0[ib].d) * Q_ds[k_KQ_0/warp_size];
176
  sum += (T) (((half) sumi)*__low2half(sum2) - __high2half(sum2)*__float2half(2.0f)) /* *16/QI8_1 == 2 */;
177
  } else
178
  #endif // FP16_AVAILABLE
179
  {
180
  const float2 * Q_ds = (const float2 *) Q_ds_v;
181
 
182
+ sum += (T) (__half2float(K_q5_0[ib].d) * (sumi*Q_ds[k_KQ_0/warp_size].x - (16/QI8_1)*Q_ds[k_KQ_0/warp_size].y));
183
  }
184
  }
185
 
 
191
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
192
 
193
  const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c;
194
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
195
  GGML_UNUSED(Q_v);
196
 
197
  T sum = 0.0f;
198
 
199
  #pragma unroll
200
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
201
  const int k_KQ = k_KQ_0 + threadIdx.x;
202
 
203
  const int ib = k_KQ / QI8_1;
 
212
  v |= (vh << 18) & 0x00100000; // 2 -> 20
213
  v |= (vh << 25) & 0x10000000; // 3 -> 28
214
 
215
+ const int u = Q_q8[k_KQ_0/warp_size];
216
 
217
  const int sumi = ggml_cuda_dp4a(v, u, 0);
218
 
 
220
  if (std::is_same<T, half>::value) {
221
  const half2 * Q_ds = (const half2 *) Q_ds_v;
222
 
223
+ const half2 d5d8_m5s8 = K_q5_1[ib].dm * Q_ds[k_KQ_0/warp_size];
224
  const half2 sumid5d8_m5s8scaled = d5d8_m5s8 * make_half2(sumi, 1.0f/QI8_1);
225
  sum += (T) (__low2half(sumid5d8_m5s8scaled) + __high2half(sumid5d8_m5s8scaled));
226
  } else
 
228
  {
229
  const float2 * Q_ds = (const float2 *) Q_ds_v;
230
 
231
+ const float sumid5d8 = __low2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].x * sumi;
232
+ const float m5s8scaled = __high2float(K_q5_1[ib].dm)*Q_ds[k_KQ_0/warp_size].y / QI8_1;
233
 
234
  sum += (T) (sumid5d8 + m5s8scaled);
235
  }
 
243
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) {
244
 
245
  const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c;
246
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
247
  GGML_UNUSED(Q_v);
248
 
249
  T sum = 0.0f;
250
 
251
  #pragma unroll
252
+ for (int k_KQ_0 = 0; k_KQ_0 < D/sizeof(int); k_KQ_0 += warp_size) {
253
  const int k_KQ = k_KQ_0 + threadIdx.x;
254
 
255
  const int ib = k_KQ / QI8_0;
 
260
  T Q_d;
261
  if (std::is_same<T, half>::value) {
262
  const half2 * Q_ds = (const half2 *) Q_ds_v;
263
+ Q_d = __low2half(Q_ds[k_KQ_0/warp_size]);
264
  } else {
265
  const float2 * Q_ds = (const float2 *) Q_ds_v;
266
+ Q_d = Q_ds[k_KQ_0/warp_size].x;
267
  }
268
 
269
+ sum += vec_dot_q8_0_q8_1_impl<T, 1>(&v, &Q_q8[k_KQ_0/warp_size], K_q8_0[ib].d, Q_d);
270
  }
271
 
272
  return sum;
 
277
  const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) {
278
 
279
  const half2 * K_h2 = (const half2 *) K_c;
280
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
281
  GGML_UNUSED(Q_q8);
282
  GGML_UNUSED(Q_ds_v);
283
 
 
288
  half2 sum2 = make_half2(0.0f, 0.0f);
289
 
290
  #pragma unroll
291
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
292
  const int k_KQ = k_KQ_0 + threadIdx.x;
293
 
294
  const half2 K_ik = K_h2[k_KQ];
295
+ sum2 += K_ik * Q_h2[k_KQ_0/warp_size];
296
  }
297
 
298
  return __low2half(sum2) + __high2half(sum2);
 
304
  float sum = 0.0f;
305
 
306
  #pragma unroll
307
+ for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += warp_size) {
308
  const int k_KQ = k_KQ_0 + threadIdx.x;
309
 
310
  const half2 K_ik = K_h2[k_KQ];
311
+ sum += __low2float(K_ik) * Q_f2[k_KQ_0/warp_size].x;
312
+ sum += __high2float(K_ik) * Q_f2[k_KQ_0/warp_size].y;
313
  }
314
 
315
  return sum;
 
704
 
705
  GGML_ASSERT(Q->ne[3] == 1);
706
 
707
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
708
+
709
  ggml_cuda_pool & pool = ctx.pool();
710
  cudaStream_t main_stream = ctx.stream();
711
  const int id = ggml_cuda_get_device();
 
758
  const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
759
  const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
760
 
761
+ const dim3 block_dim(warp_size, nwarps, 1);
762
  dim3 blocks_num;
763
  if (parallel_blocks == 0) {
764
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
 
804
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
805
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
806
 
807
+ GGML_ASSERT(block_dim.x % warp_size == 0);
808
+ GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size);
809
  fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
810
  (const char *) Q->data,
811
  K_data,
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -7,14 +7,19 @@
7
  #include "fattn-wmma-f16.cuh"
8
 
9
  #ifdef FP16_MMA_AVAILABLE
 
10
  #include <mma.h>
 
 
 
 
 
 
11
  #endif // FP16_MMA_AVAILABLE
12
 
13
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
14
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
15
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
16
- __launch_bounds__(nwarps*WARP_SIZE, 1)
17
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
18
  static __global__ void flash_attn_ext_f16(
19
  const char * __restrict__ Q,
20
  const char * __restrict__ K,
@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
51
  const int ne1,
52
  const int ne2,
53
  const int ne3) {
54
- #if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
55
  // Skip unused kernel variants for faster compilation:
56
  if (use_logit_softcap && !(D == 128 || D == 256)) {
57
  NO_DEVICE_CODE;
@@ -60,6 +65,8 @@ static __global__ void flash_attn_ext_f16(
60
 
61
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
62
 
 
 
63
  const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
64
  const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
65
 
@@ -68,11 +75,11 @@ static __global__ void flash_attn_ext_f16(
68
  constexpr int frag_m = ncols == 8 ? 32 : 16;
69
  constexpr int frag_n = ncols == 8 ? 8 : 16;
70
  static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
71
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
72
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
73
- typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
74
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
75
- typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
76
 
77
  constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
78
  constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@@ -132,9 +139,9 @@ static __global__ void flash_attn_ext_f16(
132
  for (int j0 = 0; j0 < ncols; j0 += nwarps) {
133
  const int j = j0 + threadIdx.y;
134
  #pragma unroll
135
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
136
  const int i = i0 + threadIdx.x;
137
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
138
  break;
139
  }
140
  VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
@@ -146,9 +153,9 @@ static __global__ void flash_attn_ext_f16(
146
  for (int j0 = 0; j0 < ncols; j0 += nwarps) {
147
  const int j = j0 + threadIdx.y;
148
  #pragma unroll
149
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
150
  const int i = i0 + threadIdx.x;
151
- if (i0 + WARP_SIZE > D && i >= D) {
152
  break;
153
  }
154
  KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
@@ -162,7 +169,7 @@ static __global__ void flash_attn_ext_f16(
162
  for (int i0 = 0; i0 < D; i0 += 16) {
163
  #pragma unroll
164
  for (int j0 = 0; j0 < ncols; j0 += frag_n) {
165
- nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
166
  }
167
  }
168
 
@@ -176,20 +183,20 @@ static __global__ void flash_attn_ext_f16(
176
  frag_c_KQ KQ_c[ncols/frag_n];
177
  #pragma unroll
178
  for (int j = 0; j < ncols/frag_n; ++j) {
179
- nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
180
  }
181
  #pragma unroll
182
  for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
183
  frag_a_K K_a;
184
- nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
185
  #pragma unroll
186
  for (int j = 0; j < ncols/frag_n; ++j) {
187
- nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
188
  }
189
  }
190
  #pragma unroll
191
  for (int j0 = 0; j0 < ncols; j0 += frag_n) {
192
- nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
193
  }
194
  }
195
 
@@ -202,27 +209,27 @@ static __global__ void flash_attn_ext_f16(
202
  const int j = j0 + threadIdx.y;
203
 
204
  if (std::is_same<KQ_acc_t, float>::value) {
205
- float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
206
  #pragma unroll
207
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
208
  const int k = k0 + threadIdx.x;
209
 
210
- KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
211
 
212
  if (use_logit_softcap) {
213
- KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
214
  }
215
  }
216
 
217
  float KQ_max_new = KQ_max_f[j0/nwarps];
218
  #pragma unroll
219
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
220
  const int k = k0 + threadIdx.x;
221
 
222
- KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
223
- KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
224
  }
225
- KQ_max_new = warp_reduce_max(KQ_max_new);
226
 
227
  const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
228
  KQ_max_scale_f[j0/nwarps] = expf(diff);
@@ -233,48 +240,48 @@ static __global__ void flash_attn_ext_f16(
233
 
234
  float KQ_rowsum_add = 0.0f;
235
  #pragma unroll
236
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
237
  const int k = k0 + threadIdx.x;
238
 
239
- const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
240
- KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
241
  if (diff <= SOFTMAX_FTZ_THRESHOLD) {
242
- KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
243
  }
244
- KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
245
- KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
246
  }
247
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
248
 
249
  // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
250
  KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
251
  } else {
252
- half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
253
  #pragma unroll
254
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
255
  const int k = k0 + threadIdx.x;
256
 
257
- KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
258
 
259
  if (use_logit_softcap) {
260
  // There is no dedicated tangens hyperbolicus function for half2.
261
- KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
262
- KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
263
- /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
264
 
265
- KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
266
  }
267
  }
268
 
269
  half2 KQ_max_new = KQ_max_h2[j0/nwarps];
270
  #pragma unroll
271
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
272
  const int k = k0 + threadIdx.x;
273
 
274
- KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
275
- KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
276
  }
277
- KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
278
  const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
279
  KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
280
  const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
@@ -283,17 +290,17 @@ static __global__ void flash_attn_ext_f16(
283
 
284
  half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
285
  #pragma unroll
286
- for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
287
  const int k = k0 + threadIdx.x;
288
 
289
- const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
290
- KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
291
  const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
292
- *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
293
- KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
294
- KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
295
  }
296
- KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
297
 
298
  // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
299
  KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
@@ -308,7 +315,7 @@ static __global__ void flash_attn_ext_f16(
308
  #pragma unroll
309
  for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
310
  const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
311
- nvcuda::wmma::load_matrix_sync(
312
  KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
313
  KQ + j0*(kqar*kqs_padded) + k,
314
  kqar*kqs_padded);
@@ -320,7 +327,7 @@ static __global__ void flash_attn_ext_f16(
320
  for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
321
  #pragma unroll
322
  for (int j = 0; j < ncols/frag_n; ++j) {
323
- nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
324
  }
325
 
326
  #pragma unroll
@@ -328,10 +335,10 @@ static __global__ void flash_attn_ext_f16(
328
  const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
329
 
330
  frag_a_V v_a;
331
- nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
332
  #pragma unroll
333
  for (int j = 0; j < ncols/frag_n; ++j) {
334
- nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
335
  }
336
  }
337
  }
@@ -343,10 +350,10 @@ static __global__ void flash_attn_ext_f16(
343
  for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
344
  #pragma unroll
345
  for (int j0 = 0; j0 < ncols; j0 += frag_n) {
346
- nvcuda::wmma::store_matrix_sync(
347
  KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
348
  VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
349
- D_padded, nvcuda::wmma::mem_col_major);
350
  }
351
  }
352
 
@@ -364,9 +371,9 @@ static __global__ void flash_attn_ext_f16(
364
  }
365
 
366
  #pragma unroll
367
- for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
368
  const int i = i0 + threadIdx.x;
369
- if (i0 + WARP_SIZE > D/2 && i >= D/2) {
370
  break;
371
  }
372
 
@@ -398,9 +405,9 @@ static __global__ void flash_attn_ext_f16(
398
  }
399
 
400
  #pragma unroll
401
- for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
402
  const int i = i0 + threadIdx.x;
403
- if (i0 + WARP_SIZE > D && i >= D) {
404
  break;
405
  }
406
  float dst_val = VKQ[j_VKQ*D_padded + i];
@@ -425,7 +432,7 @@ static __global__ void flash_attn_ext_f16(
425
  }
426
  #else
427
  NO_DEVICE_CODE;
428
- #endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
429
  }
430
 
431
  constexpr int get_max_power_of_2(int x) {
@@ -515,6 +522,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
515
  const ggml_tensor * Q = dst->src[0];
516
 
517
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
518
 
519
  if (prec != GGML_PREC_DEFAULT) {
520
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
@@ -571,7 +579,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
571
  return;
572
  }
573
 
574
- if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
 
575
  constexpr int cols_per_block = 8;
576
  switch (Q->ne[0]) {
577
  case 64:
@@ -592,6 +601,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
592
  }
593
  return;
594
  }
 
595
 
596
  if (Q->ne[1] <= 32) {
597
  constexpr int cols_per_block = 16;
 
7
  #include "fattn-wmma-f16.cuh"
8
 
9
  #ifdef FP16_MMA_AVAILABLE
10
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
  #include <mma.h>
12
+ namespace wmma = nvcuda::wmma;
13
+ #elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
14
+ #undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
15
+ #include <rocwmma/rocwmma.hpp>
16
+ namespace wmma = rocwmma;
17
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
18
  #endif // FP16_MMA_AVAILABLE
19
 
20
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
21
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
22
+ __launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
 
 
23
  static __global__ void flash_attn_ext_f16(
24
  const char * __restrict__ Q,
25
  const char * __restrict__ K,
 
56
  const int ne1,
57
  const int ne2,
58
  const int ne3) {
59
+ #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
60
  // Skip unused kernel variants for faster compilation:
61
  if (use_logit_softcap && !(D == 128 || D == 256)) {
62
  NO_DEVICE_CODE;
 
65
 
66
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
67
 
68
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
69
+
70
  const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
71
  const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
72
 
 
75
  constexpr int frag_m = ncols == 8 ? 32 : 16;
76
  constexpr int frag_n = ncols == 8 ? 8 : 16;
77
  static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
78
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
79
+ typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
80
+ typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
81
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
82
+ typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
83
 
84
  constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
85
  constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
 
139
  for (int j0 = 0; j0 < ncols; j0 += nwarps) {
140
  const int j = j0 + threadIdx.y;
141
  #pragma unroll
142
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
143
  const int i = i0 + threadIdx.x;
144
+ if (i0 + warp_size > D/2 && i >= D/2) {
145
  break;
146
  }
147
  VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
 
153
  for (int j0 = 0; j0 < ncols; j0 += nwarps) {
154
  const int j = j0 + threadIdx.y;
155
  #pragma unroll
156
+ for (int i0 = 0; i0 < D; i0 += warp_size) {
157
  const int i = i0 + threadIdx.x;
158
+ if (i0 + warp_size > D && i >= D) {
159
  break;
160
  }
161
  KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
 
169
  for (int i0 = 0; i0 < D; i0 += 16) {
170
  #pragma unroll
171
  for (int j0 = 0; j0 < ncols; j0 += frag_n) {
172
+ wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
173
  }
174
  }
175
 
 
183
  frag_c_KQ KQ_c[ncols/frag_n];
184
  #pragma unroll
185
  for (int j = 0; j < ncols/frag_n; ++j) {
186
+ wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
187
  }
188
  #pragma unroll
189
  for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
190
  frag_a_K K_a;
191
+ wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
192
  #pragma unroll
193
  for (int j = 0; j < ncols/frag_n; ++j) {
194
+ wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
195
  }
196
  }
197
  #pragma unroll
198
  for (int j0 = 0; j0 < ncols; j0 += frag_n) {
199
+ wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
200
  }
201
  }
202
 
 
209
  const int j = j0 + threadIdx.y;
210
 
211
  if (std::is_same<KQ_acc_t, float>::value) {
212
+ float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
213
  #pragma unroll
214
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
215
  const int k = k0 + threadIdx.x;
216
 
217
+ KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
218
 
219
  if (use_logit_softcap) {
220
+ KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
221
  }
222
  }
223
 
224
  float KQ_max_new = KQ_max_f[j0/nwarps];
225
  #pragma unroll
226
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
227
  const int k = k0 + threadIdx.x;
228
 
229
+ KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
230
+ KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
231
  }
232
+ KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
233
 
234
  const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
235
  KQ_max_scale_f[j0/nwarps] = expf(diff);
 
240
 
241
  float KQ_rowsum_add = 0.0f;
242
  #pragma unroll
243
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
244
  const int k = k0 + threadIdx.x;
245
 
246
+ const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
247
+ KQ_f_tmp[k0/warp_size] = expf(diff);
248
  if (diff <= SOFTMAX_FTZ_THRESHOLD) {
249
+ KQ_f_tmp[k0/warp_size] = 0.0f;
250
  }
251
+ KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
252
+ KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
253
  }
254
+ KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
255
 
256
  // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
257
  KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
258
  } else {
259
+ half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
260
  #pragma unroll
261
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
262
  const int k = k0 + threadIdx.x;
263
 
264
+ KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
265
 
266
  if (use_logit_softcap) {
267
  // There is no dedicated tangens hyperbolicus function for half2.
268
+ KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
269
+ KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
270
+ /(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
271
 
272
+ KQ2_tmp[k0/warp_size] *= logit_softcap_2;
273
  }
274
  }
275
 
276
  half2 KQ_max_new = KQ_max_h2[j0/nwarps];
277
  #pragma unroll
278
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
279
  const int k = k0 + threadIdx.x;
280
 
281
+ KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
282
+ KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
283
  }
284
+ KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
285
  const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
286
  KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
287
  const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
 
290
 
291
  half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
292
  #pragma unroll
293
+ for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
294
  const int k = k0 + threadIdx.x;
295
 
296
+ const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
297
+ KQ2_tmp[k0/warp_size] = h2exp(diff);
298
  const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
299
+ *((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
300
+ KQ_rowsum_add += KQ2_tmp[k0/warp_size];
301
+ KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
302
  }
303
+ KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
304
 
305
  // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
306
  KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
 
315
  #pragma unroll
316
  for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
317
  const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
318
+ wmma::load_matrix_sync(
319
  KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
320
  KQ + j0*(kqar*kqs_padded) + k,
321
  kqar*kqs_padded);
 
327
  for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
328
  #pragma unroll
329
  for (int j = 0; j < ncols/frag_n; ++j) {
330
+ wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
331
  }
332
 
333
  #pragma unroll
 
335
  const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
336
 
337
  frag_a_V v_a;
338
+ wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
339
  #pragma unroll
340
  for (int j = 0; j < ncols/frag_n; ++j) {
341
+ wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
342
  }
343
  }
344
  }
 
350
  for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
351
  #pragma unroll
352
  for (int j0 = 0; j0 < ncols; j0 += frag_n) {
353
+ wmma::store_matrix_sync(
354
  KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
355
  VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
356
+ D_padded, wmma::mem_col_major);
357
  }
358
  }
359
 
 
371
  }
372
 
373
  #pragma unroll
374
+ for (int i0 = 0; i0 < D/2; i0 += warp_size) {
375
  const int i = i0 + threadIdx.x;
376
+ if (i0 + warp_size > D/2 && i >= D/2) {
377
  break;
378
  }
379
 
 
405
  }
406
 
407
  #pragma unroll
408
+ for (int i0 = 0; i0 < D; i0 += warp_size) {
409
  const int i = i0 + threadIdx.x;
410
+ if (i0 + warp_size > D && i >= D) {
411
  break;
412
  }
413
  float dst_val = VKQ[j_VKQ*D_padded + i];
 
432
  }
433
  #else
434
  NO_DEVICE_CODE;
435
+ #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
436
  }
437
 
438
  constexpr int get_max_power_of_2(int x) {
 
522
  const ggml_tensor * Q = dst->src[0];
523
 
524
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
525
+ const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
526
 
527
  if (prec != GGML_PREC_DEFAULT) {
528
  if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
 
579
  return;
580
  }
581
 
582
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
583
+ if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
584
  constexpr int cols_per_block = 8;
585
  switch (Q->ne[0]) {
586
  case 64:
 
601
  }
602
  return;
603
  }
604
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
605
 
606
  if (Q->ne[1] <= 32) {
607
  constexpr int cols_per_block = 16;
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
250
 
251
  ggml_cuda_set_device(ctx.device);
252
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
253
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
254
 
255
- // On AMD the tile kernels perform poorly, use the vec kernel instead:
256
  if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
 
 
 
 
 
 
 
 
257
  if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
258
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
259
  } else {
@@ -291,7 +299,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
291
  const int gqa_ratio = Q->ne[2] / K->ne[2];
292
  const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
293
  K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
294
- if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
295
  if (prec == GGML_PREC_DEFAULT) {
296
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
297
  return;
 
250
 
251
  ggml_cuda_set_device(ctx.device);
252
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
253
+ const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
254
  const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
255
 
 
256
  if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
257
+ #if defined(GGML_HIP_ROCWMMA_FATTN)
258
+ if (fp16_mma_available(cc)) {
259
+ ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
260
+ return;
261
+ }
262
+ #endif // defined(GGML_HIP_ROCWMMA_FATTN)
263
+
264
+ // On AMD the tile kernels perform poorly, use the vec kernel instead:
265
  if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
266
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
267
  } else {
 
299
  const int gqa_ratio = Q->ne[2] / K->ne[2];
300
  const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
301
  K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
302
+ if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) {
303
  if (prec == GGML_PREC_DEFAULT) {
304
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
305
  return;
ggml/src/ggml-hip/CMakeLists.txt CHANGED
@@ -39,6 +39,12 @@ endif()
39
  find_package(hip REQUIRED)
40
  find_package(hipblas REQUIRED)
41
  find_package(rocblas REQUIRED)
 
 
 
 
 
 
42
 
43
  if (${hip_VERSION} VERSION_LESS 5.5)
44
  message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
@@ -107,6 +113,10 @@ if (GGML_HIP_NO_VMM)
107
  add_compile_definitions(GGML_HIP_NO_VMM)
108
  endif()
109
 
 
 
 
 
110
  if (NOT GGML_CUDA_FA)
111
  add_compile_definitions(GGML_CUDA_NO_FA)
112
  endif()
 
39
  find_package(hip REQUIRED)
40
  find_package(hipblas REQUIRED)
41
  find_package(rocblas REQUIRED)
42
+ if (GGML_HIP_ROCWMMA_FATTN)
43
+ CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA)
44
+ if (NOT ${FOUND_ROCWMMA})
45
+ message(FATAL_ERROR "rocwmma has not been found")
46
+ endif()
47
+ endif()
48
 
49
  if (${hip_VERSION} VERSION_LESS 5.5)
50
  message(FATAL_ERROR "At least ROCM/HIP V5.5 is required")
 
113
  add_compile_definitions(GGML_HIP_NO_VMM)
114
  endif()
115
 
116
+ if (GGML_HIP_ROCWMMA_FATTN)
117
+ add_compile_definitions(GGML_HIP_ROCWMMA_FATTN)
118
+ endif()
119
+
120
  if (NOT GGML_CUDA_FA)
121
  add_compile_definitions(GGML_CUDA_NO_FA)
122
  endif()