Spaces:
Running
Running
Commit
·
6662d54
1
Parent(s):
6cb8158
CUDA: optimize FA for GQA + large batches (llama/12014)
Browse files- ggml/src/ggml-cuda/cp-async.cuh +1 -1
- ggml/src/ggml-cuda/fattn-common.cuh +53 -69
- ggml/src/ggml-cuda/fattn-mma-f16.cuh +552 -252
- ggml/src/ggml-cuda/fattn-tile-f16.cu +2 -2
- ggml/src/ggml-cuda/fattn-tile-f32.cu +2 -2
- ggml/src/ggml-cuda/fattn-vec-f16.cuh +1 -1
- ggml/src/ggml-cuda/fattn-vec-f32.cuh +1 -1
- ggml/src/ggml-cuda/fattn-wmma-f16.cu +3 -3
- ggml/src/ggml-cuda/fattn.cu +56 -17
- ggml/src/ggml-cuda/mma.cuh +75 -0
- ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb16.cu → fattn-mma-f16-instance-ncols1_1-ncols2_8.cu} +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +10 -0
- ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb32.cu → fattn-mma-f16-instance-ncols1_2-ncols2_4.cu} +6 -6
- ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb64.cu → fattn-mma-f16-instance-ncols1_2-ncols2_8.cu} +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu +10 -0
- ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb8.cu → fattn-mma-f16-instance-ncols1_4-ncols2_2.cu} +6 -6
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +10 -0
- ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +10 -0
- ggml/src/ggml-cuda/template-instances/generate_cu_files.py +13 -7
ggml/src/ggml-cuda/cp-async.cuh
CHANGED
|
@@ -24,7 +24,7 @@ static __device__ __forceinline__ void cp_async_cg_16(const unsigned int dst, co
|
|
| 24 |
} else
|
| 25 |
#endif // CUDART_VERSION >= 11040
|
| 26 |
{
|
| 27 |
-
asm volatile("cp.async.cg.shared.global
|
| 28 |
: : "r"(dst), "l"(src));
|
| 29 |
}
|
| 30 |
#else
|
|
|
|
| 24 |
} else
|
| 25 |
#endif // CUDART_VERSION >= 11040
|
| 26 |
{
|
| 27 |
+
asm volatile("cp.async.cg.shared.global [%0], [%1], 16;"
|
| 28 |
: : "r"(dst), "l"(src));
|
| 29 |
}
|
| 30 |
#else
|
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -516,27 +516,25 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
| 516 |
nullptr;
|
| 517 |
}
|
| 518 |
|
| 519 |
-
|
| 520 |
-
#ifdef __clang__
|
| 521 |
-
#pragma clang diagnostic push
|
| 522 |
-
#pragma clang diagnostic ignored "-Wpass-failed"
|
| 523 |
-
#endif // __clang__
|
| 524 |
-
|
| 525 |
-
template<int D, int ncols, int KQ_stride> // D == head size
|
| 526 |
-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 527 |
__launch_bounds__(D, 1)
|
| 528 |
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 529 |
static __global__ void flash_attn_stream_k_fixup(
|
| 530 |
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
const int iter_k = ne11 / KQ_stride;
|
| 534 |
-
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
| 535 |
|
| 536 |
const int bidx0 = blockIdx.x;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 537 |
|
| 538 |
-
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
| 539 |
-
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
| 540 |
|
| 541 |
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
| 542 |
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
@@ -548,22 +546,22 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
| 548 |
const int channel = kbc0 / (iter_k*iter_j);
|
| 549 |
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
| 550 |
|
| 551 |
-
|
|
|
|
|
|
|
| 552 |
|
| 553 |
-
|
| 554 |
-
float dst_val[ncols] = {0.0f};
|
| 555 |
-
float max_val[ncols] = {0.0f};
|
| 556 |
-
float rowsum[ncols] = {0.0f};
|
| 557 |
-
#pragma unroll
|
| 558 |
-
for (int j = 0; j < ncols; ++j) {
|
| 559 |
-
if (jt*ncols + j >= ne01) {
|
| 560 |
-
break;
|
| 561 |
-
}
|
| 562 |
-
dst_val[j] = dst[j*ne02*D + threadIdx.x];
|
| 563 |
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 567 |
}
|
| 568 |
|
| 569 |
// Iterate over previous blocks and compute the combined results.
|
|
@@ -571,36 +569,30 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
| 571 |
int bidx = bidx0 - 1;
|
| 572 |
int kbc_stop = kbc0;
|
| 573 |
while(true) {
|
| 574 |
-
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
|
| 575 |
if (kbc == kbc_stop) { // Did not have any data.
|
| 576 |
bidx--;
|
| 577 |
kbc_stop = kbc;
|
| 578 |
continue;
|
| 579 |
}
|
| 580 |
|
| 581 |
-
|
| 582 |
-
for (int j = 0; j < ncols; ++j) {
|
| 583 |
-
if (jt*ncols + j >= ne01) {
|
| 584 |
-
break;
|
| 585 |
-
}
|
| 586 |
-
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
|
| 587 |
|
| 588 |
-
|
| 589 |
|
| 590 |
-
|
| 591 |
-
|
| 592 |
|
| 593 |
-
|
| 594 |
-
|
| 595 |
|
| 596 |
-
|
| 597 |
-
|
| 598 |
|
| 599 |
-
|
| 600 |
-
|
| 601 |
|
| 602 |
-
|
| 603 |
-
}
|
| 604 |
|
| 605 |
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
| 606 |
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
|
@@ -611,19 +603,9 @@ static __global__ void flash_attn_stream_k_fixup(
|
|
| 611 |
}
|
| 612 |
|
| 613 |
// Write back final result:
|
| 614 |
-
|
| 615 |
-
for (int j = 0; j < ncols; ++j) {
|
| 616 |
-
if (jt*ncols + j >= ne01) {
|
| 617 |
-
return;
|
| 618 |
-
}
|
| 619 |
-
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
|
| 620 |
-
}
|
| 621 |
}
|
| 622 |
|
| 623 |
-
#ifdef __clang__
|
| 624 |
-
#pragma clang diagnostic pop
|
| 625 |
-
#endif // __clang__
|
| 626 |
-
|
| 627 |
template<int D, int parallel_blocks> // D == head size
|
| 628 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 629 |
__launch_bounds__(D, 1)
|
|
@@ -690,11 +672,13 @@ static void on_no_fattn_vec_case(const int D) {
|
|
| 690 |
}
|
| 691 |
|
| 692 |
// parallel_blocks == 0 is stream-k decomposition
|
| 693 |
-
template <int D, int
|
| 694 |
void launch_fattn(
|
| 695 |
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
| 696 |
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
| 697 |
) {
|
|
|
|
|
|
|
| 698 |
const ggml_tensor * Q = dst->src[0];
|
| 699 |
const ggml_tensor * K = dst->src[1];
|
| 700 |
const ggml_tensor * V = dst->src[2];
|
|
@@ -763,25 +747,26 @@ void launch_fattn(
|
|
| 763 |
nb23 = nb23*bs*sizeof(half)/ts;
|
| 764 |
}
|
| 765 |
|
| 766 |
-
const int ntiles_x = ((Q->ne[1] +
|
| 767 |
-
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
|
| 768 |
|
| 769 |
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 770 |
dim3 blocks_num;
|
| 771 |
if (parallel_blocks == 0) {
|
| 772 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 773 |
-
const int
|
| 774 |
-
const int
|
|
|
|
| 775 |
|
| 776 |
-
const int nblocks_stream_k =
|
| 777 |
|
| 778 |
-
const bool use_stream_k =
|
| 779 |
|
| 780 |
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
| 781 |
blocks_num.y = 1;
|
| 782 |
blocks_num.z = 1;
|
| 783 |
|
| 784 |
-
dst_tmp_meta.alloc(blocks_num.x*
|
| 785 |
} else {
|
| 786 |
blocks_num.x = parallel_blocks*ntiles_x;
|
| 787 |
blocks_num.y = Q->ne[2];
|
|
@@ -793,7 +778,6 @@ void launch_fattn(
|
|
| 793 |
}
|
| 794 |
}
|
| 795 |
|
| 796 |
-
|
| 797 |
float scale = 1.0f;
|
| 798 |
float max_bias = 0.0f;
|
| 799 |
float logit_softcap = 0.0f;
|
|
@@ -832,9 +816,9 @@ void launch_fattn(
|
|
| 832 |
if constexpr (parallel_blocks == 0) {
|
| 833 |
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
| 834 |
const dim3 block_dim_combine(D, 1, 1);
|
| 835 |
-
const dim3 blocks_num_combine = blocks_num;
|
| 836 |
|
| 837 |
-
flash_attn_stream_k_fixup<D,
|
| 838 |
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 839 |
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
| 840 |
}
|
|
|
|
| 516 |
nullptr;
|
| 517 |
}
|
| 518 |
|
| 519 |
+
template<int D, int ncols1, int ncols2, int KQ_stride> // D == head size
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
__launch_bounds__(D, 1)
|
|
|
|
| 521 |
static __global__ void flash_attn_stream_k_fixup(
|
| 522 |
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
| 523 |
+
constexpr int ncols = ncols1*ncols2;
|
|
|
|
|
|
|
|
|
|
| 524 |
|
| 525 |
const int bidx0 = blockIdx.x;
|
| 526 |
+
const int j = blockIdx.y;
|
| 527 |
+
const int c = blockIdx.z;
|
| 528 |
+
const int jc = j*ncols2 + c;
|
| 529 |
+
const int tid = threadIdx.x;
|
| 530 |
+
|
| 531 |
+
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
|
| 532 |
+
|
| 533 |
+
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
| 534 |
+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
| 535 |
|
| 536 |
+
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 537 |
+
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 538 |
|
| 539 |
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
| 540 |
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
|
|
|
| 546 |
const int channel = kbc0 / (iter_k*iter_j);
|
| 547 |
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
| 548 |
|
| 549 |
+
if (jt*ncols1 + j >= ne01) {
|
| 550 |
+
return;
|
| 551 |
+
}
|
| 552 |
|
| 553 |
+
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
+
// Load the partial result that needs a fixup:
|
| 556 |
+
float dst_val = 0.0f;
|
| 557 |
+
float max_val = 0.0f;
|
| 558 |
+
float rowsum = 0.0f;
|
| 559 |
+
{
|
| 560 |
+
dst_val = *dst;
|
| 561 |
+
|
| 562 |
+
const float2 tmp = dst_fixup[bidx0*ncols + jc];
|
| 563 |
+
max_val = tmp.x;
|
| 564 |
+
rowsum = tmp.y;
|
| 565 |
}
|
| 566 |
|
| 567 |
// Iterate over previous blocks and compute the combined results.
|
|
|
|
| 569 |
int bidx = bidx0 - 1;
|
| 570 |
int kbc_stop = kbc0;
|
| 571 |
while(true) {
|
| 572 |
+
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 573 |
if (kbc == kbc_stop) { // Did not have any data.
|
| 574 |
bidx--;
|
| 575 |
kbc_stop = kbc;
|
| 576 |
continue;
|
| 577 |
}
|
| 578 |
|
| 579 |
+
const float dst_add = dst_fixup_data[bidx*ncols*D + jc*D + tid];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 580 |
|
| 581 |
+
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + jc];
|
| 582 |
|
| 583 |
+
// Scale the current and new value accumulators depending on the max. values.
|
| 584 |
+
const float max_val_new = fmaxf(max_val, tmp.x);
|
| 585 |
|
| 586 |
+
const float diff_val = max_val - max_val_new;
|
| 587 |
+
const float diff_add = tmp.x - max_val_new;
|
| 588 |
|
| 589 |
+
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
|
| 590 |
+
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
|
| 591 |
|
| 592 |
+
dst_val = scale_val*dst_val + scale_add*dst_add;
|
| 593 |
+
rowsum = scale_val*rowsum + scale_add*tmp.y;
|
| 594 |
|
| 595 |
+
max_val = max_val_new;
|
|
|
|
| 596 |
|
| 597 |
// If this block started in a previous tile we are done and don't need to combine additional partial results.
|
| 598 |
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
|
|
|
|
| 603 |
}
|
| 604 |
|
| 605 |
// Write back final result:
|
| 606 |
+
*dst = dst_val / rowsum;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 607 |
}
|
| 608 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 609 |
template<int D, int parallel_blocks> // D == head size
|
| 610 |
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 611 |
__launch_bounds__(D, 1)
|
|
|
|
| 672 |
}
|
| 673 |
|
| 674 |
// parallel_blocks == 0 is stream-k decomposition
|
| 675 |
+
template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride>
|
| 676 |
void launch_fattn(
|
| 677 |
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
|
| 678 |
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
|
| 679 |
) {
|
| 680 |
+
constexpr int ncols = ncols1 * ncols2;
|
| 681 |
+
|
| 682 |
const ggml_tensor * Q = dst->src[0];
|
| 683 |
const ggml_tensor * K = dst->src[1];
|
| 684 |
const ggml_tensor * V = dst->src[2];
|
|
|
|
| 747 |
nb23 = nb23*bs*sizeof(half)/ts;
|
| 748 |
}
|
| 749 |
|
| 750 |
+
const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1);
|
| 751 |
+
const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3];
|
| 752 |
|
| 753 |
const dim3 block_dim(WARP_SIZE, nwarps, 1);
|
| 754 |
dim3 blocks_num;
|
| 755 |
if (parallel_blocks == 0) {
|
| 756 |
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
|
| 757 |
+
const int max_blocks = 2*nsm;
|
| 758 |
+
const int tiles_nwaves = (ntiles_total + max_blocks - 1) / max_blocks;
|
| 759 |
+
const int tiles_efficiency_percent = 100 * ntiles_total / (max_blocks*tiles_nwaves);
|
| 760 |
|
| 761 |
+
const int nblocks_stream_k = max_blocks;
|
| 762 |
|
| 763 |
+
const bool use_stream_k = cc >= GGML_CUDA_CC_ADA_LOVELACE || tiles_efficiency_percent < 75;
|
| 764 |
|
| 765 |
blocks_num.x = use_stream_k ? nblocks_stream_k : ntiles_total;
|
| 766 |
blocks_num.y = 1;
|
| 767 |
blocks_num.z = 1;
|
| 768 |
|
| 769 |
+
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + D) * sizeof(float));
|
| 770 |
} else {
|
| 771 |
blocks_num.x = parallel_blocks*ntiles_x;
|
| 772 |
blocks_num.y = Q->ne[2];
|
|
|
|
| 778 |
}
|
| 779 |
}
|
| 780 |
|
|
|
|
| 781 |
float scale = 1.0f;
|
| 782 |
float max_bias = 0.0f;
|
| 783 |
float logit_softcap = 0.0f;
|
|
|
|
| 816 |
if constexpr (parallel_blocks == 0) {
|
| 817 |
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles.
|
| 818 |
const dim3 block_dim_combine(D, 1, 1);
|
| 819 |
+
const dim3 blocks_num_combine = {blocks_num.x, ncols1, ncols2};
|
| 820 |
|
| 821 |
+
flash_attn_stream_k_fixup<D, ncols1, ncols2, KQ_stride>
|
| 822 |
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
| 823 |
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
| 824 |
}
|
ggml/src/ggml-cuda/fattn-mma-f16.cuh
CHANGED
|
@@ -5,12 +5,15 @@
|
|
| 5 |
|
| 6 |
using namespace ggml_cuda_mma;
|
| 7 |
|
| 8 |
-
typedef tile<16,
|
| 9 |
-
typedef tile< 8,
|
| 10 |
-
typedef tile<16,
|
| 11 |
-
typedef tile<16,
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
| 15 |
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
|
| 16 |
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
|
@@ -27,7 +30,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
| 27 |
constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
|
| 28 |
constexpr int stride_i = WARP_SIZE / chunks_per_row;
|
| 29 |
#pragma unroll
|
| 30 |
-
for (int i0 = 0; i0 <
|
| 31 |
const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
|
| 32 |
const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
|
| 33 |
|
|
@@ -40,7 +43,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
| 40 |
|
| 41 |
// If D is not a power of 2, the rest is loaded synchronously.
|
| 42 |
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
| 43 |
-
static_assert(
|
| 44 |
#pragma unroll
|
| 45 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 46 |
const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
|
|
@@ -52,7 +55,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
| 52 |
}
|
| 53 |
|
| 54 |
#pragma unroll
|
| 55 |
-
for (int i0 = 0; i0 <
|
| 56 |
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 57 |
|
| 58 |
#pragma unroll
|
|
@@ -65,12 +68,54 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
|
| 65 |
}
|
| 66 |
}
|
| 67 |
|
| 68 |
-
template<int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
| 70 |
const float2 * const __restrict__ Q_f2,
|
| 71 |
const half2 * const __restrict__ K_h2,
|
| 72 |
const half2 * const __restrict__ V_h2,
|
| 73 |
-
const
|
| 74 |
float2 * const __restrict__ dstk,
|
| 75 |
float2 * const __restrict__ dstk_fixup,
|
| 76 |
const float scale,
|
|
@@ -78,42 +123,60 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 78 |
const float logit_softcap,
|
| 79 |
const int ne01,
|
| 80 |
const int ne02,
|
| 81 |
-
const int stride_Q,
|
| 82 |
const int stride_KV,
|
| 83 |
const int stride_mask,
|
| 84 |
const int jt,
|
| 85 |
half2 * const __restrict__ tile_K,
|
| 86 |
half2 * const __restrict__ tile_V,
|
|
|
|
| 87 |
const tile_B * const __restrict__ Q_B,
|
| 88 |
tile_C_VKQ * const __restrict__ VKQ_C,
|
| 89 |
-
|
| 90 |
-
|
| 91 |
const int kb0) {
|
| 92 |
#ifdef NEW_MMA_AVAILABLE
|
| 93 |
-
constexpr int
|
| 94 |
-
constexpr int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
|
| 99 |
#ifdef CP_ASYNC_AVAILABLE
|
| 100 |
cp_async_wait_all();
|
| 101 |
__syncthreads();
|
| 102 |
-
flash_attn_ext_f16_load_tile<D, nwarps,
|
| 103 |
#else
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
__syncthreads();
|
| 106 |
#endif // CP_ASYNC_AVAILABLE
|
| 107 |
|
| 108 |
// Calculate tile of KQ:
|
| 109 |
#pragma unroll
|
| 110 |
-
for (int i_KQ_00 = 0; i_KQ_00 <
|
| 111 |
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
|
| 112 |
#pragma unroll
|
| 113 |
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
|
| 114 |
tile_A K_A;
|
| 115 |
load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
}
|
| 118 |
}
|
| 119 |
|
|
@@ -122,9 +185,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 122 |
#endif // CP_ASYNC_AVAILABLE
|
| 123 |
|
| 124 |
if (use_logit_softcap) {
|
| 125 |
-
static_assert(
|
| 126 |
#pragma unroll
|
| 127 |
-
for (int i = 0; i <
|
| 128 |
#pragma unroll
|
| 129 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 130 |
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
|
@@ -132,109 +195,209 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 132 |
}
|
| 133 |
}
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
#pragma unroll
|
| 139 |
-
for (int
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
#pragma unroll
|
| 142 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 143 |
-
|
| 144 |
-
const int j = (threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l);
|
| 145 |
|
| 146 |
-
KQ_C[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
}
|
| 148 |
}
|
| 149 |
-
}
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
| 155 |
#pragma unroll
|
| 156 |
-
|
| 157 |
#pragma unroll
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
| 161 |
}
|
| 162 |
-
}
|
| 163 |
|
| 164 |
-
|
| 165 |
#pragma unroll
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
|
|
|
| 173 |
#pragma unroll
|
| 174 |
-
|
| 175 |
#pragma unroll
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
const float diff = KQ_C[k].x[l] - KQ_max_l;
|
| 179 |
-
KQ_C[k].x[l] = expf(diff);
|
| 180 |
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
}
|
| 186 |
}
|
| 187 |
}
|
| 188 |
|
| 189 |
{
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
|
|
|
|
|
|
| 193 |
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
|
| 198 |
-
|
|
|
|
| 199 |
#pragma unroll
|
| 200 |
-
|
| 201 |
#pragma unroll
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
}
|
| 205 |
}
|
| 206 |
}
|
| 207 |
|
| 208 |
// Convert KQ C tiles into B tiles for VKQ calculation:
|
| 209 |
-
tile_B B[
|
| 210 |
-
|
|
|
|
|
|
|
| 211 |
#pragma unroll
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
}
|
| 215 |
|
| 216 |
#ifdef CP_ASYNC_AVAILABLE
|
|
|
|
| 217 |
cp_async_wait_all();
|
| 218 |
__syncthreads();
|
| 219 |
if (!last_iter) {
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
| 221 |
}
|
| 222 |
#else
|
| 223 |
-
flash_attn_ext_f16_load_tile<D, nwarps,
|
| 224 |
__syncthreads();
|
| 225 |
#endif // CP_ASYNC_AVAILABLE
|
| 226 |
|
| 227 |
// Calculate VKQ tile:
|
| 228 |
#pragma unroll
|
| 229 |
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
|
| 230 |
-
static_assert((
|
| 231 |
#pragma unroll
|
| 232 |
-
for (int k00 = 0; k00 <
|
| 233 |
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
| 234 |
|
| 235 |
tile_A A;
|
| 236 |
load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
|
| 237 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
}
|
| 239 |
}
|
| 240 |
|
|
@@ -247,12 +410,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
|
| 247 |
#endif // NEW_MMA_AVAILABLE
|
| 248 |
}
|
| 249 |
|
| 250 |
-
template<int D, int
|
| 251 |
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 252 |
const float2 * const __restrict__ Q_f2,
|
| 253 |
const half2 * const __restrict__ K_h2,
|
| 254 |
const half2 * const __restrict__ V_h2,
|
| 255 |
-
const
|
| 256 |
float2 * const __restrict__ dstk,
|
| 257 |
float2 * const __restrict__ dstk_fixup,
|
| 258 |
const float scale,
|
|
@@ -260,7 +423,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 260 |
const float logit_softcap,
|
| 261 |
const int ne01,
|
| 262 |
const int ne02,
|
| 263 |
-
const int
|
|
|
|
| 264 |
const int stride_KV,
|
| 265 |
const int stride_mask,
|
| 266 |
const int jt,
|
|
@@ -269,63 +433,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 269 |
#ifdef NEW_MMA_AVAILABLE
|
| 270 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 271 |
|
| 272 |
-
|
| 273 |
-
constexpr int
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
|
| 275 |
-
static_assert(D
|
| 276 |
-
static_assert(
|
| 277 |
|
| 278 |
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 279 |
|
| 280 |
-
// Temporary shared buffer for loading K/V data with
|
| 281 |
extern __shared__ half2 tile_K[];
|
| 282 |
#ifdef CP_ASYNC_AVAILABLE
|
| 283 |
-
half2 * tile_V
|
| 284 |
#else
|
| 285 |
-
half2 * tile_V
|
| 286 |
#endif // CP_ASYNC_AVAILABLE
|
|
|
|
| 287 |
|
| 288 |
-
tile_B
|
| 289 |
-
tile_C_VKQ VKQ_C[D/tile_C_VKQ::I];
|
| 290 |
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
// Temporarily load Q data into tile_K, will be loaded into registers afterwards.
|
| 295 |
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
| 296 |
const half2 scale_h2 = make_half2(scale, scale);
|
| 297 |
#pragma unroll
|
| 298 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 299 |
-
const int k0_start
|
| 300 |
-
const int k0_stop
|
| 301 |
-
const int
|
| 302 |
|
| 303 |
if (k0_start == k0_stop) {
|
| 304 |
continue;
|
| 305 |
}
|
| 306 |
|
| 307 |
-
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
| 308 |
-
break;
|
| 309 |
-
}
|
| 310 |
-
|
| 311 |
#pragma unroll
|
| 312 |
-
for (int
|
| 313 |
-
const int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 314 |
|
| 315 |
-
if (jt*
|
| 316 |
#pragma unroll
|
| 317 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 318 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 319 |
|
| 320 |
-
const float2 tmp = Q_f2[(jt*
|
| 321 |
-
tile_K[
|
| 322 |
}
|
| 323 |
} else {
|
| 324 |
#pragma unroll
|
| 325 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 326 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 327 |
|
| 328 |
-
tile_K[
|
| 329 |
}
|
| 330 |
}
|
| 331 |
}
|
|
@@ -334,128 +513,217 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 334 |
__syncthreads();
|
| 335 |
|
| 336 |
{
|
| 337 |
-
const int j0 = (threadIdx.y / np) *
|
| 338 |
|
| 339 |
#pragma unroll
|
| 340 |
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
|
| 341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 342 |
}
|
| 343 |
}
|
| 344 |
|
| 345 |
__syncthreads();
|
| 346 |
|
| 347 |
-
// Preload K data for first iteration when using cp_async:
|
| 348 |
#ifdef CP_ASYNC_AVAILABLE
|
| 349 |
-
|
|
|
|
|
|
|
|
|
|
| 350 |
#endif // CP_ASYNC_AVAILABLE
|
| 351 |
|
| 352 |
// Iterate over ne11 == previous tokens:
|
| 353 |
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
| 354 |
constexpr bool last_iter = false;
|
| 355 |
-
flash_attn_ext_f16_iter<D,
|
| 356 |
-
(Q_f2, K_h2, V_h2,
|
| 357 |
-
ne01, ne02,
|
| 358 |
}
|
| 359 |
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 360 |
constexpr bool last_iter = true;
|
| 361 |
-
flash_attn_ext_f16_iter<D,
|
| 362 |
-
(Q_f2, K_h2, V_h2,
|
| 363 |
-
ne01, ne02,
|
| 364 |
}
|
| 365 |
|
| 366 |
// With cp_async there is no __syncthreads at the end of the iter,
|
| 367 |
// there can be a race condition on shared memory access for combining/writing back results.
|
| 368 |
#ifdef CP_ASYNC_AVAILABLE
|
| 369 |
-
if (nwarps*
|
| 370 |
__syncthreads();
|
| 371 |
}
|
| 372 |
#endif // CP_ASYNC_AVAILABLE
|
| 373 |
|
| 374 |
// Finally, sum up partial KQ rowsums.
|
| 375 |
-
// The partial sums are spread across 8 threads each, does not need full reduce.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 376 |
#pragma unroll
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
|
|
|
| 380 |
}
|
| 381 |
|
| 382 |
// Write VKQ accumulators to shared memory in column-major format.
|
| 383 |
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 384 |
// Also for np > 1 the combination is done via these values in shared memory.
|
| 385 |
-
|
|
|
|
| 386 |
#pragma unroll
|
| 387 |
-
|
| 388 |
-
|
| 389 |
|
| 390 |
#pragma unroll
|
| 391 |
-
|
| 392 |
-
|
| 393 |
|
| 394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
}
|
| 396 |
}
|
| 397 |
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
|
|
|
| 401 |
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
|
| 407 |
-
|
| 408 |
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
}
|
| 416 |
-
|
| 417 |
-
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
}
|
| 420 |
-
}
|
|
|
|
|
|
|
|
|
|
| 421 |
// Combine the meta data for parallel warps via shared memory.
|
| 422 |
// Warps with threadIdx.y % np != 0 must NOT return early.
|
| 423 |
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
| 424 |
|
| 425 |
-
|
| 426 |
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
|
|
|
|
|
|
|
|
|
| 430 |
}
|
| 431 |
|
| 432 |
-
float KQ_cmn =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 433 |
#pragma unroll
|
| 434 |
-
for (int offset = np*
|
|
|
|
|
|
|
|
|
|
| 435 |
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
| 436 |
}
|
| 437 |
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
}
|
|
|
|
|
|
|
| 443 |
#pragma unroll
|
| 444 |
-
for (int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 445 |
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
| 446 |
}
|
| 447 |
|
| 448 |
// Write back combined meta data:
|
| 449 |
-
|
| 450 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
}
|
| 452 |
-
|
|
|
|
|
|
|
|
|
|
| 453 |
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 454 |
-
dstk_fixup_meta[(threadIdx.y/np)*
|
| 455 |
}
|
| 456 |
-
if (is_fixup && threadIdx.x <
|
| 457 |
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 458 |
-
dstk_fixup_meta[(threadIdx.y/np)*
|
| 459 |
}
|
| 460 |
}
|
| 461 |
|
|
@@ -470,27 +738,32 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 470 |
|
| 471 |
#pragma unroll
|
| 472 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 473 |
-
const int k0_start
|
| 474 |
-
const int k0_stop
|
| 475 |
-
const int
|
| 476 |
|
| 477 |
if (k0_start == k0_stop) {
|
| 478 |
continue;
|
| 479 |
}
|
| 480 |
|
| 481 |
-
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
|
| 482 |
-
break;
|
| 483 |
-
}
|
| 484 |
-
|
| 485 |
#pragma unroll
|
| 486 |
-
for (int
|
| 487 |
-
const int
|
| 488 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
-
if (!is_fixup && jt*
|
| 491 |
continue;
|
| 492 |
}
|
| 493 |
-
|
|
|
|
| 494 |
#pragma unroll
|
| 495 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 496 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
@@ -498,8 +771,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 498 |
float2 dstk_val = make_float2(0.0f, 0.0f);
|
| 499 |
#pragma unroll
|
| 500 |
for (int ip = 0; ip < np; ++ip) {
|
| 501 |
-
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*
|
| 502 |
-
const float2 dstk_val_add = __half22float2(tile_K[(
|
| 503 |
dstk_val.x += dstk_val_add.x*KQ_crs;
|
| 504 |
dstk_val.y += dstk_val_add.y*KQ_crs;
|
| 505 |
}
|
|
@@ -511,9 +784,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 511 |
}
|
| 512 |
|
| 513 |
if (is_fixup) {
|
| 514 |
-
dstk_fixup_data[
|
| 515 |
} else {
|
| 516 |
-
dstk[(jt*
|
| 517 |
}
|
| 518 |
}
|
| 519 |
}
|
|
@@ -528,10 +801,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
|
| 528 |
#endif // NEW_MMA_AVAILABLE
|
| 529 |
}
|
| 530 |
|
| 531 |
-
template<int D, int
|
| 532 |
-
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 533 |
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
| 534 |
-
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 535 |
static __global__ void flash_attn_ext_f16(
|
| 536 |
const char * __restrict__ Q,
|
| 537 |
const char * __restrict__ K,
|
|
@@ -579,20 +850,23 @@ static __global__ void flash_attn_ext_f16(
|
|
| 579 |
return;
|
| 580 |
}
|
| 581 |
|
| 582 |
-
static_assert(FATTN_KQ_STRIDE %
|
| 583 |
|
| 584 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 585 |
|
| 586 |
-
const int
|
|
|
|
| 587 |
const int stride_KV = nb11 / sizeof(half2);
|
| 588 |
-
const int stride_mask = nb31 / sizeof(
|
|
|
|
|
|
|
|
|
|
| 589 |
|
| 590 |
-
|
| 591 |
-
const int iter_j = (ne01 + (ncols - 1)) / ncols;
|
| 592 |
|
| 593 |
// kbc == k block continuous, current index in continuous ijk space.
|
| 594 |
-
int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
|
| 595 |
-
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
|
| 596 |
|
| 597 |
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
| 598 |
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
@@ -605,25 +879,28 @@ static __global__ void flash_attn_ext_f16(
|
|
| 605 |
const int channel = kbc / (iter_k*iter_j);
|
| 606 |
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
| 607 |
|
| 608 |
-
const float2 * Q_f2
|
| 609 |
-
const half2 * K_h2
|
| 610 |
-
const half2 * V_h2
|
| 611 |
-
const
|
| 612 |
-
float2 * dstk
|
| 613 |
|
| 614 |
-
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
|
|
|
|
|
|
|
|
|
|
| 615 |
|
| 616 |
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 617 |
if (kb0_start == 0) {
|
| 618 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 619 |
-
flash_attn_ext_f16_process_tile<D,
|
| 620 |
-
(Q_f2, K_h2, V_h2,
|
| 621 |
-
ne01, ne02,
|
| 622 |
} else {
|
| 623 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 624 |
-
flash_attn_ext_f16_process_tile<D,
|
| 625 |
-
(Q_f2, K_h2, V_h2,
|
| 626 |
-
ne01, ne02,
|
| 627 |
}
|
| 628 |
|
| 629 |
kbc += iter_k;
|
|
@@ -640,39 +917,46 @@ static __global__ void flash_attn_ext_f16(
|
|
| 640 |
const int channel = kbc / (iter_k*iter_j);
|
| 641 |
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
| 642 |
|
| 643 |
-
const float2 * Q_f2
|
| 644 |
-
const half2 * K_h2
|
| 645 |
-
const half2 * V_h2
|
| 646 |
-
const
|
| 647 |
-
float2 * dstk
|
|
|
|
|
|
|
| 648 |
|
| 649 |
-
const
|
|
|
|
| 650 |
|
| 651 |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 652 |
constexpr bool needs_fixup = false;
|
| 653 |
-
flash_attn_ext_f16_process_tile<D,
|
| 654 |
-
(Q_f2, K_h2, V_h2,
|
| 655 |
-
ne01, ne02,
|
| 656 |
}
|
| 657 |
|
| 658 |
-
template <int D, int
|
| 659 |
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 660 |
-
|
| 661 |
-
|
|
|
|
|
|
|
|
|
|
| 662 |
|
| 663 |
-
static_assert(D
|
| 664 |
-
static_assert(
|
| 665 |
|
| 666 |
const ggml_tensor * KQV = dst;
|
| 667 |
-
const int
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
|
| 672 |
|
| 673 |
-
const
|
| 674 |
-
const int nrows_combine = nwarps*tile_B::J;
|
| 675 |
-
const size_t nbytes_shared = std::max(nrows_KQ, nrows_combine) * (D + 8) * sizeof(half);
|
| 676 |
|
| 677 |
float logit_softcap;
|
| 678 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
@@ -680,42 +964,58 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
|
|
| 680 |
fattn_kernel_t fattn_kernel;
|
| 681 |
if (logit_softcap == 0.0f) {
|
| 682 |
constexpr bool use_logit_softcap = false;
|
| 683 |
-
fattn_kernel = flash_attn_ext_f16<D,
|
| 684 |
} else {
|
| 685 |
constexpr bool use_logit_softcap = true;
|
| 686 |
-
fattn_kernel = flash_attn_ext_f16<D,
|
| 687 |
}
|
| 688 |
-
|
|
|
|
| 689 |
}
|
| 690 |
|
| 691 |
-
|
|
|
|
| 692 |
template void ggml_cuda_flash_attn_ext_mma_f16_case \
|
| 693 |
-
<D,
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
extern DECL_FATTN_MMA_F16_CASE(
|
| 697 |
-
extern DECL_FATTN_MMA_F16_CASE(
|
| 698 |
-
extern DECL_FATTN_MMA_F16_CASE(
|
| 699 |
-
extern DECL_FATTN_MMA_F16_CASE(
|
| 700 |
-
|
| 701 |
-
|
| 702 |
-
|
| 703 |
-
|
| 704 |
-
|
| 705 |
-
|
| 706 |
-
|
| 707 |
-
|
| 708 |
-
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
| 713 |
-
|
| 714 |
-
|
| 715 |
-
|
| 716 |
-
|
| 717 |
-
|
| 718 |
-
|
| 719 |
-
|
| 720 |
-
|
| 721 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
using namespace ggml_cuda_mma;
|
| 7 |
|
| 8 |
+
typedef tile<16, 8, half2> tile_A;
|
| 9 |
+
typedef tile< 8, 8, half2> tile_B;
|
| 10 |
+
typedef tile<16, 8, half2> tile_B_16;
|
| 11 |
+
typedef tile<16, 8, float> tile_C_KQ;
|
| 12 |
+
typedef tile<16, 16, float> tile_C_KQ_16;
|
| 13 |
+
typedef tile<16, 4, half2> tile_C_VKQ;
|
| 14 |
+
typedef tile<16, 8, half2> tile_C_VKQ_16;
|
| 15 |
+
|
| 16 |
+
template<int D, int nwarps, int KQ_per_iter>
|
| 17 |
static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
|
| 18 |
const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV) {
|
| 19 |
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
|
|
|
| 30 |
constexpr int chunks_per_row = k0_sync_start / h2_per_chunk;
|
| 31 |
constexpr int stride_i = WARP_SIZE / chunks_per_row;
|
| 32 |
#pragma unroll
|
| 33 |
+
for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
|
| 34 |
const int i = i0 + threadIdx.y*stride_i + (chunks_per_row == WARP_SIZE ? 0 : threadIdx.x / chunks_per_row);
|
| 35 |
const int k = (chunks_per_row == WARP_SIZE ? threadIdx.x : threadIdx.x % chunks_per_row)*h2_per_chunk;
|
| 36 |
|
|
|
|
| 43 |
|
| 44 |
// If D is not a power of 2, the rest is loaded synchronously.
|
| 45 |
// K/V data is loaded with decreasing granularity for D for better memory bandwidth.
|
| 46 |
+
static_assert(KQ_per_iter % (4*nwarps) == 0, "out of bounds");
|
| 47 |
#pragma unroll
|
| 48 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 49 |
const int k0_start = stride_k == WARP_SIZE ? k0_sync_start : D/2 - (D/2) % (2*stride_k);
|
|
|
|
| 55 |
}
|
| 56 |
|
| 57 |
#pragma unroll
|
| 58 |
+
for (int i0 = 0; i0 < KQ_per_iter; i0 += nwarps*stride_i) {
|
| 59 |
const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 60 |
|
| 61 |
#pragma unroll
|
|
|
|
| 68 |
}
|
| 69 |
}
|
| 70 |
|
| 71 |
+
template<int ncols1, int nwarps, int KQ_per_iter>
|
| 72 |
+
static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
|
| 73 |
+
const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
|
| 74 |
+
static_assert(KQ_per_iter == 2*WARP_SIZE || KQ_per_iter == WARP_SIZE, "bad KQ_per_iter");
|
| 75 |
+
#ifdef CP_ASYNC_AVAILABLE
|
| 76 |
+
constexpr int preload = KQ_per_iter * sizeof(half);
|
| 77 |
+
constexpr int cols_per_warp = 8*WARP_SIZE/KQ_per_iter;
|
| 78 |
+
constexpr int stride_j = nwarps * cols_per_warp;
|
| 79 |
+
|
| 80 |
+
const unsigned int tile_mask_32 = __cvta_generic_to_shared(tile_mask);
|
| 81 |
+
|
| 82 |
+
#pragma unroll
|
| 83 |
+
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
|
| 84 |
+
const int j = j0 + threadIdx.y*cols_per_warp +
|
| 85 |
+
(KQ_per_iter == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/8));
|
| 86 |
+
|
| 87 |
+
if (j0 + stride_j > ncols1 && j >= ncols1) {
|
| 88 |
+
break;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
const int i = 4 * (KQ_per_iter == 2*WARP_SIZE ? threadIdx.x % (WARP_SIZE/4) : threadIdx.x % (WARP_SIZE/8));
|
| 92 |
+
|
| 93 |
+
cp_async_cg_16<preload>(tile_mask_32 + j*(KQ_per_iter*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
|
| 94 |
+
}
|
| 95 |
+
#else
|
| 96 |
+
constexpr int cols_per_warp = 2*WARP_SIZE/KQ_per_iter;
|
| 97 |
+
constexpr int stride_j = nwarps * cols_per_warp;
|
| 98 |
+
#pragma unroll
|
| 99 |
+
for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
|
| 100 |
+
const int j = j0 + threadIdx.y*cols_per_warp + (KQ_per_iter == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/2));
|
| 101 |
+
|
| 102 |
+
if (j0 + stride_j > ncols1 && j >= ncols1) {
|
| 103 |
+
break;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
const int i = KQ_per_iter == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/2);
|
| 107 |
+
|
| 108 |
+
tile_mask[j*(KQ_per_iter/2 + 4) + i] = mask_h2[j*stride_mask + i];
|
| 109 |
+
}
|
| 110 |
+
#endif // CP_ASYNC_AVAILABLE
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup, bool last_iter>
|
| 114 |
static __device__ __forceinline__ void flash_attn_ext_f16_iter(
|
| 115 |
const float2 * const __restrict__ Q_f2,
|
| 116 |
const half2 * const __restrict__ K_h2,
|
| 117 |
const half2 * const __restrict__ V_h2,
|
| 118 |
+
const half2 * const __restrict__ mask_h2,
|
| 119 |
float2 * const __restrict__ dstk,
|
| 120 |
float2 * const __restrict__ dstk_fixup,
|
| 121 |
const float scale,
|
|
|
|
| 123 |
const float logit_softcap,
|
| 124 |
const int ne01,
|
| 125 |
const int ne02,
|
|
|
|
| 126 |
const int stride_KV,
|
| 127 |
const int stride_mask,
|
| 128 |
const int jt,
|
| 129 |
half2 * const __restrict__ tile_K,
|
| 130 |
half2 * const __restrict__ tile_V,
|
| 131 |
+
half2 * const __restrict__ tile_mask,
|
| 132 |
const tile_B * const __restrict__ Q_B,
|
| 133 |
tile_C_VKQ * const __restrict__ VKQ_C,
|
| 134 |
+
float * const __restrict__ KQ_max,
|
| 135 |
+
float * const __restrict__ KQ_rowsum,
|
| 136 |
const int kb0) {
|
| 137 |
#ifdef NEW_MMA_AVAILABLE
|
| 138 |
+
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 139 |
+
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
| 140 |
+
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
| 141 |
+
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 142 |
+
|
| 143 |
+
const int k_VKQ_0 = kb0 * KQ_per_iter;
|
| 144 |
+
tile_C_KQ KQ_C[KQ_per_iter/(np*tile_C_KQ::I) * ntiles];
|
| 145 |
|
| 146 |
+
// Use wide variants of tiles if ntiles >= 2.
|
| 147 |
+
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
|
| 148 |
+
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
|
| 149 |
+
tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
|
| 150 |
|
| 151 |
#ifdef CP_ASYNC_AVAILABLE
|
| 152 |
cp_async_wait_all();
|
| 153 |
__syncthreads();
|
| 154 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
|
| 155 |
#else
|
| 156 |
+
if (ncols2 > 1 || mask_h2) {
|
| 157 |
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
|
| 158 |
+
}
|
| 159 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + k_VKQ_0*stride_KV, tile_K, stride_KV);
|
| 160 |
__syncthreads();
|
| 161 |
#endif // CP_ASYNC_AVAILABLE
|
| 162 |
|
| 163 |
// Calculate tile of KQ:
|
| 164 |
#pragma unroll
|
| 165 |
+
for (int i_KQ_00 = 0; i_KQ_00 < KQ_per_iter; i_KQ_00 += np*tile_A::I) {
|
| 166 |
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
|
| 167 |
#pragma unroll
|
| 168 |
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += tile_A::J) {
|
| 169 |
tile_A K_A;
|
| 170 |
load_ldmatrix(K_A, tile_K + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
|
| 171 |
+
if (ntiles == 1) {
|
| 172 |
+
mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
|
| 173 |
+
} else {
|
| 174 |
+
#pragma unroll
|
| 175 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 176 |
+
// Wide version of KQ_C is column-major => swap A and B.
|
| 177 |
+
mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
|
| 178 |
+
}
|
| 179 |
+
}
|
| 180 |
}
|
| 181 |
}
|
| 182 |
|
|
|
|
| 185 |
#endif // CP_ASYNC_AVAILABLE
|
| 186 |
|
| 187 |
if (use_logit_softcap) {
|
| 188 |
+
static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 189 |
#pragma unroll
|
| 190 |
+
for (int i = 0; i < KQ_per_iter/(np*tile_C_KQ::I) * ntiles; ++i) {
|
| 191 |
#pragma unroll
|
| 192 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 193 |
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
|
|
|
|
| 195 |
}
|
| 196 |
}
|
| 197 |
|
| 198 |
+
float KQ_max_new[cols_per_thread];
|
| 199 |
+
#pragma unroll
|
| 200 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 201 |
+
KQ_max_new[col] = KQ_max[col];
|
| 202 |
+
}
|
| 203 |
+
float KQ_rowsum_add[cols_per_thread] = {0.0f};
|
| 204 |
+
|
| 205 |
+
if (ntiles == 1) {
|
| 206 |
+
if (ncols2 > 1 || mask_h2) {
|
| 207 |
+
#pragma unroll
|
| 208 |
+
for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ::I) {
|
| 209 |
+
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
|
| 210 |
+
#pragma unroll
|
| 211 |
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 212 |
+
const int i = i0 + tile_C_KQ::get_i(l);
|
| 213 |
+
const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
|
| 214 |
+
|
| 215 |
+
KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
|
| 216 |
+
__half2float(((const half *) tile_mask)[j*(KQ_per_iter + 8) + i]);
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
// Calculate softmax for each KQ column using the current max. value.
|
| 222 |
+
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 223 |
+
static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 224 |
+
#pragma unroll
|
| 225 |
+
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
|
| 226 |
+
#pragma unroll
|
| 227 |
+
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 228 |
+
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
|
| 233 |
#pragma unroll
|
| 234 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 235 |
+
#pragma unroll
|
| 236 |
+
for (int offset = 16; offset >= 4; offset >>= 1) {
|
| 237 |
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 242 |
+
|
| 243 |
+
#pragma unroll
|
| 244 |
+
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ::I); ++k) {
|
| 245 |
#pragma unroll
|
| 246 |
for (int l = 0; l < tile_C_KQ::ne; ++l) {
|
| 247 |
+
KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
|
|
|
|
| 248 |
|
| 249 |
+
KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
} else { // ntiles > 1
|
| 253 |
+
if (ncols2 > 1 || mask_h2) {
|
| 254 |
+
#pragma unroll
|
| 255 |
+
for (int i00 = 0; i00 < KQ_per_iter; i00 += np*tile_C_KQ_16::J) {
|
| 256 |
+
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
|
| 257 |
+
#pragma unroll
|
| 258 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 259 |
+
#pragma unroll
|
| 260 |
+
for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
|
| 261 |
+
const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
|
| 262 |
+
const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
|
| 263 |
+
|
| 264 |
+
const float2 tmp = __half22float2(tile_mask[j*(KQ_per_iter/2 + 4) + i]);
|
| 265 |
+
const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
|
| 266 |
+
KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
|
| 267 |
+
KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
|
| 268 |
+
}
|
| 269 |
+
}
|
| 270 |
}
|
| 271 |
}
|
|
|
|
| 272 |
|
| 273 |
+
// Calculate softmax for each KQ column using the current max. value.
|
| 274 |
+
// The divisor is stored in KQ_rowsum and will be applied at the end.
|
| 275 |
+
static_assert(KQ_per_iter % (np*tile_C_KQ::I) == 0, "bad loop size");
|
| 276 |
+
#pragma unroll
|
| 277 |
+
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
|
| 278 |
#pragma unroll
|
| 279 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 280 |
#pragma unroll
|
| 281 |
+
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
|
| 282 |
+
const int KQ_index = 2*t + (l/2) % 2;
|
| 283 |
+
KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
}
|
|
|
|
| 287 |
|
| 288 |
+
// Values per KQ column are spread across 4 threads, does not need full warp reduce:
|
| 289 |
#pragma unroll
|
| 290 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 291 |
+
#pragma unroll
|
| 292 |
+
for (int offset = 2; offset >= 1; offset >>= 1) {
|
| 293 |
+
KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
|
| 294 |
+
}
|
| 295 |
+
}
|
| 296 |
|
| 297 |
+
static_assert(KQ_per_iter % (np*tile_C_KQ_16::J) == 0, "bad loop size");
|
| 298 |
+
#pragma unroll
|
| 299 |
+
for (int k = 0; k < KQ_per_iter/(np*tile_C_KQ_16::J); ++k) {
|
| 300 |
#pragma unroll
|
| 301 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 302 |
#pragma unroll
|
| 303 |
+
for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
|
| 304 |
+
const int KQ_index = 2*t + (l/2) % 2;
|
|
|
|
|
|
|
| 305 |
|
| 306 |
+
KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
|
| 307 |
+
|
| 308 |
+
KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
|
| 309 |
+
}
|
| 310 |
}
|
| 311 |
}
|
| 312 |
}
|
| 313 |
|
| 314 |
{
|
| 315 |
+
float KQ_max_scale[cols_per_thread];
|
| 316 |
+
#pragma unroll
|
| 317 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 318 |
+
KQ_max_scale[col] = expf(KQ_max[col] - KQ_max_new[col]);
|
| 319 |
+
KQ_max[col] = KQ_max_new[col];
|
| 320 |
|
| 321 |
+
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
|
| 322 |
+
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
|
| 323 |
+
}
|
| 324 |
|
| 325 |
+
if (ntiles == 1) {
|
| 326 |
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
| 327 |
#pragma unroll
|
| 328 |
+
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
|
| 329 |
#pragma unroll
|
| 330 |
+
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
| 331 |
+
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
} else {
|
| 335 |
+
#pragma unroll
|
| 336 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 337 |
+
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
| 338 |
+
#pragma unroll
|
| 339 |
+
for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
|
| 340 |
+
#pragma unroll
|
| 341 |
+
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
| 342 |
+
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
}
|
| 346 |
}
|
| 347 |
}
|
| 348 |
|
| 349 |
// Convert KQ C tiles into B tiles for VKQ calculation:
|
| 350 |
+
tile_B B[KQ_per_iter/(np*2*tile_B::J) * ntiles];
|
| 351 |
+
tile_B_16 * B_16 = (tile_B_16 *) B;
|
| 352 |
+
static_assert(KQ_per_iter % (np*2*tile_B::J) == 0, "bad loop size");
|
| 353 |
+
if (ntiles == 1) {
|
| 354 |
#pragma unroll
|
| 355 |
+
for (int k = 0; k < KQ_per_iter/(np*2*tile_B::J); ++k) {
|
| 356 |
+
B[k] = get_transposed(get_half2(KQ_C[k]));
|
| 357 |
+
}
|
| 358 |
+
} else {
|
| 359 |
+
for (int k = 0; k < KQ_per_iter/(np*2*tile_B_16::J); ++k) {
|
| 360 |
+
#pragma unroll
|
| 361 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 362 |
+
B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
}
|
| 366 |
|
| 367 |
#ifdef CP_ASYNC_AVAILABLE
|
| 368 |
+
// Preload K tile for next iteration:
|
| 369 |
cp_async_wait_all();
|
| 370 |
__syncthreads();
|
| 371 |
if (!last_iter) {
|
| 372 |
+
if (ncols2 > 1 || mask_h2) {
|
| 373 |
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + (k_VKQ_0 + KQ_per_iter)/2, tile_mask, stride_mask);
|
| 374 |
+
}
|
| 375 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + (k_VKQ_0 + KQ_per_iter)*stride_KV, tile_K, stride_KV);
|
| 376 |
}
|
| 377 |
#else
|
| 378 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(V_h2 + k_VKQ_0*stride_KV, tile_V, stride_KV);
|
| 379 |
__syncthreads();
|
| 380 |
#endif // CP_ASYNC_AVAILABLE
|
| 381 |
|
| 382 |
// Calculate VKQ tile:
|
| 383 |
#pragma unroll
|
| 384 |
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += tile_C_VKQ::I) {
|
| 385 |
+
static_assert((KQ_per_iter/2) % (np*tile_A::J) == 0, "bad loop size");
|
| 386 |
#pragma unroll
|
| 387 |
+
for (int k00 = 0; k00 < KQ_per_iter/2; k00 += np*tile_A::J) {
|
| 388 |
const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
|
| 389 |
|
| 390 |
tile_A A;
|
| 391 |
load_ldmatrix_trans(A, tile_V + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
|
| 392 |
+
if (ntiles == 1) {
|
| 393 |
+
mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
|
| 394 |
+
} else {
|
| 395 |
+
#pragma unroll
|
| 396 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 397 |
+
// Wide version of VKQ_C is column-major => swap A and B.
|
| 398 |
+
mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
}
|
| 402 |
}
|
| 403 |
|
|
|
|
| 410 |
#endif // NEW_MMA_AVAILABLE
|
| 411 |
}
|
| 412 |
|
| 413 |
+
template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
|
| 414 |
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
| 415 |
const float2 * const __restrict__ Q_f2,
|
| 416 |
const half2 * const __restrict__ K_h2,
|
| 417 |
const half2 * const __restrict__ V_h2,
|
| 418 |
+
const half2 * const __restrict__ mask_h2,
|
| 419 |
float2 * const __restrict__ dstk,
|
| 420 |
float2 * const __restrict__ dstk_fixup,
|
| 421 |
const float scale,
|
|
|
|
| 423 |
const float logit_softcap,
|
| 424 |
const int ne01,
|
| 425 |
const int ne02,
|
| 426 |
+
const int stride_Q1,
|
| 427 |
+
const int stride_Q2,
|
| 428 |
const int stride_KV,
|
| 429 |
const int stride_mask,
|
| 430 |
const int jt,
|
|
|
|
| 433 |
#ifdef NEW_MMA_AVAILABLE
|
| 434 |
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
| 435 |
|
| 436 |
+
constexpr int ncols = ncols1 * ncols2;
|
| 437 |
+
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 438 |
+
constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
|
| 439 |
+
constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
|
| 440 |
+
|
| 441 |
+
static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
|
| 442 |
|
| 443 |
+
static_assert(D % nwarps == 0, "bad D");
|
| 444 |
+
static_assert(KQ_per_iter % nwarps == 0, "bad KQ_per_iter");
|
| 445 |
|
| 446 |
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
|
| 447 |
|
| 448 |
+
// Temporary shared buffer for loading K/V data with KQ_per_iter*D logical elements:
|
| 449 |
extern __shared__ half2 tile_K[];
|
| 450 |
#ifdef CP_ASYNC_AVAILABLE
|
| 451 |
+
half2 * tile_V = tile_K + KQ_per_iter*D2_padded;
|
| 452 |
#else
|
| 453 |
+
half2 * tile_V = tile_K;
|
| 454 |
#endif // CP_ASYNC_AVAILABLE
|
| 455 |
+
half2 * tile_mask = tile_V + KQ_per_iter*D2_padded;
|
| 456 |
|
| 457 |
+
tile_B Q_B[D/(2*tile_B::J) * ntiles];
|
| 458 |
+
tile_C_VKQ VKQ_C[D/tile_C_VKQ::I * ntiles];
|
| 459 |
|
| 460 |
+
tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
|
| 461 |
+
tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
|
| 462 |
+
|
| 463 |
+
float KQ_rowsum[cols_per_thread] = {0.0f};
|
| 464 |
+
float KQ_max[cols_per_thread];
|
| 465 |
+
#pragma unroll
|
| 466 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 467 |
+
KQ_max[col] = -FLT_MAX/2.0f;
|
| 468 |
+
}
|
| 469 |
|
| 470 |
// Temporarily load Q data into tile_K, will be loaded into registers afterwards.
|
| 471 |
// The loading is done with decreasing granularity for D for better memory bandwidth.
|
| 472 |
const half2 scale_h2 = make_half2(scale, scale);
|
| 473 |
#pragma unroll
|
| 474 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 475 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
| 476 |
+
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 477 |
+
const int stride_jc = WARP_SIZE / stride_k;
|
| 478 |
|
| 479 |
if (k0_start == k0_stop) {
|
| 480 |
continue;
|
| 481 |
}
|
| 482 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
#pragma unroll
|
| 484 |
+
for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
|
| 485 |
+
const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 486 |
+
|
| 487 |
+
if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
|
| 488 |
+
break;
|
| 489 |
+
}
|
| 490 |
+
|
| 491 |
+
const int j = jc / ncols2;
|
| 492 |
+
const int c = jc % ncols2;
|
| 493 |
|
| 494 |
+
if (jt*ncols1 + j < ne01) {
|
| 495 |
#pragma unroll
|
| 496 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 497 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 498 |
|
| 499 |
+
const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
|
| 500 |
+
tile_K[jc*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
|
| 501 |
}
|
| 502 |
} else {
|
| 503 |
#pragma unroll
|
| 504 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 505 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
| 506 |
|
| 507 |
+
tile_K[jc*D2_padded + k] = make_half2(0.0f, 0.0f);
|
| 508 |
}
|
| 509 |
}
|
| 510 |
}
|
|
|
|
| 513 |
__syncthreads();
|
| 514 |
|
| 515 |
{
|
| 516 |
+
const int j0 = (threadIdx.y / np) * cols_per_warp;
|
| 517 |
|
| 518 |
#pragma unroll
|
| 519 |
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
|
| 520 |
+
if (ntiles == 1) {
|
| 521 |
+
load_ldmatrix(Q_B[k0/tile_B::J], tile_K + j0*D2_padded + k0, D2_padded);
|
| 522 |
+
} else {
|
| 523 |
+
#pragma unroll
|
| 524 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 525 |
+
load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
|
| 526 |
+
tile_K + (j0 + t*tile_B_16::I)*D2_padded + k0, D2_padded);
|
| 527 |
+
}
|
| 528 |
+
}
|
| 529 |
}
|
| 530 |
}
|
| 531 |
|
| 532 |
__syncthreads();
|
| 533 |
|
| 534 |
+
// Preload mask and K data for first iteration when using cp_async:
|
| 535 |
#ifdef CP_ASYNC_AVAILABLE
|
| 536 |
+
if (ncols2 > 1 || mask_h2) {
|
| 537 |
+
flash_attn_ext_f16_load_mask<ncols1, nwarps, KQ_per_iter>(mask_h2 + kb0_start*KQ_per_iter/2, tile_mask, stride_mask);
|
| 538 |
+
}
|
| 539 |
+
flash_attn_ext_f16_load_tile<D, nwarps, KQ_per_iter>(K_h2 + kb0_start*KQ_per_iter*stride_KV, tile_K, stride_KV);
|
| 540 |
#endif // CP_ASYNC_AVAILABLE
|
| 541 |
|
| 542 |
// Iterate over ne11 == previous tokens:
|
| 543 |
for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) {
|
| 544 |
constexpr bool last_iter = false;
|
| 545 |
+
flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
| 546 |
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 547 |
+
ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
|
| 548 |
}
|
| 549 |
{ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
|
| 550 |
constexpr bool last_iter = true;
|
| 551 |
+
flash_attn_ext_f16_iter<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup, last_iter>
|
| 552 |
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
|
| 553 |
+
ne01, ne02, stride_KV, stride_mask, jt, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1);
|
| 554 |
}
|
| 555 |
|
| 556 |
// With cp_async there is no __syncthreads at the end of the iter,
|
| 557 |
// there can be a race condition on shared memory access for combining/writing back results.
|
| 558 |
#ifdef CP_ASYNC_AVAILABLE
|
| 559 |
+
if (nwarps*cols_per_warp > KQ_per_iter) {
|
| 560 |
__syncthreads();
|
| 561 |
}
|
| 562 |
#endif // CP_ASYNC_AVAILABLE
|
| 563 |
|
| 564 |
// Finally, sum up partial KQ rowsums.
|
| 565 |
+
// The partial sums are spread across 8/4 threads each, does not need full reduce.
|
| 566 |
+
{
|
| 567 |
+
constexpr int offset_first = ntiles == 1 ? 16 : 2;
|
| 568 |
+
constexpr int offset_last = ntiles == 1 ? 4 : 1;
|
| 569 |
+
#pragma unroll
|
| 570 |
+
for (int col = 0; col < cols_per_thread; ++col) {
|
| 571 |
#pragma unroll
|
| 572 |
+
for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
|
| 573 |
+
KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
|
| 574 |
+
}
|
| 575 |
+
}
|
| 576 |
}
|
| 577 |
|
| 578 |
// Write VKQ accumulators to shared memory in column-major format.
|
| 579 |
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
| 580 |
// Also for np > 1 the combination is done via these values in shared memory.
|
| 581 |
+
if (ntiles == 1) {
|
| 582 |
+
const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
|
| 583 |
#pragma unroll
|
| 584 |
+
for (int k0 = 0; k0 < D/2; k0 += tile_B::J) {
|
| 585 |
+
const tile_B B = get_transposed(VKQ_C[k0/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
|
| 586 |
|
| 587 |
#pragma unroll
|
| 588 |
+
for (int l = 0; l < tile_B::ne; ++l) {
|
| 589 |
+
const int k = k0 + tile_B::get_j(l);
|
| 590 |
|
| 591 |
+
tile_K[jc_cwd*D2_padded + k] = B.x[l];
|
| 592 |
+
}
|
| 593 |
+
}
|
| 594 |
+
} else {
|
| 595 |
+
#pragma unroll
|
| 596 |
+
for (int t = 0; t < ntiles/2; ++t) {
|
| 597 |
+
const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
|
| 598 |
+
#pragma unroll
|
| 599 |
+
for (int k0 = 0; k0 < D/2; k0 += tile_C_VKQ_16::J) {
|
| 600 |
+
#pragma unroll
|
| 601 |
+
for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
|
| 602 |
+
const int j = j0 + tile_C_VKQ_16::get_i(l);
|
| 603 |
+
const int k = k0 + tile_C_VKQ_16::get_j(l);
|
| 604 |
+
|
| 605 |
+
tile_K[j*D2_padded + k] = VKQ_C_16[k0/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
|
| 606 |
+
}
|
| 607 |
+
}
|
| 608 |
}
|
| 609 |
}
|
| 610 |
|
| 611 |
+
if constexpr (ntiles == 1) {
|
| 612 |
+
const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
|
| 613 |
+
const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
|
| 614 |
+
const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
|
| 615 |
|
| 616 |
+
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
|
| 617 |
+
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 618 |
+
((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
|
| 619 |
+
}
|
| 620 |
|
| 621 |
+
__syncthreads();
|
| 622 |
|
| 623 |
+
if (np == 1) {
|
| 624 |
+
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
| 625 |
+
if (needs_fixup && threadIdx.x < tile_B::I) {
|
| 626 |
+
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 627 |
+
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
| 628 |
+
}
|
| 629 |
+
if (is_fixup && threadIdx.x < tile_B::I) {
|
| 630 |
+
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 631 |
+
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
| 632 |
+
}
|
| 633 |
}
|
| 634 |
+
} else {
|
| 635 |
+
static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
|
| 636 |
+
const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
|
| 637 |
+
+ (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
|
| 638 |
+
+ tile_C_VKQ_16::get_i(threadIdx.x % 4);
|
| 639 |
+
const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
|
| 640 |
+
|
| 641 |
+
if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
|
| 642 |
+
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
|
| 643 |
+
((float2 *) tile_K)[jc_cwm*(D2_padded/2) + D/4] = KQ_cmr;
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
__syncthreads();
|
| 647 |
+
|
| 648 |
+
if (np == 1) {
|
| 649 |
+
// No combination is needed, the meta data can be directly written from registers to VRAM.
|
| 650 |
+
if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
|
| 651 |
+
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 652 |
+
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
| 653 |
+
}
|
| 654 |
+
if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
|
| 655 |
+
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 656 |
+
dstk_fixup_meta[jc_cwm] = KQ_cmr;
|
| 657 |
+
}
|
| 658 |
}
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
|
| 662 |
+
if (np > 1 && threadIdx.y % np == 0) {
|
| 663 |
// Combine the meta data for parallel warps via shared memory.
|
| 664 |
// Warps with threadIdx.y % np != 0 must NOT return early.
|
| 665 |
// All threads must return simultaneously to avoid race conditions with work on the next tile.
|
| 666 |
|
| 667 |
+
constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
|
| 668 |
|
| 669 |
+
const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
|
| 670 |
+
float2 * const meta_ptr = ((float2 *) tile_K) + jc_meta*(D2_padded/2) + D/4;
|
| 671 |
+
float2 meta[nmeta];
|
| 672 |
+
#pragma unroll
|
| 673 |
+
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
| 674 |
+
meta[imeta] = meta_ptr[imeta * WARP_SIZE * D2_padded/2];
|
| 675 |
}
|
| 676 |
|
| 677 |
+
float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
|
| 678 |
+
#pragma unroll
|
| 679 |
+
for (int imeta = 1; imeta < nmeta; ++imeta) {
|
| 680 |
+
KQ_cmn = fmaxf(KQ_cmn, meta[imeta].x);
|
| 681 |
+
}
|
| 682 |
#pragma unroll
|
| 683 |
+
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
| 684 |
+
if (offset >= WARP_SIZE) {
|
| 685 |
+
continue;
|
| 686 |
+
}
|
| 687 |
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
|
| 688 |
}
|
| 689 |
|
| 690 |
+
float KQ_cms[nmeta]; // KQ combine max scale per warp.
|
| 691 |
+
#pragma unroll
|
| 692 |
+
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
| 693 |
+
KQ_cms[imeta] = expf(meta[imeta].x - KQ_cmn);
|
| 694 |
}
|
| 695 |
+
|
| 696 |
+
float KQ_crs = KQ_cms[0]*meta[0].y; // KQ combine rowsum, scaled sum of all parallel warps.
|
| 697 |
#pragma unroll
|
| 698 |
+
for (int imeta = 1; imeta < nmeta; ++imeta) {
|
| 699 |
+
KQ_crs += KQ_cms[imeta]*meta[imeta].y;
|
| 700 |
+
}
|
| 701 |
+
#pragma unroll
|
| 702 |
+
for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
|
| 703 |
+
if (offset >= WARP_SIZE) {
|
| 704 |
+
continue;
|
| 705 |
+
}
|
| 706 |
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
|
| 707 |
}
|
| 708 |
|
| 709 |
// Write back combined meta data:
|
| 710 |
+
#pragma unroll
|
| 711 |
+
for (int imeta = 0; imeta < nmeta; ++imeta) {
|
| 712 |
+
if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
|
| 713 |
+
// Combined KQ max scale + rowsum.
|
| 714 |
+
meta_ptr[imeta * WARP_SIZE * D2_padded/2] = make_float2(KQ_cms[imeta], KQ_crs);
|
| 715 |
+
}
|
| 716 |
}
|
| 717 |
+
|
| 718 |
+
// Combined KQ max + rowsum.
|
| 719 |
+
static_assert(cols_per_warp <= WARP_SIZE);
|
| 720 |
+
if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
| 721 |
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
|
| 722 |
+
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
| 723 |
}
|
| 724 |
+
if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
|
| 725 |
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
|
| 726 |
+
dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
|
| 727 |
}
|
| 728 |
}
|
| 729 |
|
|
|
|
| 738 |
|
| 739 |
#pragma unroll
|
| 740 |
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
|
| 741 |
+
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
|
| 742 |
+
const int k0_stop = D/2 - (D/2) % (1*stride_k);
|
| 743 |
+
const int stride_jc = WARP_SIZE / stride_k;
|
| 744 |
|
| 745 |
if (k0_start == k0_stop) {
|
| 746 |
continue;
|
| 747 |
}
|
| 748 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 749 |
#pragma unroll
|
| 750 |
+
for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
|
| 751 |
+
const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
|
| 752 |
+
|
| 753 |
+
if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
|
| 754 |
+
break;
|
| 755 |
+
}
|
| 756 |
+
|
| 757 |
+
const int jc_tile_K = (jc_dst/cols_per_warp)*(np*cols_per_warp) + jc_dst % cols_per_warp;
|
| 758 |
+
|
| 759 |
+
const int j_dst = jc_dst / ncols2;
|
| 760 |
+
const int c_dst = jc_dst % ncols2;
|
| 761 |
|
| 762 |
+
if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
|
| 763 |
continue;
|
| 764 |
}
|
| 765 |
+
|
| 766 |
+
const float * meta_j = (const float *) tile_K + jc_tile_K*D2_padded + D/2;
|
| 767 |
#pragma unroll
|
| 768 |
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
|
| 769 |
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
|
|
|
|
| 771 |
float2 dstk_val = make_float2(0.0f, 0.0f);
|
| 772 |
#pragma unroll
|
| 773 |
for (int ip = 0; ip < np; ++ip) {
|
| 774 |
+
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*cols_per_warp * D2_padded + 0];
|
| 775 |
+
const float2 dstk_val_add = __half22float2(tile_K[(jc_tile_K + ip*cols_per_warp) * D2_padded + k]);
|
| 776 |
dstk_val.x += dstk_val_add.x*KQ_crs;
|
| 777 |
dstk_val.y += dstk_val_add.y*KQ_crs;
|
| 778 |
}
|
|
|
|
| 784 |
}
|
| 785 |
|
| 786 |
if (is_fixup) {
|
| 787 |
+
dstk_fixup_data[jc_dst*(D/2) + k] = dstk_val;
|
| 788 |
} else {
|
| 789 |
+
dstk[((jt*ncols1 + j_dst)*ne02 + c_dst)*(D/2) + k] = dstk_val;
|
| 790 |
}
|
| 791 |
}
|
| 792 |
}
|
|
|
|
| 801 |
#endif // NEW_MMA_AVAILABLE
|
| 802 |
}
|
| 803 |
|
| 804 |
+
template<int D, int ncols1, int ncols2, int nwarps, int KQ_per_iter, int ntiles, bool use_logit_softcap>
|
|
|
|
| 805 |
__launch_bounds__(nwarps*WARP_SIZE, 2)
|
|
|
|
| 806 |
static __global__ void flash_attn_ext_f16(
|
| 807 |
const char * __restrict__ Q,
|
| 808 |
const char * __restrict__ K,
|
|
|
|
| 850 |
return;
|
| 851 |
}
|
| 852 |
|
| 853 |
+
static_assert(FATTN_KQ_STRIDE % KQ_per_iter == 0, "bad KQ_per_iter");
|
| 854 |
|
| 855 |
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
| 856 |
|
| 857 |
+
const int stride_Q1 = nb01 / sizeof(float2);
|
| 858 |
+
const int stride_Q2 = nb02 / sizeof(float2);
|
| 859 |
const int stride_KV = nb11 / sizeof(half2);
|
| 860 |
+
const int stride_mask = nb31 / sizeof(half2);
|
| 861 |
+
|
| 862 |
+
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
| 863 |
+
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
| 864 |
|
| 865 |
+
constexpr int kb_niter = FATTN_KQ_STRIDE / KQ_per_iter; // Number of kernel iterations per assigned KQ slice.
|
|
|
|
| 866 |
|
| 867 |
// kbc == k block continuous, current index in continuous ijk space.
|
| 868 |
+
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 869 |
+
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
| 870 |
|
| 871 |
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
| 872 |
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
|
|
|
| 879 |
const int channel = kbc / (iter_k*iter_j);
|
| 880 |
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
| 881 |
|
| 882 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 883 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 884 |
+
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
| 885 |
+
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 886 |
+
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
| 887 |
|
| 888 |
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 889 |
+
|
| 890 |
+
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 891 |
+
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
| 892 |
|
| 893 |
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
|
| 894 |
if (kb0_start == 0) {
|
| 895 |
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
| 896 |
+
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 897 |
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 898 |
+
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 899 |
} else {
|
| 900 |
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
| 901 |
+
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 902 |
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 903 |
+
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 904 |
}
|
| 905 |
|
| 906 |
kbc += iter_k;
|
|
|
|
| 917 |
const int channel = kbc / (iter_k*iter_j);
|
| 918 |
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
| 919 |
|
| 920 |
+
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
| 921 |
+
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
| 922 |
+
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
| 923 |
+
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
| 924 |
+
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
| 925 |
+
|
| 926 |
+
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
| 927 |
|
| 928 |
+
const int kb0_start_kernel = kb0_start * kb_niter;
|
| 929 |
+
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
| 930 |
|
| 931 |
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
| 932 |
constexpr bool needs_fixup = false;
|
| 933 |
+
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
| 934 |
+
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
| 935 |
+
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
| 936 |
}
|
| 937 |
|
| 938 |
+
template <int D, int ncols1, int ncols2>
|
| 939 |
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 940 |
+
constexpr int ncols = ncols1 * ncols2;
|
| 941 |
+
constexpr int KQ_per_iter = D <= 128 && ncols1 <= 64 ? 64 : 32;
|
| 942 |
+
constexpr int nwarps = (KQ_per_iter == 32 && ncols <= 16) ? 2 : 4;
|
| 943 |
+
constexpr int ntiles = ncols <= 8 ? 1 : (ncols <= 64 ? 2 : 4);
|
| 944 |
+
constexpr int cols_per_warp = ntiles * tile_B::I;
|
| 945 |
|
| 946 |
+
static_assert(D % tile_B::J == 0, "bad D");
|
| 947 |
+
static_assert(ncols % cols_per_warp == 0, "bad ncols");
|
| 948 |
|
| 949 |
const ggml_tensor * KQV = dst;
|
| 950 |
+
const int id = ggml_cuda_get_device();
|
| 951 |
+
const int cc = ggml_cuda_info().devices[id].cc;
|
| 952 |
+
|
| 953 |
+
const int KQ_shared_rows = cp_async_available(cc) ? 2*KQ_per_iter : KQ_per_iter;
|
| 954 |
|
| 955 |
+
const size_t nbytes_shared_KV = KQ_shared_rows * (D + 8) * sizeof(half);
|
| 956 |
+
const size_t nbytes_shared_mask = ncols1 * (KQ_per_iter + 8) * sizeof(half);
|
| 957 |
+
const size_t nbytes_shared_combine = nwarps*cols_per_warp * (D + 8) * sizeof(half);
|
| 958 |
|
| 959 |
+
const size_t nbytes_shared_total = std::max(nbytes_shared_KV + nbytes_shared_mask, nbytes_shared_combine);
|
|
|
|
|
|
|
| 960 |
|
| 961 |
float logit_softcap;
|
| 962 |
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
|
|
|
|
| 964 |
fattn_kernel_t fattn_kernel;
|
| 965 |
if (logit_softcap == 0.0f) {
|
| 966 |
constexpr bool use_logit_softcap = false;
|
| 967 |
+
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
| 968 |
} else {
|
| 969 |
constexpr bool use_logit_softcap = true;
|
| 970 |
+
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
|
| 971 |
}
|
| 972 |
+
|
| 973 |
+
launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
|
| 974 |
}
|
| 975 |
|
| 976 |
+
|
| 977 |
+
#define DECL_FATTN_MMA_F16_CASE(D, ncols1, ncols2) \
|
| 978 |
template void ggml_cuda_flash_attn_ext_mma_f16_case \
|
| 979 |
+
<D, ncols1, ncols2>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
|
| 980 |
+
|
| 981 |
+
#define DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(D, ncols) \
|
| 982 |
+
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/1, 1); \
|
| 983 |
+
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/2, 2); \
|
| 984 |
+
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
|
| 985 |
+
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
|
| 986 |
+
|
| 987 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8);
|
| 988 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8);
|
| 989 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8);
|
| 990 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8);
|
| 991 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8);
|
| 992 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8);
|
| 993 |
+
|
| 994 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16);
|
| 995 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16);
|
| 996 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16);
|
| 997 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16);
|
| 998 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16);
|
| 999 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16);
|
| 1000 |
+
|
| 1001 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32);
|
| 1002 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32);
|
| 1003 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32);
|
| 1004 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32);
|
| 1005 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32);
|
| 1006 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32);
|
| 1007 |
+
|
| 1008 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64);
|
| 1009 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64);
|
| 1010 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64);
|
| 1011 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64);
|
| 1012 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64);
|
| 1013 |
+
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64);
|
| 1014 |
+
|
| 1015 |
+
// Kernels with ncols == 128 are only 4% faster due to register pressure.
|
| 1016 |
+
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
|
| 1017 |
+
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
|
| 1018 |
+
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
|
| 1019 |
+
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
|
| 1020 |
+
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
|
| 1021 |
+
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
|
ggml/src/ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -302,14 +302,14 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 302 |
constexpr int nwarps = 8;
|
| 303 |
constexpr size_t nbytes_shared = 0;
|
| 304 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 305 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 306 |
} break;
|
| 307 |
case 128: {
|
| 308 |
constexpr int D = 128;
|
| 309 |
constexpr int nwarps = 8;
|
| 310 |
constexpr size_t nbytes_shared = 0;
|
| 311 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 312 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 313 |
} break;
|
| 314 |
default: {
|
| 315 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
|
|
| 302 |
constexpr int nwarps = 8;
|
| 303 |
constexpr size_t nbytes_shared = 0;
|
| 304 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 305 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 306 |
} break;
|
| 307 |
case 128: {
|
| 308 |
constexpr int D = 128;
|
| 309 |
constexpr int nwarps = 8;
|
| 310 |
constexpr size_t nbytes_shared = 0;
|
| 311 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 312 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 313 |
} break;
|
| 314 |
default: {
|
| 315 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
ggml/src/ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -296,14 +296,14 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 296 |
constexpr int nwarps = 8;
|
| 297 |
constexpr size_t nbytes_shared = 0;
|
| 298 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 299 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 300 |
} break;
|
| 301 |
case 128: {
|
| 302 |
constexpr int D = 128;
|
| 303 |
constexpr int nwarps = 8;
|
| 304 |
constexpr size_t nbytes_shared = 0;
|
| 305 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 306 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 307 |
} break;
|
| 308 |
default: {
|
| 309 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
|
|
|
| 296 |
constexpr int nwarps = 8;
|
| 297 |
constexpr size_t nbytes_shared = 0;
|
| 298 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 299 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 300 |
} break;
|
| 301 |
case 128: {
|
| 302 |
constexpr int D = 128;
|
| 303 |
constexpr int nwarps = 8;
|
| 304 |
constexpr size_t nbytes_shared = 0;
|
| 305 |
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
|
| 306 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
|
| 307 |
} break;
|
| 308 |
default: {
|
| 309 |
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
|
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -310,7 +310,7 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
|
|
| 310 |
constexpr bool need_f16_K = D != 128;
|
| 311 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 312 |
constexpr size_t nbytes_shared = 0;
|
| 313 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
| 314 |
}
|
| 315 |
|
| 316 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 310 |
constexpr bool need_f16_K = D != 128;
|
| 311 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 312 |
constexpr size_t nbytes_shared = 0;
|
| 313 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
| 314 |
}
|
| 315 |
|
| 316 |
template <int D, ggml_type type_K, ggml_type type_V>
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -290,7 +290,7 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
|
|
| 290 |
constexpr bool need_f16_K = D != 128;
|
| 291 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 292 |
constexpr size_t nbytes_shared = 0;
|
| 293 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
| 294 |
}
|
| 295 |
|
| 296 |
template <int D, ggml_type type_K, ggml_type type_V>
|
|
|
|
| 290 |
constexpr bool need_f16_K = D != 128;
|
| 291 |
constexpr bool need_f16_V = D != 128 && D != 64;
|
| 292 |
constexpr size_t nbytes_shared = 0;
|
| 293 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
|
| 294 |
}
|
| 295 |
|
| 296 |
template <int D, ggml_type type_K, ggml_type type_V>
|
ggml/src/ggml-cuda/fattn-wmma-f16.cu
CHANGED
|
@@ -478,7 +478,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 478 |
fattn_kernel = flash_attn_ext_f16<
|
| 479 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 480 |
}
|
| 481 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 482 |
return;
|
| 483 |
}
|
| 484 |
if (2*blocks_num_pb1 < 2*nsm) {
|
|
@@ -493,7 +493,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 493 |
fattn_kernel = flash_attn_ext_f16<
|
| 494 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 495 |
}
|
| 496 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 497 |
return;
|
| 498 |
}
|
| 499 |
constexpr int parallel_blocks = 1;
|
|
@@ -507,7 +507,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm
|
|
| 507 |
fattn_kernel = flash_attn_ext_f16<
|
| 508 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 509 |
}
|
| 510 |
-
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 511 |
}
|
| 512 |
|
| 513 |
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
|
|
| 478 |
fattn_kernel = flash_attn_ext_f16<
|
| 479 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 480 |
}
|
| 481 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 482 |
return;
|
| 483 |
}
|
| 484 |
if (2*blocks_num_pb1 < 2*nsm) {
|
|
|
|
| 493 |
fattn_kernel = flash_attn_ext_f16<
|
| 494 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 495 |
}
|
| 496 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 497 |
return;
|
| 498 |
}
|
| 499 |
constexpr int parallel_blocks = 1;
|
|
|
|
| 507 |
fattn_kernel = flash_attn_ext_f16<
|
| 508 |
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
|
| 509 |
}
|
| 510 |
+
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
|
| 511 |
}
|
| 512 |
|
| 513 |
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
ggml/src/ggml-cuda/fattn.cu
CHANGED
|
@@ -8,28 +8,50 @@
|
|
| 8 |
#include "fattn-wmma-f16.cuh"
|
| 9 |
#include "fattn.cuh"
|
| 10 |
|
| 11 |
-
template <int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 13 |
const ggml_tensor * Q = dst->src[0];
|
| 14 |
|
| 15 |
switch (Q->ne[0]) {
|
| 16 |
case 64:
|
| 17 |
-
|
| 18 |
break;
|
| 19 |
case 80:
|
| 20 |
-
|
| 21 |
break;
|
| 22 |
case 96:
|
| 23 |
-
|
| 24 |
break;
|
| 25 |
case 112:
|
| 26 |
-
|
| 27 |
break;
|
| 28 |
case 128:
|
| 29 |
-
|
| 30 |
break;
|
| 31 |
case 256:
|
| 32 |
-
|
| 33 |
break;
|
| 34 |
default:
|
| 35 |
GGML_ABORT("fatal error");
|
|
@@ -38,24 +60,35 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context
|
|
| 38 |
}
|
| 39 |
|
| 40 |
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 41 |
-
const ggml_tensor *
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
if (
|
| 44 |
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
| 45 |
return;
|
| 46 |
}
|
| 47 |
|
| 48 |
-
if (
|
| 49 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<
|
| 50 |
return;
|
| 51 |
}
|
| 52 |
|
| 53 |
-
if (
|
| 54 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<
|
| 55 |
return;
|
| 56 |
}
|
| 57 |
|
| 58 |
-
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<
|
| 59 |
}
|
| 60 |
|
| 61 |
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
|
@@ -209,8 +242,11 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
|
|
| 209 |
}
|
| 210 |
|
| 211 |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 212 |
-
const ggml_tensor * KQV
|
| 213 |
-
const ggml_tensor * Q
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
ggml_cuda_set_device(ctx.device);
|
| 216 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
@@ -252,7 +288,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
|
| 252 |
return;
|
| 253 |
}
|
| 254 |
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
| 256 |
if (prec == GGML_PREC_DEFAULT) {
|
| 257 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 258 |
return;
|
|
|
|
| 8 |
#include "fattn-wmma-f16.cuh"
|
| 9 |
#include "fattn.cuh"
|
| 10 |
|
| 11 |
+
template <int D, int ncols2>
|
| 12 |
+
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 13 |
+
const ggml_tensor * Q = dst->src[0];
|
| 14 |
+
|
| 15 |
+
if (Q->ne[1] <= 8/ncols2) {
|
| 16 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst);
|
| 17 |
+
return;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
if (Q->ne[1] <= 16/ncols2) {
|
| 21 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst);
|
| 22 |
+
return;
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
if (Q->ne[1] <= 32/ncols2) {
|
| 26 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst);
|
| 27 |
+
return;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
template <int ncols2>
|
| 34 |
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 35 |
const ggml_tensor * Q = dst->src[0];
|
| 36 |
|
| 37 |
switch (Q->ne[0]) {
|
| 38 |
case 64:
|
| 39 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst);
|
| 40 |
break;
|
| 41 |
case 80:
|
| 42 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst);
|
| 43 |
break;
|
| 44 |
case 96:
|
| 45 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst);
|
| 46 |
break;
|
| 47 |
case 112:
|
| 48 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst);
|
| 49 |
break;
|
| 50 |
case 128:
|
| 51 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst);
|
| 52 |
break;
|
| 53 |
case 256:
|
| 54 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst);
|
| 55 |
break;
|
| 56 |
default:
|
| 57 |
GGML_ABORT("fatal error");
|
|
|
|
| 60 |
}
|
| 61 |
|
| 62 |
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 63 |
+
const ggml_tensor * KQV = dst;
|
| 64 |
+
const ggml_tensor * Q = dst->src[0];
|
| 65 |
+
const ggml_tensor * K = dst->src[1];
|
| 66 |
+
const ggml_tensor * mask = dst->src[3];
|
| 67 |
+
|
| 68 |
+
float max_bias = 0.0f;
|
| 69 |
+
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
|
| 70 |
+
|
| 71 |
+
const float use_gqa_opt = mask && max_bias == 0.0f;
|
| 72 |
+
|
| 73 |
+
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
|
| 74 |
+
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
| 75 |
|
| 76 |
+
if (use_gqa_opt && gqa_ratio % 8 == 0) {
|
| 77 |
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
|
| 78 |
return;
|
| 79 |
}
|
| 80 |
|
| 81 |
+
if (use_gqa_opt && gqa_ratio == 4) {
|
| 82 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst);
|
| 83 |
return;
|
| 84 |
}
|
| 85 |
|
| 86 |
+
if (use_gqa_opt && gqa_ratio == 2) {
|
| 87 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst);
|
| 88 |
return;
|
| 89 |
}
|
| 90 |
|
| 91 |
+
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst);
|
| 92 |
}
|
| 93 |
|
| 94 |
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
|
|
|
|
| 242 |
}
|
| 243 |
|
| 244 |
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 245 |
+
const ggml_tensor * KQV = dst;
|
| 246 |
+
const ggml_tensor * Q = dst->src[0];
|
| 247 |
+
const ggml_tensor * K = dst->src[1];
|
| 248 |
+
const ggml_tensor * V = dst->src[2];
|
| 249 |
+
const ggml_tensor * mask = dst->src[3];
|
| 250 |
|
| 251 |
ggml_cuda_set_device(ctx.device);
|
| 252 |
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
|
|
|
|
| 288 |
return;
|
| 289 |
}
|
| 290 |
|
| 291 |
+
const int gqa_ratio = Q->ne[2] / K->ne[2];
|
| 292 |
+
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
|
| 293 |
+
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
|
| 294 |
+
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
|
| 295 |
if (prec == GGML_PREC_DEFAULT) {
|
| 296 |
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
|
| 297 |
return;
|
ggml/src/ggml-cuda/mma.cuh
CHANGED
|
@@ -73,6 +73,8 @@ namespace ggml_cuda_mma {
|
|
| 73 |
return threadIdx.x / 4;
|
| 74 |
} else if constexpr (I == 16 && J == 8) {
|
| 75 |
return (l / 2) * 8 + threadIdx.x / 4;
|
|
|
|
|
|
|
| 76 |
} else {
|
| 77 |
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 78 |
}
|
|
@@ -85,6 +87,8 @@ namespace ggml_cuda_mma {
|
|
| 85 |
return 4 * l + threadIdx.x % 4;
|
| 86 |
} else if constexpr (I == 16 && J == 8) {
|
| 87 |
return 2 * (threadIdx.x % 4) + l % 2;
|
|
|
|
|
|
|
| 88 |
} else {
|
| 89 |
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 90 |
}
|
|
@@ -289,6 +293,42 @@ namespace ggml_cuda_mma {
|
|
| 289 |
#endif // NEW_MMA_AVAILABLE
|
| 290 |
}
|
| 291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
static __device__ __forceinline__ void mma(
|
| 293 |
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
| 294 |
#ifdef NEW_MMA_AVAILABLE
|
|
@@ -316,4 +356,39 @@ namespace ggml_cuda_mma {
|
|
| 316 |
#endif // NEW_MMA_AVAILABLE
|
| 317 |
}
|
| 318 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
}
|
|
|
|
| 73 |
return threadIdx.x / 4;
|
| 74 |
} else if constexpr (I == 16 && J == 8) {
|
| 75 |
return (l / 2) * 8 + threadIdx.x / 4;
|
| 76 |
+
} else if constexpr (I == 16 && J == 16) {
|
| 77 |
+
return ((l / 2) % 2) * 8 + threadIdx.x / 4;
|
| 78 |
} else {
|
| 79 |
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 80 |
}
|
|
|
|
| 87 |
return 4 * l + threadIdx.x % 4;
|
| 88 |
} else if constexpr (I == 16 && J == 8) {
|
| 89 |
return 2 * (threadIdx.x % 4) + l % 2;
|
| 90 |
+
} else if constexpr (I == 16 && J == 16) {
|
| 91 |
+
return 8 * (l / 4) + 2 * (threadIdx.x % 4) + l % 2;
|
| 92 |
} else {
|
| 93 |
static_assert(I == -1 && J == -1, "template specialization not implemented");
|
| 94 |
}
|
|
|
|
| 293 |
#endif // NEW_MMA_AVAILABLE
|
| 294 |
}
|
| 295 |
|
| 296 |
+
static __device__ __forceinline__ void mma(
|
| 297 |
+
tile<16, 8, half2> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
| 298 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 299 |
+
const int * Axi = (const int *) A.x;
|
| 300 |
+
const int * Bxi = (const int *) B.x;
|
| 301 |
+
int * Dxi = (int *) D.x;
|
| 302 |
+
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 303 |
+
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
| 304 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
| 305 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
| 306 |
+
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
|
| 307 |
+
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 308 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
| 309 |
+
#else
|
| 310 |
+
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
| 311 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 312 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
| 313 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
| 314 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 315 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1])
|
| 316 |
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
| 317 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 318 |
+
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 319 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
| 320 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
|
| 321 |
+
: "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 322 |
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
| 323 |
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 324 |
+
#else
|
| 325 |
+
GGML_UNUSED(D);
|
| 326 |
+
GGML_UNUSED(A);
|
| 327 |
+
GGML_UNUSED(B);
|
| 328 |
+
NO_DEVICE_CODE;
|
| 329 |
+
#endif // NEW_MMA_AVAILABLE
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
static __device__ __forceinline__ void mma(
|
| 333 |
tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
|
| 334 |
#ifdef NEW_MMA_AVAILABLE
|
|
|
|
| 356 |
#endif // NEW_MMA_AVAILABLE
|
| 357 |
}
|
| 358 |
|
| 359 |
+
static __device__ __forceinline__ void mma(
|
| 360 |
+
tile<16, 16, float> & D, const tile<16, 8, half2> & A, const tile<16, 8, half2> & B) {
|
| 361 |
+
#ifdef NEW_MMA_AVAILABLE
|
| 362 |
+
const int * Axi = (const int *) A.x;
|
| 363 |
+
const int * Bxi = (const int *) B.x;
|
| 364 |
+
int * Dxi = (int *) D.x;
|
| 365 |
+
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 366 |
+
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 367 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 368 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[2]));
|
| 369 |
+
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
| 370 |
+
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
| 371 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]), "r"(Bxi[3]));
|
| 372 |
+
#else
|
| 373 |
+
// On Turing m16n8k16 mma is not available, use 4x m8n8k8 mma instead:
|
| 374 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 375 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 376 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
|
| 377 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 378 |
+
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3])
|
| 379 |
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[2]));
|
| 380 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 381 |
+
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
| 382 |
+
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[1]));
|
| 383 |
+
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
|
| 384 |
+
: "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
|
| 385 |
+
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
|
| 386 |
+
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
|
| 387 |
+
#else
|
| 388 |
+
GGML_UNUSED(D);
|
| 389 |
+
GGML_UNUSED(A);
|
| 390 |
+
GGML_UNUSED(B);
|
| 391 |
+
NO_DEVICE_CODE;
|
| 392 |
+
#endif // NEW_MMA_AVAILABLE
|
| 393 |
+
}
|
| 394 |
}
|
ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb16.cu → fattn-mma-f16-instance-ncols1_1-ncols2_8.cu}
RENAMED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64,
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80,
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96,
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112,
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128,
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256,
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 1, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 1, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 1, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 1, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 1, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 1, 8);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_1.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 16, 1);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 16, 1);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 16, 1);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 16, 1);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 16, 1);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 16, 1);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_2.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 16, 2);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 16, 2);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 16, 2);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 16, 2);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 16, 2);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 16, 2);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 16, 4);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 16, 4);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 16, 4);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 16, 4);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 16, 4);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 16, 4);
|
ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb32.cu → fattn-mma-f16-instance-ncols1_2-ncols2_4.cu}
RENAMED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64,
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80,
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96,
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112,
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128,
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256,
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 2, 4);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 2, 4);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 2, 4);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 2, 4);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 2, 4);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 2, 4);
|
ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb64.cu → fattn-mma-f16-instance-ncols1_2-ncols2_8.cu}
RENAMED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64,
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80,
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96,
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112,
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128,
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256,
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 2, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 2, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 2, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 2, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 2, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 2, 8);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_1.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 32, 1);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 32, 1);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 32, 1);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 32, 1);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 32, 1);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 32, 1);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_32-ncols2_2.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 32, 2);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 32, 2);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 32, 2);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 32, 2);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 32, 2);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 32, 2);
|
ggml/src/ggml-cuda/template-instances/{fattn-mma-f16-instance-cpb8.cu → fattn-mma-f16-instance-ncols1_4-ncols2_2.cu}
RENAMED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
-
DECL_FATTN_MMA_F16_CASE(64,
|
| 6 |
-
DECL_FATTN_MMA_F16_CASE(80,
|
| 7 |
-
DECL_FATTN_MMA_F16_CASE(96,
|
| 8 |
-
DECL_FATTN_MMA_F16_CASE(112,
|
| 9 |
-
DECL_FATTN_MMA_F16_CASE(128,
|
| 10 |
-
DECL_FATTN_MMA_F16_CASE(256,
|
|
|
|
| 2 |
|
| 3 |
#include "../fattn-mma-f16.cuh"
|
| 4 |
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 4, 2);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 4, 2);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 4, 2);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 4, 2);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 4, 2);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 4, 2);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 4, 4);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 4, 4);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 4, 4);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 4, 4);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 4, 4);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 4, 4);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 4, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 4, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 4, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 4, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 4, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 4, 8);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_64-ncols2_1.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 64, 1);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 64, 1);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 64, 1);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 64, 1);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 64, 1);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 64, 1);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_1.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 8, 1);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 8, 1);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 8, 1);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 8, 1);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 8, 1);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 8, 1);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_2.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 8, 2);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 8, 2);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 8, 2);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 8, 2);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 8, 2);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 8, 2);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 8, 4);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 8, 4);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 8, 4);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 8, 4);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 8, 4);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 8, 4);
|
ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
|
| 2 |
+
|
| 3 |
+
#include "../fattn-mma-f16.cuh"
|
| 4 |
+
|
| 5 |
+
DECL_FATTN_MMA_F16_CASE(64, 8, 8);
|
| 6 |
+
DECL_FATTN_MMA_F16_CASE(80, 8, 8);
|
| 7 |
+
DECL_FATTN_MMA_F16_CASE(96, 8, 8);
|
| 8 |
+
DECL_FATTN_MMA_F16_CASE(112, 8, 8);
|
| 9 |
+
DECL_FATTN_MMA_F16_CASE(128, 8, 8);
|
| 10 |
+
DECL_FATTN_MMA_F16_CASE(256, 8, 8);
|
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
CHANGED
|
@@ -18,7 +18,7 @@ SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_f
|
|
| 18 |
|
| 19 |
"""
|
| 20 |
|
| 21 |
-
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {
|
| 22 |
|
| 23 |
TYPES_MMQ = [
|
| 24 |
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
|
@@ -57,12 +57,18 @@ for vkq_size in [16, 32]:
|
|
| 57 |
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
| 58 |
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
| 59 |
|
| 60 |
-
for
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
for type in TYPES_MMQ:
|
| 68 |
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|
|
|
|
| 18 |
|
| 19 |
"""
|
| 20 |
|
| 21 |
+
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {ncols1}, {ncols2});\n"
|
| 22 |
|
| 23 |
TYPES_MMQ = [
|
| 24 |
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
|
|
|
|
| 57 |
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
|
| 58 |
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
|
| 59 |
|
| 60 |
+
for ncols in [8, 16, 32, 64, 128]:
|
| 61 |
+
for ncols2 in [1, 2, 4, 8]:
|
| 62 |
+
ncols1 = ncols // ncols2
|
| 63 |
+
if ncols == 128:
|
| 64 |
+
continue # Too much register pressure.
|
| 65 |
+
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
|
| 66 |
+
f.write(SOURCE_FATTN_MMA_START)
|
| 67 |
+
|
| 68 |
+
for head_size in [64, 80, 96, 112, 128, 256]:
|
| 69 |
+
if ncols == 128 and head_size == 256:
|
| 70 |
+
continue # Needs too much shared memory.
|
| 71 |
+
f.write(SOURCE_FATTN_MMA_CASE.format(ncols1=ncols1, ncols2=ncols2, head_size=head_size))
|
| 72 |
|
| 73 |
for type in TYPES_MMQ:
|
| 74 |
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
|