JohannesGaessler commited on
Commit
6662d54
·
1 Parent(s): 6cb8158

CUDA: optimize FA for GQA + large batches (llama/12014)

Browse files
Files changed (27) hide show
  1. ggml/src/ggml-cuda/cp-async.cuh +1 -1
  2. ggml/src/ggml-cuda/fattn-common.cuh +53 -69
  3. ggml/src/ggml-cuda/fattn-mma-f16.cuh +552 -252
  4. ggml/src/ggml-cuda/fattn-tile-f16.cu +2 -2
  5. ggml/src/ggml-cuda/fattn-tile-f32.cu +2 -2
  6. ggml/src/ggml-cuda/fattn-vec-f16.cuh +1 -1
  7. ggml/src/ggml-cuda/fattn-vec-f32.cuh +1 -1
  8. ggml/src/ggml-cuda/fattn-wmma-f16.cu +3 -3
  9. ggml/src/ggml-cuda/fattn.cu +56 -17
  10. ggml/src/ggml-cuda/mma.cuh +75 -0
  11. ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb16.cu → fattn-mma-f16-instance-ncols1_1-ncols2_8.cu} +6 -6
  12. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
  13. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
  14. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
  15. ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb32.cu → fattn-mma-f16-instance-ncols1_2-ncols2_4.cu} +6 -6
  16. ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb64.cu → fattn-mma-f16-instance-ncols1_2-ncols2_8.cu} +6 -6
  17. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
  18. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
  19. ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb8.cu → fattn-mma-f16-instance-ncols1_4-ncols2_2.cu} +6 -6
  20. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
  21. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
  22. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
  23. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
  24. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
  25. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
  26. ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
  27. ggml/src/ggml-cuda/template-instances/generate_cu_files.py +13 -7
ggml/src/ggml-cuda/cp-async.cuh CHANGED
@@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co
24
  } else
25
  #endif // CUDART_VERSION >= 11040
26
  {
27
- asm volatile("cp.async.cg.shared.global.L2 [%0], [%1], 16;"
28
  : : "r"(dst), "l"(src));
29
  }
30
  #else
 
24
  } else
25
  #endif // CUDART_VERSION >= 11040
26
  {
27
+ asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
28
  : : "r"(dst), "l"(src));
29
  }
30
  #else
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
516
  nullptr;
517
  }
518
 
519
- // The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
520
- #ifdef __clang__
521
- #pragma clang diagnostic push
522
- #pragma clang diagnostic ignored "-Wpass-failed"
523
- #endif // __clang__
524
-
525
- template<int D, int ncols, int KQ_stride> // D == head size
526
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
527
  __launch_bounds__(D, 1)
528
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
529
  static __global__ void flash_attn_stream_k_fixup(
530
  float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
531
- const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
532
-
533
- const int iter_k = ne11 / KQ_stride;
534
- const int iter_j = (ne01 + (ncols - 1)) / ncols;
535
 
536
  const int bidx0 = blockIdx.x;
 
 
 
 
 
 
 
 
 
537
 
538
- const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
539
- const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
540
 
541
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
542
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
548
  const int channel = kbc0 / (iter_k*iter_j);
549
  const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
550
 
551
- dst += jt*ncols*ne02*D + channel*D;
 
 
552
 
553
- // Load the partial result that needs a fixup:
554
- float dst_val[ncols] = {0.0f};
555
- float max_val[ncols] = {0.0f};
556
- float rowsum[ncols] = {0.0f};
557
- #pragma unroll
558
- for (int j = 0; j < ncols; ++j) {
559
- if (jt*ncols + j >= ne01) {
560
- break;
561
- }
562
- dst_val[j] = dst[j*ne02*D + threadIdx.x];
563
 
564
- const float2 tmp = dst_fixup[bidx0*ncols + j];
565
- max_val[j] = tmp.x;
566
- rowsum[j] = tmp.y;
 
 
 
 
 
 
 
567
  }
568
 
569
  // Iterate over previous blocks and compute the combined results.
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
571
  int bidx = bidx0 - 1;
572
  int kbc_stop = kbc0;
573
  while(true) {
574
- const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
575
  if (kbc == kbc_stop) { // Did not have any data.
576
  bidx--;
577
  kbc_stop = kbc;
578
  continue;
579
  }
580
 
581
- #pragma unroll
582
- for (int j = 0; j < ncols; ++j) {
583
- if (jt*ncols + j >= ne01) {
584
- break;
585
- }
586
- const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
587
 
588
- const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
589
 
590
- // Scale the current and new value accumulators depending on the max. values.
591
- const float max_val_new = fmaxf(max_val[j], tmp.x);
592
 
593
- const float diff_val = max_val[j] - max_val_new;
594
- const float diff_add = tmp.x - max_val_new;
595
 
596
- const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
597
- const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
598
 
599
- dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
600
- rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
601
 
602
- max_val[j] = max_val_new;
603
- }
604
 
605
  // If this block started in a previous tile we are done and don't need to combine additional partial results.
606
  if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
611
  }
612
 
613
  // Write back final result:
614
- #pragma unroll
615
- for (int j = 0; j < ncols; ++j) {
616
- if (jt*ncols + j >= ne01) {
617
- return;
618
- }
619
- dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
620
- }
621
  }
622
 
623
- #ifdef __clang__
624
- #pragma clang diagnostic pop
625
- #endif // __clang__
626
-
627
  template<int D, int parallel_blocks> // D == head size
628
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
629
  __launch_bounds__(D, 1)
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
690
  }
691
 
692
  // parallel_blocks == 0 is stream-k decomposition
693
- template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
694
  void launch_fattn(
695
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
696
  const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
697
  ) {
 
 
698
  const ggml_tensor * Q = dst->src[0];
699
  const ggml_tensor * K = dst->src[1];
700
  const ggml_tensor * V = dst->src[2];
@@ -763,25 +747,26 @@ void launch_fattn(
763
  nb23 = nb23*bs*sizeof(half)/ts;
764
  }
765
 
766
- const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
767
- const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
768
 
769
  const dim3 block_dim(WARP_SIZE, nwarps, 1);
770
  dim3 blocks_num;
771
  if (parallel_blocks == 0) {
772
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
773
- const int tiles_nwaves = (ntiles_total + 2*nsm - 1) / (2*nsm);
774
- const int tiles_efficiency_percent = 100 * ntiles_total / (2*nsm*tiles_nwaves);
 
775
 
776
- const int nblocks_stream_k = 2*nsm;
777
 
778
- const bool use_stream_k = tiles_efficiency_percent < 75 || cc >= GGML_CUDA_CC_ADA_LOVELACE;
779
 
780
  blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
781
  blocks_num.y = 1;
782
  blocks_num.z = 1;
783
 
784
- dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
785
  } else {
786
  blocks_num.x = parallel_blocks*ntiles_x;
787
  blocks_num.y = Q->ne[2];
@@ -793,7 +778,6 @@ void launch_fattn(
793
  }
794
  }
795
 
796
-
797
  float scale = 1.0f;
798
  float max_bias = 0.0f;
799
  float logit_softcap = 0.0f;
@@ -832,9 +816,9 @@ void launch_fattn(
832
  if constexpr (parallel_blocks == 0) {
833
  if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
834
  const dim3 block_dim_combine(D, 1, 1);
835
- const dim3 blocks_num_combine = blocks_num;
836
 
837
- flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
838
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
839
  ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
840
  }
 
516
  nullptr;
517
  }
518
 
519
+ template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
 
 
 
 
 
 
 
520
  __launch_bounds__(D, 1)
 
521
  static __global__ void flash_attn_stream_k_fixup(
522
  float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
523
+ constexpr int ncols = ncols1*ncols2;
 
 
 
524
 
525
  const int bidx0 = blockIdx.x;
526
+ const int j = blockIdx.y;
527
+ const int c = blockIdx.z;
528
+ const int jc = j*ncols2 + c;
529
+ const int tid = threadIdx.x;
530
+
531
+ const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
532
+
533
+ const int iter_k = ne11 / FATTN_KQ_STRIDE;
534
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
535
 
536
+ const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
537
+ const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
538
 
539
  const bool did_not_have_any_data = kbc0 == kbc0_stop;
540
  const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
 
546
  const int channel = kbc0 / (iter_k*iter_j);
547
  const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
548
 
549
+ if (jt*ncols1 + j >= ne01) {
550
+ return;
551
+ }
552
 
553
+ dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
 
 
 
 
 
 
 
 
 
554
 
555
+ // Load the partial result that needs a fixup:
556
+ float dst_val = 0.0f;
557
+ float max_val = 0.0f;
558
+ float rowsum = 0.0f;
559
+ {
560
+ dst_val = *dst;
561
+
562
+ const float2 tmp = dst_fixup[bidx0*ncols + jc];
563
+ max_val = tmp.x;
564
+ rowsum = tmp.y;
565
  }
566
 
567
  // Iterate over previous blocks and compute the combined results.
 
569
  int bidx = bidx0 - 1;
570
  int kbc_stop = kbc0;
571
  while(true) {
572
+ const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
573
  if (kbc == kbc_stop) { // Did not have any data.
574
  bidx--;
575
  kbc_stop = kbc;
576
  continue;
577
  }
578
 
579
+ const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
 
 
 
 
 
580
 
581
+ const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
582
 
583
+ // Scale the current and new value accumulators depending on the max. values.
584
+ const float max_val_new = fmaxf(max_val, tmp.x);
585
 
586
+ const float diff_val = max_val - max_val_new;
587
+ const float diff_add = tmp.x - max_val_new;
588
 
589
+ const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
590
+ const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
591
 
592
+ dst_val = scale_val*dst_val + scale_add*dst_add;
593
+ rowsum = scale_val*rowsum + scale_add*tmp.y;
594
 
595
+ max_val = max_val_new;
 
596
 
597
  // If this block started in a previous tile we are done and don't need to combine additional partial results.
598
  if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
 
603
  }
604
 
605
  // Write back final result:
606
+ *dst = dst_val / rowsum;
 
 
 
 
 
 
607
  }
608
 
 
 
 
 
609
  template<int D, int parallel_blocks> // D == head size
610
  #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
611
  __launch_bounds__(D, 1)
 
672
  }
673
 
674
  // parallel_blocks == 0 is stream-k decomposition
675
+ template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
676
  void launch_fattn(
677
  ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
678
  const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
679
  ) {
680
+ constexpr int ncols = ncols1 * ncols2;
681
+
682
  const ggml_tensor * Q = dst->src[0];
683
  const ggml_tensor * K = dst->src[1];
684
  const ggml_tensor * V = dst->src[2];
 
747
  nb23 = nb23*bs*sizeof(half)/ts;
748
  }
749
 
750
+ const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
751
+ const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
752
 
753
  const dim3 block_dim(WARP_SIZE, nwarps, 1);
754
  dim3 blocks_num;
755
  if (parallel_blocks == 0) {
756
  // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
757
+ const int max_blocks = 2*nsm;
758
+ const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
759
+ const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
760
 
761
+ const int nblocks_stream_k = max_blocks;
762
 
763
+ const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
764
 
765
  blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
766
  blocks_num.y = 1;
767
  blocks_num.z = 1;
768
 
769
+ dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
770
  } else {
771
  blocks_num.x = parallel_blocks*ntiles_x;
772
  blocks_num.y = Q->ne[2];
 
778
  }
779
  }
780
 
 
781
  float scale = 1.0f;
782
  float max_bias = 0.0f;
783
  float logit_softcap = 0.0f;
 
816
  if constexpr (parallel_blocks == 0) {
817
  if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
818
  const dim3 block_dim_combine(D, 1, 1);
819
+ const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
820
 
821
+ flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
822
  <<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
823
  ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
824
  }
ggml/src/ggml-cuda/fattn-mma-f16.cuh CHANGED
@@ -5,12 +5,15 @@
5
 
6
  using namespace ggml_cuda_mma;
7
 
8
- typedef tile<16, 8, half2> tile_A;
9
- typedef tile< 8, 8, half2> tile_B;
10
- typedef tile<16, 8, float> tile_C_KQ;
11
- typedef tile<16, 4, half2> tile_C_VKQ;
12
-
13
- template<int D, int nwarps, int KQ_stride>
 
 
 
14
  static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
15
  const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
16
  constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
@@ -27,7 +30,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
27
  constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
28
  constexpr int stride_i = WARP_SIZE / chunks_per_row;
29
  #pragma unroll
30
- for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
31
  const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
32
  const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
33
 
@@ -40,7 +43,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
40
 
41
  // If D is not a power of 2, the rest is loaded synchronously.
42
  // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
43
- static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
44
  #pragma unroll
45
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
46
  const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
@@ -52,7 +55,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
52
  }
53
 
54
  #pragma unroll
55
- for (int i0 = 0; i0 < KQ_stride; i0 += nwarps*stride_i) {
56
  const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
57
 
58
  #pragma unroll
@@ -65,12 +68,54 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
65
  }
66
  }
67
 
68
- template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
70
  const float2 * const __restrict__ Q_f2,
71
  const half2 * const __restrict__ K_h2,
72
  const half2 * const __restrict__ V_h2,
73
- const half * const __restrict__ maskh,
74
  float2 * const __restrict__ dstk,
75
  float2 * const __restrict__ dstk_fixup,
76
  const float scale,
@@ -78,42 +123,60 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
78
  const float logit_softcap,
79
  const int ne01,
80
  const int ne02,
81
- const int stride_Q,
82
  const int stride_KV,
83
  const int stride_mask,
84
  const int jt,
85
  half2 * const __restrict__ tile_K,
86
  half2 * const __restrict__ tile_V,
 
87
  const tile_B * const __restrict__ Q_B,
88
  tile_C_VKQ * const __restrict__ VKQ_C,
89
- float2 & KQ_max,
90
- float2 & KQ_rowsum,
91
  const int kb0) {
92
  #ifdef NEW_MMA_AVAILABLE
93
- constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
94
- constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
 
 
 
 
 
95
 
96
- const int k_VKQ_0 = kb0*KQ_stride;
97
- tile_C_KQ KQ_C[KQ_stride/(np*tile_C_KQ::I)];
 
 
98
 
99
  #ifdef CP_ASYNC_AVAILABLE
100
  cp_async_wait_all();
101
  __syncthreads();
102
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
103
  #else
104
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
 
 
 
105
  __syncthreads();
106
  #endif // CP_ASYNC_AVAILABLE
107
 
108
  // Calculate tile of KQ:
109
  #pragma unroll
110
- for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*tile_A::I) {
111
  const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
112
  #pragma unroll
113
  for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
114
  tile_A K_A;
115
  load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
116
- mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, ((tile_B *) Q_B)[k_KQ_0/tile_A::J]);
 
 
 
 
 
 
 
 
117
  }
118
  }
119
 
@@ -122,9 +185,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
122
  #endif // CP_ASYNC_AVAILABLE
123
 
124
  if (use_logit_softcap) {
125
- static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
126
  #pragma unroll
127
- for (int i = 0; i < KQ_stride/(np*tile_C_KQ::I); ++i) {
128
  #pragma unroll
129
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
130
  KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
@@ -132,109 +195,209 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
132
  }
133
  }
134
 
135
- if (maskh) {
136
- static_assert(KQ_stride % (np *tile_C_KQ::I) == 0, "bad loop size");
137
- static_assert(ncols % (nwarps/np*tile_C_KQ::J) == 0, "bad loop size");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  #pragma unroll
139
- for (int i00 = 0; i00 < KQ_stride; i00 += np*tile_C_KQ::I) {
140
- const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
 
 
 
 
 
 
 
 
 
141
  #pragma unroll
142
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
143
- const int i = i0 + tile_C_KQ::get_i(l);
144
- const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
145
 
146
- KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  }
148
  }
149
- }
150
 
151
- // Calculate softmax for each KQ column using the current max. value.
152
- // The divisor is stored in KQ_rowsum and will be applied at the end.
153
- float2 KQ_max_new = KQ_max;
154
- static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
 
155
  #pragma unroll
156
- for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
157
  #pragma unroll
158
- for (int l0 = 0; l0 < tile_C_KQ::ne; l0 += 2) {
159
- KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
160
- KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
 
 
161
  }
162
- }
163
 
164
- // Values per KQ column are spread across 8 threads, does not need full warp reduce:
165
  #pragma unroll
166
- for (int offset = 16; offset > 2; offset >>= 1) {
167
- KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
168
- KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
169
- }
 
 
170
 
171
- float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
172
- static_assert(KQ_stride % (np*tile_C_KQ::I) == 0, "bad loop size");
 
173
  #pragma unroll
174
- for (int k = 0; k < KQ_stride/(np*tile_C_KQ::I); ++k) {
175
  #pragma unroll
176
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
177
- const float KQ_max_l = l % 2 == 0 ? KQ_max_new.x : KQ_max_new.y;
178
- const float diff = KQ_C[k].x[l] - KQ_max_l;
179
- KQ_C[k].x[l] = expf(diff);
180
 
181
- if (l % 2 == 0) {
182
- KQ_rowsum_add.x += KQ_C[k].x[l];
183
- } else {
184
- KQ_rowsum_add.y += KQ_C[k].x[l];
185
  }
186
  }
187
  }
188
 
189
  {
190
- const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
191
- const float2 KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
192
- KQ_max = KQ_max_new;
 
 
193
 
194
- // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
195
- KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
196
- KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
197
 
198
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
 
199
  #pragma unroll
200
- for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
201
  #pragma unroll
202
- for (int l = 0; l < tile_C_VKQ::ne; ++l) {
203
- VKQ_C[i].x[l] *= KQ_max_scale_h2;
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  }
205
  }
206
  }
207
 
208
  // Convert KQ C tiles into B tiles for VKQ calculation:
209
- tile_B B[KQ_stride/(np*2*tile_B::J)];
210
- static_assert(KQ_stride % (np*2*tile_B::J) == 0, "bad loop size");
 
 
211
  #pragma unroll
212
- for (int k = 0; k < KQ_stride/(np*2*tile_B::J); ++k) {
213
- B[k] = get_transposed(get_half2(KQ_C[k]));
 
 
 
 
 
 
 
 
214
  }
215
 
216
  #ifdef CP_ASYNC_AVAILABLE
 
217
  cp_async_wait_all();
218
  __syncthreads();
219
  if (!last_iter) {
220
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + (k_VKQ_0 + KQ_stride)*stride_KV, tile_K, stride_KV);
 
 
 
221
  }
222
  #else
223
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
224
  __syncthreads();
225
  #endif // CP_ASYNC_AVAILABLE
226
 
227
  // Calculate VKQ tile:
228
  #pragma unroll
229
  for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
230
- static_assert((KQ_stride/2) % (np*tile_A::J) == 0, "bad loop size");
231
  #pragma unroll
232
- for (int k00 = 0; k00 < KQ_stride/2; k00 += np*tile_A::J) {
233
  const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
234
 
235
  tile_A A;
236
  load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
237
- mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
 
 
 
 
 
 
 
 
238
  }
239
  }
240
 
@@ -247,12 +410,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
247
  #endif // NEW_MMA_AVAILABLE
248
  }
249
 
250
- template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
251
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
252
  const float2 * const __restrict__ Q_f2,
253
  const half2 * const __restrict__ K_h2,
254
  const half2 * const __restrict__ V_h2,
255
- const half * const __restrict__ maskh,
256
  float2 * const __restrict__ dstk,
257
  float2 * const __restrict__ dstk_fixup,
258
  const float scale,
@@ -260,7 +423,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
260
  const float logit_softcap,
261
  const int ne01,
262
  const int ne02,
263
- const int stride_Q,
 
264
  const int stride_KV,
265
  const int stride_mask,
266
  const int jt,
@@ -269,63 +433,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
269
  #ifdef NEW_MMA_AVAILABLE
270
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
271
 
272
- static_assert(nwarps*tile_B::I % ncols == 0, "bad nwarps");
273
- constexpr int np = nwarps*tile_B::I / ncols; // Number of parallel CUDA warps per Q column.
 
 
 
 
274
 
275
- static_assert(D % nwarps == 0, "bad D");
276
- static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
277
 
278
  constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
279
 
280
- // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements:
281
  extern __shared__ half2 tile_K[];
282
  #ifdef CP_ASYNC_AVAILABLE
283
- half2 * tile_V = tile_K + KQ_stride*D2_padded;
284
  #else
285
- half2 * tile_V = tile_K;
286
  #endif // CP_ASYNC_AVAILABLE
 
287
 
288
- tile_B Q_B[D/(2*tile_B::J)];
289
- tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
290
 
291
- float2 KQ_rowsum = {0.0f, 0.0f};
292
- float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
 
 
 
 
 
 
 
293
 
294
  // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
295
  // The loading is done with decreasing granularity for D for better memory bandwidth.
296
  const half2 scale_h2 = make_half2(scale, scale);
297
  #pragma unroll
298
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
299
- const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
300
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
301
- const int stride_j = WARP_SIZE / stride_k;
302
 
303
  if (k0_start == k0_stop) {
304
  continue;
305
  }
306
 
307
- if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
308
- break;
309
- }
310
-
311
  #pragma unroll
312
- for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
313
- const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
 
 
 
 
 
 
 
314
 
315
- if (jt*ncols + j < ne01) {
316
  #pragma unroll
317
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
318
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
319
 
320
- const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
321
- tile_K[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
322
  }
323
  } else {
324
  #pragma unroll
325
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
326
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
327
 
328
- tile_K[j*D2_padded + k] = make_half2(0.0f, 0.0f);
329
  }
330
  }
331
  }
@@ -334,128 +513,217 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
334
  __syncthreads();
335
 
336
  {
337
- const int j0 = (threadIdx.y / np) * tile_B::I;
338
 
339
  #pragma unroll
340
  for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
341
- load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
 
 
 
 
 
 
 
 
342
  }
343
  }
344
 
345
  __syncthreads();
346
 
347
- // Preload K data for first iteration when using cp_async:
348
  #ifdef CP_ASYNC_AVAILABLE
349
- flash_attn_ext_f16_load_tile<D, nwarps, KQ_stride>(K_h2 + kb0_start*KQ_stride*stride_KV, tile_K, stride_KV);
 
 
 
350
  #endif // CP_ASYNC_AVAILABLE
351
 
352
  // Iterate over ne11 == previous tokens:
353
  for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
354
  constexpr bool last_iter = false;
355
- flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
356
- (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
357
- ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
358
  }
359
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
360
  constexpr bool last_iter = true;
361
- flash_attn_ext_f16_iter<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup, last_iter>
362
- (Q_f2, K_h2, V_h2, maskh, dstk, dstk_fixup, scale, slope, logit_softcap,
363
- ne01, ne02, stride_Q, stride_KV, stride_mask, jt, tile_K, tile_V, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
364
  }
365
 
366
  // With cp_async there is no __syncthreads at the end of the iter,
367
  // there can be a race condition on shared memory access for combining/writing back results.
368
  #ifdef CP_ASYNC_AVAILABLE
369
- if (nwarps*tile_B::I > KQ_stride) {
370
  __syncthreads();
371
  }
372
  #endif // CP_ASYNC_AVAILABLE
373
 
374
  // Finally, sum up partial KQ rowsums.
375
- // The partial sums are spread across 8 threads each, does not need full reduce.
 
 
 
 
 
376
  #pragma unroll
377
- for (int offset = 16; offset > 2; offset >>= 1) {
378
- KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE);
379
- KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE);
 
380
  }
381
 
382
  // Write VKQ accumulators to shared memory in column-major format.
383
  // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
384
  // Also for np > 1 the combination is done via these values in shared memory.
385
- const int j_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // j combine write data
 
386
  #pragma unroll
387
- for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
388
- const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
389
 
390
  #pragma unroll
391
- for (int l = 0; l < tile_B::ne; ++l) {
392
- const int k = k0 + tile_B::get_j(l);
393
 
394
- tile_K[j_cwd*D2_padded + k] = B.x[l];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  }
396
  }
397
 
398
- const int j_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // j combine write meta offset
399
- const int j_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
400
- const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
 
401
 
402
- if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
403
- // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
404
- ((float2 *) tile_K)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
405
- }
406
 
407
- __syncthreads();
408
 
409
- static_assert(np == 1 || np == 2 || np == 4, "bad np");
410
- if (np == 1) {
411
- // No combination is needed, the meta data can be directly written from registers to VRAM.
412
- if (needs_fixup && threadIdx.x < tile_B::I) {
413
- float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
414
- dstk_fixup_meta[j_cwm] = KQ_cmr;
 
 
 
 
415
  }
416
- if (is_fixup && threadIdx.x < tile_B::I) {
417
- float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
418
- dstk_fixup_meta[j_cwm] = KQ_cmr;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  }
420
- } else if (threadIdx.y % np == 0) {
 
 
 
421
  // Combine the meta data for parallel warps via shared memory.
422
  // Warps with threadIdx.y % np != 0 must NOT return early.
423
  // All threads must return simultaneously to avoid race conditions with work on the next tile.
424
 
425
- float * meta_j = (float *) tile_K + (threadIdx.y*tile_B::I + threadIdx.x)*D2_padded + D/2;
426
 
427
- float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
428
- if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
429
- KQ_cm = meta_j[0];
 
 
 
430
  }
431
 
432
- float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
 
 
 
 
433
  #pragma unroll
434
- for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
 
 
 
435
  KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
436
  }
437
 
438
- const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
439
- float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
440
- if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
441
- KQ_crs = KQ_cms*meta_j[1];
442
  }
 
 
443
  #pragma unroll
444
- for (int offset = np*tile_B::I/2; offset >= tile_B::I; offset >>= 1) {
 
 
 
 
 
 
 
445
  KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
446
  }
447
 
448
  // Write back combined meta data:
449
- if (np*tile_B::I == WARP_SIZE || threadIdx.x < np*tile_B::I) {
450
- *((float2 *) meta_j) = make_float2(KQ_cms, KQ_crs); // Combined KQ max scale + rowsum.
 
 
 
 
451
  }
452
- if (needs_fixup && threadIdx.x < tile_B::I) {
 
 
 
453
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
454
- dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
455
  }
456
- if (is_fixup && threadIdx.x < tile_B::I) {
457
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
458
- dstk_fixup_meta[(threadIdx.y/np)*tile_B::I + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
459
  }
460
  }
461
 
@@ -470,27 +738,32 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
470
 
471
  #pragma unroll
472
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
473
- const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
474
- const int k0_stop = D/2 - (D/2) % (1*stride_k);
475
- const int stride_j = WARP_SIZE / stride_k;
476
 
477
  if (k0_start == k0_stop) {
478
  continue;
479
  }
480
 
481
- if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
482
- break;
483
- }
484
-
485
  #pragma unroll
486
- for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
487
- const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
488
- const int j_tile_K = (j_dst/tile_B::I)*(np*tile_B::I) + j_dst % tile_B::I;
 
 
 
 
 
 
 
 
489
 
490
- if (!is_fixup && jt*ncols + j_dst >= ne01) {
491
  continue;
492
  }
493
- const float * meta_j = (const float *) tile_K + j_tile_K*D2_padded + D/2;
 
494
  #pragma unroll
495
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
496
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -498,8 +771,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
498
  float2 dstk_val = make_float2(0.0f, 0.0f);
499
  #pragma unroll
500
  for (int ip = 0; ip < np; ++ip) {
501
- const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*tile_B::I*D2_padded + 0];
502
- const float2 dstk_val_add = __half22float2(tile_K[(j_tile_K + ip*tile_B::I)*D2_padded + k]);
503
  dstk_val.x += dstk_val_add.x*KQ_crs;
504
  dstk_val.y += dstk_val_add.y*KQ_crs;
505
  }
@@ -511,9 +784,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
511
  }
512
 
513
  if (is_fixup) {
514
- dstk_fixup_data[j_dst*(D/2) + k] = dstk_val;
515
  } else {
516
- dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val;
517
  }
518
  }
519
  }
@@ -528,10 +801,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
528
  #endif // NEW_MMA_AVAILABLE
529
  }
530
 
531
- template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
532
- #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
533
  __launch_bounds__(nwarps*WARP_SIZE, 2)
534
- #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
535
  static __global__ void flash_attn_ext_f16(
536
  const char * __restrict__ Q,
537
  const char * __restrict__ K,
@@ -579,20 +850,23 @@ static __global__ void flash_attn_ext_f16(
579
  return;
580
  }
581
 
582
- static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride");
583
 
584
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
585
 
586
- const int stride_Q = nb01 / sizeof(float2);
 
587
  const int stride_KV = nb11 / sizeof(half2);
588
- const int stride_mask = nb31 / sizeof(half);
 
 
 
589
 
590
- const int iter_k = ne11 / KQ_stride;
591
- const int iter_j = (ne01 + (ncols - 1)) / ncols;
592
 
593
  // kbc == k block continuous, current index in continuous ijk space.
594
- int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
595
- const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
596
 
597
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
598
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -605,25 +879,28 @@ static __global__ void flash_attn_ext_f16(
605
  const int channel = kbc / (iter_k*iter_j);
606
  const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
607
 
608
- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
609
- const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
610
- const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
611
- const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
612
- float2 * dstk = ((float2 *) dst) + channel*(D/2);
613
 
614
- const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
 
 
 
615
 
616
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
617
  if (kb0_start == 0) {
618
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
619
- flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
620
- (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
621
- ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
622
  } else {
623
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
624
- flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
625
- (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
626
- ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
627
  }
628
 
629
  kbc += iter_k;
@@ -640,39 +917,46 @@ static __global__ void flash_attn_ext_f16(
640
  const int channel = kbc / (iter_k*iter_j);
641
  const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
642
 
643
- const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
644
- const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
645
- const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
646
- const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
647
- float2 * dstk = ((float2 *) dst) + channel*(D/2);
 
 
648
 
649
- const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
 
650
 
651
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
652
  constexpr bool needs_fixup = false;
653
- flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
654
- (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
655
- ne01, ne02, stride_Q, stride_KV, stride_mask, jt, kb0_start, kb0_stop);
656
  }
657
 
658
- template <int D, int cols_per_block>
659
  void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
660
- typedef tile<16, 8, half2> tile_A;
661
- typedef tile< 8, 8, half2> tile_B;
 
 
 
662
 
663
- static_assert(D % tile_B::J == 0, "bad D");
664
- static_assert(cols_per_block % tile_B::I == 0, "bad cols_per_block");
665
 
666
  const ggml_tensor * KQV = dst;
667
- const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
 
 
668
 
669
- constexpr int KQ_stride = D <= 128 ? 64 : 32;
670
- constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
671
- cols_per_block/tile_B::J * KQ_stride/tile_A::I : (cols_per_block <= 8 ? 4 : 8);
672
 
673
- const int nrows_KQ = cp_async_available(cc) ? 2*KQ_stride : KQ_stride;
674
- const int nrows_combine = nwarps*tile_B::J;
675
- const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
676
 
677
  float logit_softcap;
678
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
@@ -680,42 +964,58 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
680
  fattn_kernel_t fattn_kernel;
681
  if (logit_softcap == 0.0f) {
682
  constexpr bool use_logit_softcap = false;
683
- fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
684
  } else {
685
  constexpr bool use_logit_softcap = true;
686
- fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
687
  }
688
- launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
 
689
  }
690
 
691
- #define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \
 
692
  template void ggml_cuda_flash_attn_ext_mma_f16_case \
693
- <D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
694
-
695
- extern DECL_FATTN_MMA_F16_CASE( 64, 8);
696
- extern DECL_FATTN_MMA_F16_CASE( 80, 8);
697
- extern DECL_FATTN_MMA_F16_CASE( 96, 8);
698
- extern DECL_FATTN_MMA_F16_CASE(112, 8);
699
- extern DECL_FATTN_MMA_F16_CASE(128, 8);
700
- extern DECL_FATTN_MMA_F16_CASE(256, 8);
701
-
702
- extern DECL_FATTN_MMA_F16_CASE( 64, 16);
703
- extern DECL_FATTN_MMA_F16_CASE( 80, 16);
704
- extern DECL_FATTN_MMA_F16_CASE( 96, 16);
705
- extern DECL_FATTN_MMA_F16_CASE(112, 16);
706
- extern DECL_FATTN_MMA_F16_CASE(128, 16);
707
- extern DECL_FATTN_MMA_F16_CASE(256, 16);
708
-
709
- extern DECL_FATTN_MMA_F16_CASE( 64, 32);
710
- extern DECL_FATTN_MMA_F16_CASE( 80, 32);
711
- extern DECL_FATTN_MMA_F16_CASE( 96, 32);
712
- extern DECL_FATTN_MMA_F16_CASE(112, 32);
713
- extern DECL_FATTN_MMA_F16_CASE(128, 32);
714
- extern DECL_FATTN_MMA_F16_CASE(256, 32);
715
-
716
- extern DECL_FATTN_MMA_F16_CASE( 64, 64);
717
- extern DECL_FATTN_MMA_F16_CASE( 80, 64);
718
- extern DECL_FATTN_MMA_F16_CASE( 96, 64);
719
- extern DECL_FATTN_MMA_F16_CASE(112, 64);
720
- extern DECL_FATTN_MMA_F16_CASE(128, 64);
721
- extern DECL_FATTN_MMA_F16_CASE(256, 64);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  using namespace ggml_cuda_mma;
7
 
8
+ typedef tile<16, 8, half2> tile_A;
9
+ typedef tile< 8, 8, half2> tile_B;
10
+ typedef tile<16, 8, half2> tile_B_16;
11
+ typedef tile<16, 8, float> tile_C_KQ;
12
+ typedef tile<16, 16, float> tile_C_KQ_16;
13
+ typedef tile<16, 4, half2> tile_C_VKQ;
14
+ typedef tile<16, 8, half2> tile_C_VKQ_16;
15
+
16
+ template<int D, int nwarps, int KQ_per_iter>
17
  static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
18
  const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
19
  constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
 
30
  constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
31
  constexpr int stride_i = WARP_SIZE / chunks_per_row;
32
  #pragma unroll
33
+ for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
34
  const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
35
  const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
36
 
 
43
 
44
  // If D is not a power of 2, the rest is loaded synchronously.
45
  // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
46
+ static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds");
47
  #pragma unroll
48
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
49
  const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
 
55
  }
56
 
57
  #pragma unroll
58
+ for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
59
  const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
60
 
61
  #pragma unroll
 
68
  }
69
  }
70
 
71
+ template<int ncols1, int nwarps, int KQ_per_iter>
72
+ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
73
+ const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
74
+ static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter");
75
+ #ifdef CP_ASYNC_AVAILABLE
76
+ constexpr int preload = KQ_per_iter * sizeof(half);
77
+ constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter;
78
+ constexpr int stride_j = nwarps * cols_per_warp;
79
+
80
+ const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
81
+
82
+ #pragma unroll
83
+ for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
84
+ const int j = j0 + threadIdx.y*cols_per_warp +
85
+ (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8));
86
+
87
+ if (j0 + stride_j > ncols1 && j >= ncols1) {
88
+ break;
89
+ }
90
+
91
+ const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8));
92
+
93
+ cp_async_cg_16<preload>(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
94
+ }
95
+ #else
96
+ constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter;
97
+ constexpr int stride_j = nwarps * cols_per_warp;
98
+ #pragma unroll
99
+ for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
100
+ const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2));
101
+
102
+ if (j0 + stride_j > ncols1 && j >= ncols1) {
103
+ break;
104
+ }
105
+
106
+ const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2);
107
+
108
+ tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i];
109
+ }
110
+ #endif // CP_ASYNC_AVAILABLE
111
+ }
112
+
113
+ template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
114
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
115
  const float2 * const __restrict__ Q_f2,
116
  const half2 * const __restrict__ K_h2,
117
  const half2 * const __restrict__ V_h2,
118
+ const half2 * const __restrict__ mask_h2,
119
  float2 * const __restrict__ dstk,
120
  float2 * const __restrict__ dstk_fixup,
121
  const float scale,
 
123
  const float logit_softcap,
124
  const int ne01,
125
  const int ne02,
 
126
  const int stride_KV,
127
  const int stride_mask,
128
  const int jt,
129
  half2 * const __restrict__ tile_K,
130
  half2 * const __restrict__ tile_V,
131
+ half2 * const __restrict__ tile_mask,
132
  const tile_B * const __restrict__ Q_B,
133
  tile_C_VKQ * const __restrict__ VKQ_C,
134
+ float * const __restrict__ KQ_max,
135
+ float * const __restrict__ KQ_rowsum,
136
  const int kb0) {
137
  #ifdef NEW_MMA_AVAILABLE
138
+ constexpr int cols_per_warp = ntiles * tile_B::I;
139
+ constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
140
+ constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
141
+ constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
142
+
143
+ const int k_VKQ_0 = kb0 * KQ_per_iter;
144
+ tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles];
145
 
146
+ // Use wide variants of tiles if ntiles >= 2.
147
+ tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
148
+ tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
149
+ tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
150
 
151
  #ifdef CP_ASYNC_AVAILABLE
152
  cp_async_wait_all();
153
  __syncthreads();
154
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
155
  #else
156
+ if (ncols2 > 1 || mask_h2) {
157
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
158
+ }
159
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
160
  __syncthreads();
161
  #endif // CP_ASYNC_AVAILABLE
162
 
163
  // Calculate tile of KQ:
164
  #pragma unroll
165
+ for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) {
166
  const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
167
  #pragma unroll
168
  for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
169
  tile_A K_A;
170
  load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
171
+ if (ntiles == 1) {
172
+ mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
173
+ } else {
174
+ #pragma unroll
175
+ for (int t = 0; t < ntiles/2; ++t) {
176
+ // Wide version of KQ_C is column-major => swap A and B.
177
+ mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
178
+ }
179
+ }
180
  }
181
  }
182
 
 
185
  #endif // CP_ASYNC_AVAILABLE
186
 
187
  if (use_logit_softcap) {
188
+ static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
189
  #pragma unroll
190
+ for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) {
191
  #pragma unroll
192
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
193
  KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
 
195
  }
196
  }
197
 
198
+ float KQ_max_new[cols_per_thread];
199
+ #pragma unroll
200
+ for (int col = 0; col < cols_per_thread; ++col) {
201
+ KQ_max_new[col] = KQ_max[col];
202
+ }
203
+ float KQ_rowsum_add[cols_per_thread] = {0.0f};
204
+
205
+ if (ntiles == 1) {
206
+ if (ncols2 > 1 || mask_h2) {
207
+ #pragma unroll
208
+ for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) {
209
+ const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
210
+ #pragma unroll
211
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
212
+ const int i = i0 + tile_C_KQ::get_i(l);
213
+ const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
214
+
215
+ KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
216
+ __half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]);
217
+ }
218
+ }
219
+ }
220
+
221
+ // Calculate softmax for each KQ column using the current max. value.
222
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
223
+ static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
224
+ #pragma unroll
225
+ for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
226
+ #pragma unroll
227
+ for (int l = 0; l < tile_C_KQ::ne; ++l) {
228
+ KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
229
+ }
230
+ }
231
+
232
+ // Values per KQ column are spread across 8 threads, does not need full warp reduce:
233
  #pragma unroll
234
+ for (int col = 0; col < cols_per_thread; ++col) {
235
+ #pragma unroll
236
+ for (int offset = 16; offset >= 4; offset >>= 1) {
237
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
238
+ }
239
+ }
240
+
241
+ static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
242
+
243
+ #pragma unroll
244
+ for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
245
  #pragma unroll
246
  for (int l = 0; l < tile_C_KQ::ne; ++l) {
247
+ KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
 
248
 
249
+ KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
250
+ }
251
+ }
252
+ } else { // ntiles > 1
253
+ if (ncols2 > 1 || mask_h2) {
254
+ #pragma unroll
255
+ for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) {
256
+ const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
257
+ #pragma unroll
258
+ for (int t = 0; t < ntiles/2; ++t) {
259
+ #pragma unroll
260
+ for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
261
+ const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
262
+ const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
263
+
264
+ const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]);
265
+ const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
266
+ KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
267
+ KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
268
+ }
269
+ }
270
  }
271
  }
 
272
 
273
+ // Calculate softmax for each KQ column using the current max. value.
274
+ // The divisor is stored in KQ_rowsum and will be applied at the end.
275
+ static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
276
+ #pragma unroll
277
+ for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
278
  #pragma unroll
279
+ for (int t = 0; t < ntiles/2; ++t) {
280
  #pragma unroll
281
+ for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
282
+ const int KQ_index = 2*t + (l/2) % 2;
283
+ KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
284
+ }
285
+ }
286
  }
 
287
 
288
+ // Values per KQ column are spread across 4 threads, does not need full warp reduce:
289
  #pragma unroll
290
+ for (int col = 0; col < cols_per_thread; ++col) {
291
+ #pragma unroll
292
+ for (int offset = 2; offset >= 1; offset >>= 1) {
293
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
294
+ }
295
+ }
296
 
297
+ static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size");
298
+ #pragma unroll
299
+ for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
300
  #pragma unroll
301
+ for (int t = 0; t < ntiles/2; ++t) {
302
  #pragma unroll
303
+ for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
304
+ const int KQ_index = 2*t + (l/2) % 2;
 
 
305
 
306
+ KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
307
+
308
+ KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
309
+ }
310
  }
311
  }
312
  }
313
 
314
  {
315
+ float KQ_max_scale[cols_per_thread];
316
+ #pragma unroll
317
+ for (int col = 0; col < cols_per_thread; ++col) {
318
+ KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
319
+ KQ_max[col] = KQ_max_new[col];
320
 
321
+ // Scale previous KQ_rowsum to account for a potential increase in KQ_max:
322
+ KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
323
+ }
324
 
325
+ if (ntiles == 1) {
326
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
327
  #pragma unroll
328
+ for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
329
  #pragma unroll
330
+ for (int l = 0; l < tile_C_VKQ::ne; ++l) {
331
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
332
+ }
333
+ }
334
+ } else {
335
+ #pragma unroll
336
+ for (int col = 0; col < cols_per_thread; ++col) {
337
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
338
+ #pragma unroll
339
+ for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
340
+ #pragma unroll
341
+ for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
342
+ VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
343
+ }
344
+ }
345
  }
346
  }
347
  }
348
 
349
  // Convert KQ C tiles into B tiles for VKQ calculation:
350
+ tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles];
351
+ tile_B_16 * B_16 = (tile_B_16 *) B;
352
+ static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size");
353
+ if (ntiles == 1) {
354
  #pragma unroll
355
+ for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) {
356
+ B[k] = get_transposed(get_half2(KQ_C[k]));
357
+ }
358
+ } else {
359
+ for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) {
360
+ #pragma unroll
361
+ for (int t = 0; t < ntiles/2; ++t) {
362
+ B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
363
+ }
364
+ }
365
  }
366
 
367
  #ifdef CP_ASYNC_AVAILABLE
368
+ // Preload K tile for next iteration:
369
  cp_async_wait_all();
370
  __syncthreads();
371
  if (!last_iter) {
372
+ if (ncols2 > 1 || mask_h2) {
373
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask);
374
+ }
375
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV);
376
  }
377
  #else
378
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
379
  __syncthreads();
380
  #endif // CP_ASYNC_AVAILABLE
381
 
382
  // Calculate VKQ tile:
383
  #pragma unroll
384
  for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
385
+ static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size");
386
  #pragma unroll
387
+ for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) {
388
  const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
389
 
390
  tile_A A;
391
  load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
392
+ if (ntiles == 1) {
393
+ mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
394
+ } else {
395
+ #pragma unroll
396
+ for (int t = 0; t < ntiles/2; ++t) {
397
+ // Wide version of VKQ_C is column-major => swap A and B.
398
+ mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
399
+ }
400
+ }
401
  }
402
  }
403
 
 
410
  #endif // NEW_MMA_AVAILABLE
411
  }
412
 
413
+ template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
414
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
415
  const float2 * const __restrict__ Q_f2,
416
  const half2 * const __restrict__ K_h2,
417
  const half2 * const __restrict__ V_h2,
418
+ const half2 * const __restrict__ mask_h2,
419
  float2 * const __restrict__ dstk,
420
  float2 * const __restrict__ dstk_fixup,
421
  const float scale,
 
423
  const float logit_softcap,
424
  const int ne01,
425
  const int ne02,
426
+ const int stride_Q1,
427
+ const int stride_Q2,
428
  const int stride_KV,
429
  const int stride_mask,
430
  const int jt,
 
433
  #ifdef NEW_MMA_AVAILABLE
434
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
435
 
436
+ constexpr int ncols = ncols1 * ncols2;
437
+ constexpr int cols_per_warp = ntiles * tile_B::I;
438
+ constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
439
+ constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
440
+
441
+ static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
442
 
443
+ static_assert(D % nwarps == 0, "bad D");
444
+ static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter");
445
 
446
  constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
447
 
448
+ // Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements:
449
  extern __shared__ half2 tile_K[];
450
  #ifdef CP_ASYNC_AVAILABLE
451
+ half2 * tile_V = tile_K + KQ_per_iter*D2_padded;
452
  #else
453
+ half2 * tile_V = tile_K;
454
  #endif // CP_ASYNC_AVAILABLE
455
+ half2 * tile_mask = tile_V + KQ_per_iter*D2_padded;
456
 
457
+ tile_B Q_B[D/(2*tile_B::J) * ntiles];
458
+ tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles];
459
 
460
+ tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
461
+ tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
462
+
463
+ float KQ_rowsum[cols_per_thread] = {0.0f};
464
+ float KQ_max[cols_per_thread];
465
+ #pragma unroll
466
+ for (int col = 0; col < cols_per_thread; ++col) {
467
+ KQ_max[col] = -FLT_MAX/2.0f;
468
+ }
469
 
470
  // Temporarily load Q data into tile_K, will be loaded into registers afterwards.
471
  // The loading is done with decreasing granularity for D for better memory bandwidth.
472
  const half2 scale_h2 = make_half2(scale, scale);
473
  #pragma unroll
474
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
475
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
476
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
477
+ const int stride_jc = WARP_SIZE / stride_k;
478
 
479
  if (k0_start == k0_stop) {
480
  continue;
481
  }
482
 
 
 
 
 
483
  #pragma unroll
484
+ for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
485
+ const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
486
+
487
+ if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
488
+ break;
489
+ }
490
+
491
+ const int j = jc / ncols2;
492
+ const int c = jc % ncols2;
493
 
494
+ if (jt*ncols1 + j < ne01) {
495
  #pragma unroll
496
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
497
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
498
 
499
+ const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
500
+ tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
501
  }
502
  } else {
503
  #pragma unroll
504
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
505
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
506
 
507
+ tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f);
508
  }
509
  }
510
  }
 
513
  __syncthreads();
514
 
515
  {
516
+ const int j0 = (threadIdx.y / np) * cols_per_warp;
517
 
518
  #pragma unroll
519
  for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
520
+ if (ntiles == 1) {
521
+ load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
522
+ } else {
523
+ #pragma unroll
524
+ for (int t = 0; t < ntiles/2; ++t) {
525
+ load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
526
+ tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded);
527
+ }
528
+ }
529
  }
530
  }
531
 
532
  __syncthreads();
533
 
534
+ // Preload mask and K data for first iteration when using cp_async:
535
  #ifdef CP_ASYNC_AVAILABLE
536
+ if (ncols2 > 1 || mask_h2) {
537
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask);
538
+ }
539
+ flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV);
540
  #endif // CP_ASYNC_AVAILABLE
541
 
542
  // Iterate over ne11 == previous tokens:
543
  for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
544
  constexpr bool last_iter = false;
545
+ flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
546
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
547
+ ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
548
  }
549
  { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
550
  constexpr bool last_iter = true;
551
+ flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
552
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
553
+ ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
554
  }
555
 
556
  // With cp_async there is no __syncthreads at the end of the iter,
557
  // there can be a race condition on shared memory access for combining/writing back results.
558
  #ifdef CP_ASYNC_AVAILABLE
559
+ if (nwarps*cols_per_warp > KQ_per_iter) {
560
  __syncthreads();
561
  }
562
  #endif // CP_ASYNC_AVAILABLE
563
 
564
  // Finally, sum up partial KQ rowsums.
565
+ // The partial sums are spread across 8/4 threads each, does not need full reduce.
566
+ {
567
+ constexpr int offset_first = ntiles == 1 ? 16 : 2;
568
+ constexpr int offset_last = ntiles == 1 ? 4 : 1;
569
+ #pragma unroll
570
+ for (int col = 0; col < cols_per_thread; ++col) {
571
  #pragma unroll
572
+ for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
573
+ KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
574
+ }
575
+ }
576
  }
577
 
578
  // Write VKQ accumulators to shared memory in column-major format.
579
  // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
580
  // Also for np > 1 the combination is done via these values in shared memory.
581
+ if (ntiles == 1) {
582
+ const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
583
  #pragma unroll
584
+ for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
585
+ const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
586
 
587
  #pragma unroll
588
+ for (int l = 0; l < tile_B::ne; ++l) {
589
+ const int k = k0 + tile_B::get_j(l);
590
 
591
+ tile_K[jc_cwd*D2_padded + k] = B.x[l];
592
+ }
593
+ }
594
+ } else {
595
+ #pragma unroll
596
+ for (int t = 0; t < ntiles/2; ++t) {
597
+ const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
598
+ #pragma unroll
599
+ for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) {
600
+ #pragma unroll
601
+ for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
602
+ const int j = j0 + tile_C_VKQ_16::get_i(l);
603
+ const int k = k0 + tile_C_VKQ_16::get_j(l);
604
+
605
+ tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
606
+ }
607
+ }
608
  }
609
  }
610
 
611
+ if constexpr (ntiles == 1) {
612
+ const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
613
+ const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
614
+ const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
615
 
616
+ if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
617
+ // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
618
+ ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
619
+ }
620
 
621
+ __syncthreads();
622
 
623
+ if (np == 1) {
624
+ // No combination is needed, the meta data can be directly written from registers to VRAM.
625
+ if (needs_fixup && threadIdx.x < tile_B::I) {
626
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
627
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
628
+ }
629
+ if (is_fixup && threadIdx.x < tile_B::I) {
630
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
631
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
632
+ }
633
  }
634
+ } else {
635
+ static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
636
+ const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
637
+ + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
638
+ + tile_C_VKQ_16::get_i(threadIdx.x % 4);
639
+ const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
640
+
641
+ if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
642
+ // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
643
+ ((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
644
+ }
645
+
646
+ __syncthreads();
647
+
648
+ if (np == 1) {
649
+ // No combination is needed, the meta data can be directly written from registers to VRAM.
650
+ if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
651
+ float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
652
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
653
+ }
654
+ if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
655
+ float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
656
+ dstk_fixup_meta[jc_cwm] = KQ_cmr;
657
+ }
658
  }
659
+ }
660
+
661
+ static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
662
+ if (np > 1 && threadIdx.y % np == 0) {
663
  // Combine the meta data for parallel warps via shared memory.
664
  // Warps with threadIdx.y % np != 0 must NOT return early.
665
  // All threads must return simultaneously to avoid race conditions with work on the next tile.
666
 
667
+ constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
668
 
669
+ const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
670
+ float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4;
671
+ float2 meta[nmeta];
672
+ #pragma unroll
673
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
674
+ meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2];
675
  }
676
 
677
+ float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
678
+ #pragma unroll
679
+ for (int imeta = 1; imeta < nmeta; ++imeta) {
680
+ KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
681
+ }
682
  #pragma unroll
683
+ for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
684
+ if (offset >= WARP_SIZE) {
685
+ continue;
686
+ }
687
  KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
688
  }
689
 
690
+ float KQ_cms[nmeta]; // KQ combine max scale per warp.
691
+ #pragma unroll
692
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
693
+ KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
694
  }
695
+
696
+ float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
697
  #pragma unroll
698
+ for (int imeta = 1; imeta < nmeta; ++imeta) {
699
+ KQ_crs += KQ_cms[imeta]*meta[imeta].y;
700
+ }
701
+ #pragma unroll
702
+ for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
703
+ if (offset >= WARP_SIZE) {
704
+ continue;
705
+ }
706
  KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
707
  }
708
 
709
  // Write back combined meta data:
710
+ #pragma unroll
711
+ for (int imeta = 0; imeta < nmeta; ++imeta) {
712
+ if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
713
+ // Combined KQ max scale + rowsum.
714
+ meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs);
715
+ }
716
  }
717
+
718
+ // Combined KQ max + rowsum.
719
+ static_assert(cols_per_warp <= WARP_SIZE);
720
+ if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
721
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
722
+ dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
723
  }
724
+ if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
725
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
726
+ dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
727
  }
728
  }
729
 
 
738
 
739
  #pragma unroll
740
  for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
741
+ const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
742
+ const int k0_stop = D/2 - (D/2) % (1*stride_k);
743
+ const int stride_jc = WARP_SIZE / stride_k;
744
 
745
  if (k0_start == k0_stop) {
746
  continue;
747
  }
748
 
 
 
 
 
749
  #pragma unroll
750
+ for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
751
+ const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
752
+
753
+ if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
754
+ break;
755
+ }
756
+
757
+ const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
758
+
759
+ const int j_dst = jc_dst / ncols2;
760
+ const int c_dst = jc_dst % ncols2;
761
 
762
+ if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
763
  continue;
764
  }
765
+
766
+ const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2;
767
  #pragma unroll
768
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
769
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
 
771
  float2 dstk_val = make_float2(0.0f, 0.0f);
772
  #pragma unroll
773
  for (int ip = 0; ip < np; ++ip) {
774
+ const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0];
775
+ const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]);
776
  dstk_val.x += dstk_val_add.x*KQ_crs;
777
  dstk_val.y += dstk_val_add.y*KQ_crs;
778
  }
 
784
  }
785
 
786
  if (is_fixup) {
787
+ dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val;
788
  } else {
789
+ dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val;
790
  }
791
  }
792
  }
 
801
  #endif // NEW_MMA_AVAILABLE
802
  }
803
 
804
+ template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap>
 
805
  __launch_bounds__(nwarps*WARP_SIZE, 2)
 
806
  static __global__ void flash_attn_ext_f16(
807
  const char * __restrict__ Q,
808
  const char * __restrict__ K,
 
850
  return;
851
  }
852
 
853
+ static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter");
854
 
855
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
856
 
857
+ const int stride_Q1 = nb01 / sizeof(float2);
858
+ const int stride_Q2 = nb02 / sizeof(float2);
859
  const int stride_KV = nb11 / sizeof(half2);
860
+ const int stride_mask = nb31 / sizeof(half2);
861
+
862
+ const int iter_k = ne11 / FATTN_KQ_STRIDE;
863
+ const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
864
 
865
+ constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice.
 
866
 
867
  // kbc == k block continuous, current index in continuous ijk space.
868
+ int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
869
+ const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
870
 
871
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
872
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
 
879
  const int channel = kbc / (iter_k*iter_j);
880
  const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
881
 
882
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
883
+ const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
884
+ const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
885
+ const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
886
+ float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
887
 
888
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
889
+
890
+ const int kb0_start_kernel = kb0_start * kb_niter;
891
+ const int kb0_stop_kernel = kb0_stop * kb_niter;
892
 
893
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
894
  if (kb0_start == 0) {
895
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
896
+ flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
897
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
898
+ ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
899
  } else {
900
  constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
901
+ flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
902
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
903
+ ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
904
  }
905
 
906
  kbc += iter_k;
 
917
  const int channel = kbc / (iter_k*iter_j);
918
  const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
919
 
920
+ const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
921
+ const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
922
+ const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
923
+ const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
924
+ float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
925
+
926
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
927
 
928
+ const int kb0_start_kernel = kb0_start * kb_niter;
929
+ const int kb0_stop_kernel = kb0_stop * kb_niter;
930
 
931
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
932
  constexpr bool needs_fixup = false;
933
+ flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
934
+ (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
935
+ ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
936
  }
937
 
938
+ template <int D, int ncols1, int ncols2>
939
  void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
940
+ constexpr int ncols = ncols1 * ncols2;
941
+ constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32;
942
+ constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4;
943
+ constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4);
944
+ constexpr int cols_per_warp = ntiles * tile_B::I;
945
 
946
+ static_assert(D % tile_B::J == 0, "bad D");
947
+ static_assert(ncols % cols_per_warp == 0, "bad ncols");
948
 
949
  const ggml_tensor * KQV = dst;
950
+ const int id = ggml_cuda_get_device();
951
+ const int cc = ggml_cuda_info().devices[id].cc;
952
+
953
+ const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter;
954
 
955
+ const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half);
956
+ const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half);
957
+ const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half);
958
 
959
+ const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine);
 
 
960
 
961
  float logit_softcap;
962
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
 
964
  fattn_kernel_t fattn_kernel;
965
  if (logit_softcap == 0.0f) {
966
  constexpr bool use_logit_softcap = false;
967
+ fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
968
  } else {
969
  constexpr bool use_logit_softcap = true;
970
+ fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
971
  }
972
+
973
+ launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
974
  }
975
 
976
+
977
+ #define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \
978
  template void ggml_cuda_flash_attn_ext_mma_f16_case \
979
+ <D, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
980
+
981
+ #define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \
982
+ extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \
983
+ extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \
984
+ extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
985
+ extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
986
+
987
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8);
988
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8);
989
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8);
990
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8);
991
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8);
992
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8);
993
+
994
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16);
995
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16);
996
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16);
997
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16);
998
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16);
999
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16);
1000
+
1001
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32);
1002
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32);
1003
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32);
1004
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32);
1005
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32);
1006
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32);
1007
+
1008
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64);
1009
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64);
1010
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64);
1011
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64);
1012
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64);
1013
+ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64);
1014
+
1015
+ // Kernels with ncols == 128 are only 4% faster due to register pressure.
1016
+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
1017
+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
1018
+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
1019
+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
1020
+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
1021
+ // DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -302,14 +302,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
302
  constexpr int nwarps = 8;
303
  constexpr size_t nbytes_shared = 0;
304
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
305
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
306
  } break;
307
  case 128: {
308
  constexpr int D = 128;
309
  constexpr int nwarps = 8;
310
  constexpr size_t nbytes_shared = 0;
311
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
312
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
313
  } break;
314
  default: {
315
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
302
  constexpr int nwarps = 8;
303
  constexpr size_t nbytes_shared = 0;
304
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
305
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
306
  } break;
307
  case 128: {
308
  constexpr int D = 128;
309
  constexpr int nwarps = 8;
310
  constexpr size_t nbytes_shared = 0;
311
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
312
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
313
  } break;
314
  default: {
315
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
296
  constexpr int nwarps = 8;
297
  constexpr size_t nbytes_shared = 0;
298
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
299
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
300
  } break;
301
  case 128: {
302
  constexpr int D = 128;
303
  constexpr int nwarps = 8;
304
  constexpr size_t nbytes_shared = 0;
305
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
306
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
307
  } break;
308
  default: {
309
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
 
296
  constexpr int nwarps = 8;
297
  constexpr size_t nbytes_shared = 0;
298
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
299
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
300
  } break;
301
  case 128: {
302
  constexpr int D = 128;
303
  constexpr int nwarps = 8;
304
  constexpr size_t nbytes_shared = 0;
305
  fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
306
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
307
  } break;
308
  default: {
309
  GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
310
  constexpr bool need_f16_K = D != 128;
311
  constexpr bool need_f16_V = D != 128 && D != 64;
312
  constexpr size_t nbytes_shared = 0;
313
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
314
  }
315
 
316
  template <int D, ggml_type type_K, ggml_type type_V>
 
310
  constexpr bool need_f16_K = D != 128;
311
  constexpr bool need_f16_V = D != 128 && D != 64;
312
  constexpr size_t nbytes_shared = 0;
313
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
314
  }
315
 
316
  template <int D, ggml_type type_K, ggml_type type_V>
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
290
  constexpr bool need_f16_K = D != 128;
291
  constexpr bool need_f16_V = D != 128 && D != 64;
292
  constexpr size_t nbytes_shared = 0;
293
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
294
  }
295
 
296
  template <int D, ggml_type type_K, ggml_type type_V>
 
290
  constexpr bool need_f16_K = D != 128;
291
  constexpr bool need_f16_V = D != 128 && D != 64;
292
  constexpr size_t nbytes_shared = 0;
293
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
294
  }
295
 
296
  template <int D, ggml_type type_K, ggml_type type_V>
ggml/src/ggml-cuda/fattn-wmma-f16.cu CHANGED
@@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
478
  fattn_kernel = flash_attn_ext_f16<
479
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
480
  }
481
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
482
  return;
483
  }
484
  if (2*blocks_num_pb1 < 2*nsm) {
@@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
493
  fattn_kernel = flash_attn_ext_f16<
494
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
495
  }
496
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
497
  return;
498
  }
499
  constexpr int parallel_blocks = 1;
@@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
507
  fattn_kernel = flash_attn_ext_f16<
508
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
509
  }
510
- launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
511
  }
512
 
513
  void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
 
478
  fattn_kernel = flash_attn_ext_f16<
479
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
480
  }
481
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
482
  return;
483
  }
484
  if (2*blocks_num_pb1 < 2*nsm) {
 
493
  fattn_kernel = flash_attn_ext_f16<
494
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
495
  }
496
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
497
  return;
498
  }
499
  constexpr int parallel_blocks = 1;
 
507
  fattn_kernel = flash_attn_ext_f16<
508
  D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
509
  }
510
+ launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
511
  }
512
 
513
  void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml/src/ggml-cuda/fattn.cu CHANGED
@@ -8,28 +8,50 @@
8
  #include "fattn-wmma-f16.cuh"
9
  #include "fattn.cuh"
10
 
11
- template <int cols_per_block>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
  const ggml_tensor * Q = dst->src[0];
14
 
15
  switch (Q->ne[0]) {
16
  case 64:
17
- ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
18
  break;
19
  case 80:
20
- ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
21
  break;
22
  case 96:
23
- ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
24
  break;
25
  case 112:
26
- ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
27
  break;
28
  case 128:
29
- ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
30
  break;
31
  case 256:
32
- ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
33
  break;
34
  default:
35
  GGML_ABORT("fatal error");
@@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
38
  }
39
 
40
  static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
41
- const ggml_tensor * Q = dst->src[0];
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- if (Q->ne[1] <= 8) {
44
  ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
45
  return;
46
  }
47
 
48
- if (Q->ne[1] <= 16) {
49
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
50
  return;
51
  }
52
 
53
- if (Q->ne[1] <= 32) {
54
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
55
  return;
56
  }
57
 
58
- ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
59
  }
60
 
61
  #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
@@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
209
  }
210
 
211
  void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
212
- const ggml_tensor * KQV = dst;
213
- const ggml_tensor * Q = dst->src[0];
 
 
 
214
 
215
  ggml_cuda_set_device(ctx.device);
216
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
@@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
252
  return;
253
  }
254
 
255
- if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0) {
 
 
 
256
  if (prec == GGML_PREC_DEFAULT) {
257
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
258
  return;
 
8
  #include "fattn-wmma-f16.cuh"
9
  #include "fattn.cuh"
10
 
11
+ template <int D, int ncols2>
12
+ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
13
+ const ggml_tensor * Q = dst->src[0];
14
+
15
+ if (Q->ne[1] <= 8/ncols2) {
16
+ ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
17
+ return;
18
+ }
19
+
20
+ if (Q->ne[1] <= 16/ncols2) {
21
+ ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
22
+ return;
23
+ }
24
+
25
+ if (Q->ne[1] <= 32/ncols2) {
26
+ ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
27
+ return;
28
+ }
29
+
30
+ ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
31
+ }
32
+
33
+ template <int ncols2>
34
  static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
35
  const ggml_tensor * Q = dst->src[0];
36
 
37
  switch (Q->ne[0]) {
38
  case 64:
39
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
40
  break;
41
  case 80:
42
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
43
  break;
44
  case 96:
45
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
46
  break;
47
  case 112:
48
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
49
  break;
50
  case 128:
51
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
52
  break;
53
  case 256:
54
+ ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
55
  break;
56
  default:
57
  GGML_ABORT("fatal error");
 
60
  }
61
 
62
  static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
63
+ const ggml_tensor * KQV = dst;
64
+ const ggml_tensor * Q = dst->src[0];
65
+ const ggml_tensor * K = dst->src[1];
66
+ const ggml_tensor * mask = dst->src[3];
67
+
68
+ float max_bias = 0.0f;
69
+ memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
70
+
71
+ const float use_gqa_opt = mask && max_bias == 0.0f;
72
+
73
+ GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
74
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
75
 
76
+ if (use_gqa_opt && gqa_ratio % 8 == 0) {
77
  ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
78
  return;
79
  }
80
 
81
+ if (use_gqa_opt && gqa_ratio == 4) {
82
+ ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
83
  return;
84
  }
85
 
86
+ if (use_gqa_opt && gqa_ratio == 2) {
87
+ ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
88
  return;
89
  }
90
 
91
+ ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
92
  }
93
 
94
  #define FATTN_VEC_F16_CASE(D, type_K, type_V) \
 
242
  }
243
 
244
  void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
245
+ const ggml_tensor * KQV = dst;
246
+ const ggml_tensor * Q = dst->src[0];
247
+ const ggml_tensor * K = dst->src[1];
248
+ const ggml_tensor * V = dst->src[2];
249
+ const ggml_tensor * mask = dst->src[3];
250
 
251
  ggml_cuda_set_device(ctx.device);
252
  const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
 
288
  return;
289
  }
290
 
291
+ const int gqa_ratio = Q->ne[2] / K->ne[2];
292
+ const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
293
+ K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
294
+ if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
295
  if (prec == GGML_PREC_DEFAULT) {
296
  ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
297
  return;
ggml/src/ggml-cuda/mma.cuh CHANGED
@@ -73,6 +73,8 @@ namespace ggml_cuda_mma {
73
  return threadIdx.x / 4;
74
  } else if constexpr (I == 16 && J == 8) {
75
  return (l / 2) * 8 + threadIdx.x / 4;
 
 
76
  } else {
77
  static_assert(I == -1 && J == -1, "template specialization not implemented");
78
  }
@@ -85,6 +87,8 @@ namespace ggml_cuda_mma {
85
  return 4 * l + threadIdx.x % 4;
86
  } else if constexpr (I == 16 && J == 8) {
87
  return 2 * (threadIdx.x % 4) + l % 2;
 
 
88
  } else {
89
  static_assert(I == -1 && J == -1, "template specialization not implemented");
90
  }
@@ -289,6 +293,42 @@ namespace ggml_cuda_mma {
289
  #endif // NEW_MMA_AVAILABLE
290
  }
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  static __device__ __forceinline__ void mma(
293
  tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
294
  #ifdef NEW_MMA_AVAILABLE
@@ -316,4 +356,39 @@ namespace ggml_cuda_mma {
316
  #endif // NEW_MMA_AVAILABLE
317
  }
318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  }
 
73
  return threadIdx.x / 4;
74
  } else if constexpr (I == 16 && J == 8) {
75
  return (l / 2) * 8 + threadIdx.x / 4;
76
+ } else if constexpr (I == 16 && J == 16) {
77
+ return ((l / 2) % 2) * 8 + threadIdx.x / 4;
78
  } else {
79
  static_assert(I == -1 && J == -1, "template specialization not implemented");
80
  }
 
87
  return 4 * l + threadIdx.x % 4;
88
  } else if constexpr (I == 16 && J == 8) {
89
  return 2 * (threadIdx.x % 4) + l % 2;
90
+ } else if constexpr (I == 16 && J == 16) {
91
+ return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
92
  } else {
93
  static_assert(I == -1 && J == -1, "template specialization not implemented");
94
  }
 
293
  #endif // NEW_MMA_AVAILABLE
294
  }
295
 
296
+ static __device__ __forceinline__ void mma(
297
+ tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
298
+ #ifdef NEW_MMA_AVAILABLE
299
+ const int * Axi = (const int *) A.x;
300
+ const int * Bxi = (const int *) B.x;
301
+ int * Dxi = (int *) D.x;
302
+ #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
303
+ asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
304
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
305
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
306
+ asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
307
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
308
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
309
+ #else
310
+ // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
311
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
312
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
313
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
314
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
315
+ : "+r"(Dxi[0]), "+r"(Dxi[1])
316
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
317
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
318
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
319
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
320
+ asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
321
+ : "+r"(Dxi[2]), "+r"(Dxi[3])
322
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
323
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
324
+ #else
325
+ GGML_UNUSED(D);
326
+ GGML_UNUSED(A);
327
+ GGML_UNUSED(B);
328
+ NO_DEVICE_CODE;
329
+ #endif // NEW_MMA_AVAILABLE
330
+ }
331
+
332
  static __device__ __forceinline__ void mma(
333
  tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
334
  #ifdef NEW_MMA_AVAILABLE
 
356
  #endif // NEW_MMA_AVAILABLE
357
  }
358
 
359
+ static __device__ __forceinline__ void mma(
360
+ tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
361
+ #ifdef NEW_MMA_AVAILABLE
362
+ const int * Axi = (const int *) A.x;
363
+ const int * Bxi = (const int *) B.x;
364
+ int * Dxi = (int *) D.x;
365
+ #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
366
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
367
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
368
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
369
+ asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
370
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
371
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
372
+ #else
373
+ // On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
374
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
375
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
376
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
377
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
378
+ : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
379
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
380
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
381
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
382
+ : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
383
+ asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
384
+ : "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
385
+ : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
386
+ #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
387
+ #else
388
+ GGML_UNUSED(D);
389
+ GGML_UNUSED(A);
390
+ GGML_UNUSED(B);
391
+ NO_DEVICE_CODE;
392
+ #endif // NEW_MMA_AVAILABLE
393
+ }
394
  }
ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb16.cu → fattn-mma-f16-instance-ncols1_1-ncols2_8.cu} RENAMED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 16);
6
- DECL_FATTN_MMA_F16_CASE(80, 16);
7
- DECL_FATTN_MMA_F16_CASE(96, 16);
8
- DECL_FATTN_MMA_F16_CASE(112, 16);
9
- DECL_FATTN_MMA_F16_CASE(128, 16);
10
- DECL_FATTN_MMA_F16_CASE(256, 16);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 1, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 1, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 1, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 1, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 1, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 1, 8);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 16, 1);
6
+ DECL_FATTN_MMA_F16_CASE(80, 16, 1);
7
+ DECL_FATTN_MMA_F16_CASE(96, 16, 1);
8
+ DECL_FATTN_MMA_F16_CASE(112, 16, 1);
9
+ DECL_FATTN_MMA_F16_CASE(128, 16, 1);
10
+ DECL_FATTN_MMA_F16_CASE(256, 16, 1);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 16, 2);
6
+ DECL_FATTN_MMA_F16_CASE(80, 16, 2);
7
+ DECL_FATTN_MMA_F16_CASE(96, 16, 2);
8
+ DECL_FATTN_MMA_F16_CASE(112, 16, 2);
9
+ DECL_FATTN_MMA_F16_CASE(128, 16, 2);
10
+ DECL_FATTN_MMA_F16_CASE(256, 16, 2);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 16, 4);
6
+ DECL_FATTN_MMA_F16_CASE(80, 16, 4);
7
+ DECL_FATTN_MMA_F16_CASE(96, 16, 4);
8
+ DECL_FATTN_MMA_F16_CASE(112, 16, 4);
9
+ DECL_FATTN_MMA_F16_CASE(128, 16, 4);
10
+ DECL_FATTN_MMA_F16_CASE(256, 16, 4);
ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb32.cu → fattn-mma-f16-instance-ncols1_2-ncols2_4.cu} RENAMED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 32);
6
- DECL_FATTN_MMA_F16_CASE(80, 32);
7
- DECL_FATTN_MMA_F16_CASE(96, 32);
8
- DECL_FATTN_MMA_F16_CASE(112, 32);
9
- DECL_FATTN_MMA_F16_CASE(128, 32);
10
- DECL_FATTN_MMA_F16_CASE(256, 32);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 2, 4);
6
+ DECL_FATTN_MMA_F16_CASE(80, 2, 4);
7
+ DECL_FATTN_MMA_F16_CASE(96, 2, 4);
8
+ DECL_FATTN_MMA_F16_CASE(112, 2, 4);
9
+ DECL_FATTN_MMA_F16_CASE(128, 2, 4);
10
+ DECL_FATTN_MMA_F16_CASE(256, 2, 4);
ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb64.cu → fattn-mma-f16-instance-ncols1_2-ncols2_8.cu} RENAMED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 64);
6
- DECL_FATTN_MMA_F16_CASE(80, 64);
7
- DECL_FATTN_MMA_F16_CASE(96, 64);
8
- DECL_FATTN_MMA_F16_CASE(112, 64);
9
- DECL_FATTN_MMA_F16_CASE(128, 64);
10
- DECL_FATTN_MMA_F16_CASE(256, 64);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 2, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 2, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 2, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 2, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 2, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 2, 8);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 32, 1);
6
+ DECL_FATTN_MMA_F16_CASE(80, 32, 1);
7
+ DECL_FATTN_MMA_F16_CASE(96, 32, 1);
8
+ DECL_FATTN_MMA_F16_CASE(112, 32, 1);
9
+ DECL_FATTN_MMA_F16_CASE(128, 32, 1);
10
+ DECL_FATTN_MMA_F16_CASE(256, 32, 1);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 32, 2);
6
+ DECL_FATTN_MMA_F16_CASE(80, 32, 2);
7
+ DECL_FATTN_MMA_F16_CASE(96, 32, 2);
8
+ DECL_FATTN_MMA_F16_CASE(112, 32, 2);
9
+ DECL_FATTN_MMA_F16_CASE(128, 32, 2);
10
+ DECL_FATTN_MMA_F16_CASE(256, 32, 2);
ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb8.cu → fattn-mma-f16-instance-ncols1_4-ncols2_2.cu} RENAMED
@@ -2,9 +2,9 @@
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
- DECL_FATTN_MMA_F16_CASE(64, 8);
6
- DECL_FATTN_MMA_F16_CASE(80, 8);
7
- DECL_FATTN_MMA_F16_CASE(96, 8);
8
- DECL_FATTN_MMA_F16_CASE(112, 8);
9
- DECL_FATTN_MMA_F16_CASE(128, 8);
10
- DECL_FATTN_MMA_F16_CASE(256, 8);
 
2
 
3
  #include "../fattn-mma-f16.cuh"
4
 
5
+ DECL_FATTN_MMA_F16_CASE(64, 4, 2);
6
+ DECL_FATTN_MMA_F16_CASE(80, 4, 2);
7
+ DECL_FATTN_MMA_F16_CASE(96, 4, 2);
8
+ DECL_FATTN_MMA_F16_CASE(112, 4, 2);
9
+ DECL_FATTN_MMA_F16_CASE(128, 4, 2);
10
+ DECL_FATTN_MMA_F16_CASE(256, 4, 2);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 4, 4);
6
+ DECL_FATTN_MMA_F16_CASE(80, 4, 4);
7
+ DECL_FATTN_MMA_F16_CASE(96, 4, 4);
8
+ DECL_FATTN_MMA_F16_CASE(112, 4, 4);
9
+ DECL_FATTN_MMA_F16_CASE(128, 4, 4);
10
+ DECL_FATTN_MMA_F16_CASE(256, 4, 4);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 4, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 4, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 4, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 4, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 4, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 4, 8);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 64, 1);
6
+ DECL_FATTN_MMA_F16_CASE(80, 64, 1);
7
+ DECL_FATTN_MMA_F16_CASE(96, 64, 1);
8
+ DECL_FATTN_MMA_F16_CASE(112, 64, 1);
9
+ DECL_FATTN_MMA_F16_CASE(128, 64, 1);
10
+ DECL_FATTN_MMA_F16_CASE(256, 64, 1);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 8, 1);
6
+ DECL_FATTN_MMA_F16_CASE(80, 8, 1);
7
+ DECL_FATTN_MMA_F16_CASE(96, 8, 1);
8
+ DECL_FATTN_MMA_F16_CASE(112, 8, 1);
9
+ DECL_FATTN_MMA_F16_CASE(128, 8, 1);
10
+ DECL_FATTN_MMA_F16_CASE(256, 8, 1);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 8, 2);
6
+ DECL_FATTN_MMA_F16_CASE(80, 8, 2);
7
+ DECL_FATTN_MMA_F16_CASE(96, 8, 2);
8
+ DECL_FATTN_MMA_F16_CASE(112, 8, 2);
9
+ DECL_FATTN_MMA_F16_CASE(128, 8, 2);
10
+ DECL_FATTN_MMA_F16_CASE(256, 8, 2);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 8, 4);
6
+ DECL_FATTN_MMA_F16_CASE(80, 8, 4);
7
+ DECL_FATTN_MMA_F16_CASE(96, 8, 4);
8
+ DECL_FATTN_MMA_F16_CASE(112, 8, 4);
9
+ DECL_FATTN_MMA_F16_CASE(128, 8, 4);
10
+ DECL_FATTN_MMA_F16_CASE(256, 8, 4);
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ // This file has been autogenerated by generate_cu_files.py, do not edit manually.
2
+
3
+ #include "../fattn-mma-f16.cuh"
4
+
5
+ DECL_FATTN_MMA_F16_CASE(64, 8, 8);
6
+ DECL_FATTN_MMA_F16_CASE(80, 8, 8);
7
+ DECL_FATTN_MMA_F16_CASE(96, 8, 8);
8
+ DECL_FATTN_MMA_F16_CASE(112, 8, 8);
9
+ DECL_FATTN_MMA_F16_CASE(128, 8, 8);
10
+ DECL_FATTN_MMA_F16_CASE(256, 8, 8);
ggml/src/ggml-cuda/template-instances/generate_cu_files.py CHANGED
@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
18
 
19
  """
20
 
21
- SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
22
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@@ -57,12 +57,18 @@ for vkq_size in [16, 32]:
57
  with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
58
  f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
59
 
60
- for cols_per_block in [8, 16, 32, 64]:
61
- with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
62
- f.write(SOURCE_FATTN_MMA_START)
63
-
64
- for head_size in [64, 80, 96, 112, 128, 256]:
65
- f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
 
 
 
 
 
 
66
 
67
  for type in TYPES_MMQ:
68
  with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
 
18
 
19
  """
20
 
21
+ SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
22
 
23
  TYPES_MMQ = [
24
  "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
 
57
  with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
58
  f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
59
 
60
+ for ncols in [8, 16, 32, 64, 128]:
61
+ for ncols2 in [1, 2, 4, 8]:
62
+ ncols1 = ncols // ncols2
63
+ if ncols == 128:
64
+ continue # Too much register pressure.
65
+ with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
66
+ f.write(SOURCE_FATTN_MMA_START)
67
+
68
+ for head_size in [64, 80, 96, 112, 128, 256]:
69
+ if ncols == 128 and head_size == 256:
70
+ continue # Needs too much shared memory.
71
+ f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
72
 
73
  for type in TYPES_MMQ:
74
  with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: