mrfatso commited on
Commit
8ece1ee
·
1 Parent(s): 1a0281c

OpenCL: add initial FA support (llama/14987)

Browse files

* add F16/F16 fa support

* fix kernel init

* use mad instead of fma

* use inline function

* mark FA with sinks as unsupported for now

* add pragma unroll to loops

ggml/src/ggml-opencl/CMakeLists.txt CHANGED
@@ -112,6 +112,9 @@ set(GGML_OPENCL_KERNELS
112
  mul_mat_f16_f32
113
  conv2d
114
  conv2d_f16_f32
 
 
 
115
  )
116
 
117
  foreach (K ${GGML_OPENCL_KERNELS})
 
112
  mul_mat_f16_f32
113
  conv2d
114
  conv2d_f16_f32
115
+ flash_attn_f32_f16
116
+ flash_attn_f16
117
+ flash_attn_f32
118
  )
119
 
120
  foreach (K ${GGML_OPENCL_KERNELS})
ggml/src/ggml-opencl/ggml-opencl.cpp CHANGED
@@ -25,6 +25,7 @@
25
  #include <vector>
26
  #include <string>
27
  #include <cmath>
 
28
  #include <memory>
29
  #include <charconv>
30
  #include <mutex>
@@ -424,6 +425,14 @@ struct ggml_backend_opencl_context {
424
  cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
425
  cl_kernel kernel_soft_max, kernel_soft_max_4;
426
  cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
 
 
 
 
 
 
 
 
427
  cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
428
  cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
429
  cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
@@ -1308,6 +1317,73 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
1308
  GGML_LOG_CONT(".");
1309
  }
1310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1311
  // argsort
1312
  {
1313
  #ifdef GGML_OPENCL_EMBED_KERNELS
@@ -2636,6 +2712,45 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
2636
  return op->src[0]->type == GGML_TYPE_F32;
2637
  case GGML_OP_SUM_ROWS:
2638
  return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2639
  default:
2640
  return false;
2641
  }
@@ -5451,6 +5566,133 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
5451
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
5452
  }
5453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5454
  static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5455
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5456
 
@@ -7607,6 +7849,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
7607
  }
7608
  func = ggml_cl_sum_rows;
7609
  break;
 
 
 
 
 
 
7610
  default:
7611
  return false;
7612
  }
 
25
  #include <vector>
26
  #include <string>
27
  #include <cmath>
28
+ #include <map>
29
  #include <memory>
30
  #include <charconv>
31
  #include <mutex>
 
425
  cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
426
  cl_kernel kernel_soft_max, kernel_soft_max_4;
427
  cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16;
428
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16;
429
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f16_q1;
430
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32;
431
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_q1;
432
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16;
433
+ std::map<std::pair<int, int>, cl_kernel> kernels_flash_attn_f32_f16_q1;
434
+ std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
435
+ std::map<std::pair<int, int>, int> kernels_flash_attn_bn;
436
  cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
437
  cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
438
  cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
 
1317
  GGML_LOG_CONT(".");
1318
  }
1319
 
1320
+ // flash_attn
1321
+ {
1322
+ #ifdef GGML_OPENCL_EMBED_KERNELS
1323
+ const std::string kernel_src_f16 {
1324
+ #include "flash_attn_f16.cl.h"
1325
+ };
1326
+ const std::string kernel_src_f32 {
1327
+ #include "flash_attn_f32.cl.h"
1328
+ };
1329
+ const std::string kernel_src_f32_f16 {
1330
+ #include "flash_attn_f32_f16.cl.h"
1331
+ };
1332
+ #else
1333
+ const std::string kernel_src_f16 = read_file("flash_attn_f16.cl");
1334
+ const std::string kernel_src_f32 = read_file("flash_attn_f32.cl");
1335
+ const std::string kernel_src_f32_f16 = read_file("flash_attn_f32_f16.cl");
1336
+ #endif
1337
+
1338
+ if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
1339
+ const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
1340
+ { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32},
1341
+ {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
1342
+ {192, 192, 16, 16}, {256, 256, 16, 16},
1343
+ };
1344
+
1345
+ for (size_t i = 0; i < sizeof(fa_dims)/sizeof(fa_dims[0]); ++i) {
1346
+ const int dk = fa_dims[i].dk;
1347
+ const int dv = fa_dims[i].dv;
1348
+ const int bm = fa_dims[i].bm;
1349
+ const int bn = fa_dims[i].bn;
1350
+ std::string OPTS = compile_opts +
1351
+ " -D DK=" + std::to_string(dk) +
1352
+ " -D DV=" + std::to_string(dv) +
1353
+ " -D BLOCK_M=" + std::to_string(bm) +
1354
+ " -D BLOCK_N=" + std::to_string(bn);
1355
+
1356
+ cl_program prog_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16.c_str(), OPTS);
1357
+ cl_kernel k_f16, k_f16_q1;
1358
+ CL_CHECK((k_f16 = clCreateKernel(prog_f16, "flash_attn_f16", &err), err));
1359
+ CL_CHECK((k_f16_q1 = clCreateKernel(prog_f16, "flash_attn_f16_q1", &err), err));
1360
+ backend_ctx->kernels_flash_attn_f16[{dk, dv}] = k_f16;
1361
+ backend_ctx->kernels_flash_attn_f16_q1[{dk, dv}] = k_f16_q1;
1362
+ CL_CHECK(clReleaseProgram(prog_f16));
1363
+
1364
+ cl_program prog_f32 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32.c_str(), OPTS);
1365
+ cl_kernel k_f32, k_f32_q1;
1366
+ CL_CHECK((k_f32 = clCreateKernel(prog_f32, "flash_attn_f32", &err), err));
1367
+ CL_CHECK((k_f32_q1 = clCreateKernel(prog_f32, "flash_attn_f32_q1", &err), err));
1368
+ backend_ctx->kernels_flash_attn_f32[{dk, dv}] = k_f32;
1369
+ backend_ctx->kernels_flash_attn_f32_q1[{dk, dv}] = k_f32_q1;
1370
+ CL_CHECK(clReleaseProgram(prog_f32));
1371
+
1372
+ cl_program prog_f32_f16 = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f32_f16.c_str(), OPTS);
1373
+ cl_kernel k_f32_f16, k_f32_f16_q1;
1374
+ CL_CHECK((k_f32_f16 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16", &err), err));
1375
+ CL_CHECK((k_f32_f16_q1 = clCreateKernel(prog_f32_f16, "flash_attn_f32_f16_q1", &err), err));
1376
+ backend_ctx->kernels_flash_attn_f32_f16[{dk, dv}] = k_f32_f16;
1377
+ backend_ctx->kernels_flash_attn_f32_f16_q1[{dk, dv}] = k_f32_f16_q1;
1378
+ CL_CHECK(clReleaseProgram(prog_f32_f16));
1379
+
1380
+ backend_ctx->kernels_flash_attn_bm[{dk, dv}] = bm;
1381
+ backend_ctx->kernels_flash_attn_bn[{dk, dv}] = bn;
1382
+ }
1383
+ GGML_LOG_CONT(".");
1384
+ }
1385
+ }
1386
+
1387
  // argsort
1388
  {
1389
  #ifdef GGML_OPENCL_EMBED_KERNELS
 
2712
  return op->src[0]->type == GGML_TYPE_F32;
2713
  case GGML_OP_SUM_ROWS:
2714
  return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
2715
+ case GGML_OP_FLASH_ATTN_EXT:
2716
+ {
2717
+ if (op->src[4]) {
2718
+ return false;
2719
+ }
2720
+
2721
+ const ggml_tensor * q = op->src[0];
2722
+ const ggml_tensor * k = op->src[1];
2723
+ const ggml_tensor * v = op->src[2];
2724
+
2725
+ const int dk = q->ne[0];
2726
+ const int dv = v->ne[0];
2727
+
2728
+ const struct { int dk; int dv; } supported_dims[] = {
2729
+ { 64, 64}, { 80, 80}, { 96, 96},
2730
+ {112, 112}, {128, 128}, {192, 128},
2731
+ {192, 192}, {256, 256},
2732
+ };
2733
+
2734
+ bool dims_supported = false;
2735
+ for (size_t i = 0; i < sizeof(supported_dims)/sizeof(supported_dims[0]); ++i) {
2736
+ if (supported_dims[i].dk == dk && supported_dims[i].dv == dv) {
2737
+ dims_supported = true;
2738
+ break;
2739
+ }
2740
+ }
2741
+ if (!dims_supported) {
2742
+ return false;
2743
+ }
2744
+
2745
+ const bool is_f32_f32 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F32 &&
2746
+ v->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
2747
+ const bool is_f16_f16 = q->type == GGML_TYPE_F16 && k->type == GGML_TYPE_F16 &&
2748
+ v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16;
2749
+ const bool is_f32_f16 = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16 &&
2750
+ v->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F32;
2751
+
2752
+ return is_f32_f32 || is_f16_f16 || is_f32_f16;
2753
+ }
2754
  default:
2755
  return false;
2756
  }
 
5566
  backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, NULL, dst);
5567
  }
5568
 
5569
+ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
5570
+ const ggml_tensor * v = dst->src[2];
5571
+ const ggml_tensor * mask = dst->src[3];
5572
+ GGML_ASSERT(q->extra);
5573
+ GGML_ASSERT(k->extra);
5574
+ GGML_ASSERT(v->extra);
5575
+ GGML_ASSERT(dst->extra);
5576
+ if (mask) {
5577
+ GGML_ASSERT(mask->extra);
5578
+ }
5579
+
5580
+ ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5581
+
5582
+ const int n_q = q->ne[1];
5583
+ const int n_kv = k->ne[1];
5584
+ const int d_head_q = q->ne[0];
5585
+ const int d_head_v = v->ne[0];
5586
+ const int n_head = q->ne[2];
5587
+ const int n_head_kv = k->ne[2];
5588
+ const int n_batch = q->ne[3];
5589
+
5590
+ cl_kernel kernel = NULL;
5591
+
5592
+ const bool is_f16 = q->type == GGML_TYPE_F16;
5593
+ const bool is_mixed = q->type == GGML_TYPE_F32 && k->type == GGML_TYPE_F16;
5594
+ const std::pair<int, int> dk_dv = {d_head_q, d_head_v};
5595
+
5596
+ if (n_q == 1) {
5597
+ if (is_mixed) {
5598
+ kernel = backend_ctx->kernels_flash_attn_f32_f16_q1.at(dk_dv);
5599
+ } else if (is_f16) {
5600
+ kernel = backend_ctx->kernels_flash_attn_f16_q1.at(dk_dv);
5601
+ } else {
5602
+ kernel = backend_ctx->kernels_flash_attn_f32_q1.at(dk_dv);
5603
+ }
5604
+ } else {
5605
+ if (is_mixed) {
5606
+ kernel = backend_ctx->kernels_flash_attn_f32_f16.at(dk_dv);
5607
+ } else if (is_f16) {
5608
+ kernel = backend_ctx->kernels_flash_attn_f16.at(dk_dv);
5609
+ } else {
5610
+ kernel = backend_ctx->kernels_flash_attn_f32.at(dk_dv);
5611
+ }
5612
+ }
5613
+ GGML_ASSERT(kernel != NULL);
5614
+
5615
+ ggml_tensor_extra_cl * extra_q = (ggml_tensor_extra_cl *)q->extra;
5616
+ ggml_tensor_extra_cl * extra_k = (ggml_tensor_extra_cl *)k->extra;
5617
+ ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
5618
+ ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
5619
+ ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
5620
+
5621
+ cl_ulong offset_q = extra_q->offset + q->view_offs;
5622
+ cl_ulong offset_k = extra_k->offset + k->view_offs;
5623
+ cl_ulong offset_v = extra_v->offset + v->view_offs;
5624
+ cl_ulong offset_o = extra_o->offset + dst->view_offs;
5625
+ cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
5626
+ cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
5627
+
5628
+ const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
5629
+ const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
5630
+ const cl_ulong v_nb1 = v->nb[1], v_nb2 = v->nb[2], v_nb3 = v->nb[3];
5631
+ const cl_ulong o_nb1 = dst->nb[1], o_nb2 = dst->nb[2], o_nb3 = dst->nb[3];
5632
+ const cl_ulong mask_nb1 = mask ? mask->nb[1] : 0;
5633
+ const cl_ulong mask_nb2 = mask ? mask->nb[2] : 0;
5634
+ const cl_ulong mask_nb3 = mask ? mask->nb[3] : 0;
5635
+ const int mask_ne2 = mask ? mask->ne[2] : 0;
5636
+ const int mask_ne3 = mask ? mask->ne[3] : 0;
5637
+
5638
+ float scale, max_bias, logit_softcap;
5639
+ const float * params = (const float *)dst->op_params;
5640
+ scale = params[0];
5641
+ max_bias = params[1];
5642
+ logit_softcap = params[2];
5643
+
5644
+ const int is_causal = (mask == NULL && n_q > 1 && n_q == n_kv);
5645
+
5646
+ const int n_head_log2_val = n_head > 0 ? 1u << (int)floorf(log2f((float)n_head)) : 0;
5647
+ const float n_head_log2_f = n_head_log2_val > 0 ? (float)n_head_log2_val : 1.0f;
5648
+ const float m0 = powf(2.0f, -(max_bias) / n_head_log2_f);
5649
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2_f);
5650
+
5651
+ CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_q->data_device));
5652
+ CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset_q));
5653
+ CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_k->data_device));
5654
+ CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset_k));
5655
+ CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra_v->data_device));
5656
+ CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset_v));
5657
+ CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extra_o->data_device));
5658
+ CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offset_o));
5659
+ CL_CHECK(clSetKernelArg(kernel, 8, sizeof(float), &scale));
5660
+ CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &n_q));
5661
+ CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &n_kv));
5662
+ CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &is_causal));
5663
+ CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &n_head));
5664
+ CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &q_nb1)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &q_nb2)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &q_nb3));
5665
+ CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &k_nb1)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &k_nb2)); CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &k_nb3));
5666
+ CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &v_nb1)); CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &v_nb2)); CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &v_nb3));
5667
+ CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &o_nb1)); CL_CHECK(clSetKernelArg(kernel, 23, sizeof(cl_ulong), &o_nb2)); CL_CHECK(clSetKernelArg(kernel, 24, sizeof(cl_ulong), &o_nb3));
5668
+ CL_CHECK(clSetKernelArg(kernel, 25, sizeof(float), &max_bias));
5669
+ CL_CHECK(clSetKernelArg(kernel, 26, sizeof(float), &m0));
5670
+ CL_CHECK(clSetKernelArg(kernel, 27, sizeof(float), &m1));
5671
+ CL_CHECK(clSetKernelArg(kernel, 28, sizeof(int), &n_head_log2_val));
5672
+ CL_CHECK(clSetKernelArg(kernel, 29, sizeof(float), &logit_softcap));
5673
+ CL_CHECK(clSetKernelArg(kernel, 30, sizeof(int), &n_head_kv));
5674
+ CL_CHECK(clSetKernelArg(kernel, 31, sizeof(cl_mem), &mask_buffer));
5675
+ CL_CHECK(clSetKernelArg(kernel, 32, sizeof(cl_ulong), &offset_mask));
5676
+ CL_CHECK(clSetKernelArg(kernel, 33, sizeof(cl_ulong), &mask_nb1));
5677
+ CL_CHECK(clSetKernelArg(kernel, 34, sizeof(cl_ulong), &mask_nb2));
5678
+ CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
5679
+ CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
5680
+ CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
5681
+
5682
+ if (n_q == 1) {
5683
+ const size_t wg_size = 64;
5684
+ size_t local_work_size[] = { wg_size, 1 };
5685
+ size_t global_work_size[] = { wg_size, (size_t)(n_head * n_batch) };
5686
+ backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
5687
+ } else {
5688
+ const int block_m = backend_ctx->kernels_flash_attn_bm.at(dk_dv);
5689
+ const size_t wg_size = block_m;
5690
+ size_t local_work_size[] = { wg_size, 1 };
5691
+ size_t global_work_size[] = { (size_t)((n_q + block_m - 1) / block_m) * wg_size, (size_t)(n_head * n_batch) };
5692
+ backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst);
5693
+ }
5694
+ }
5695
+
5696
  static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
5697
  ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
5698
 
 
7849
  }
7850
  func = ggml_cl_sum_rows;
7851
  break;
7852
+ case GGML_OP_FLASH_ATTN_EXT:
7853
+ if (!any_on_device) {
7854
+ return false;
7855
+ }
7856
+ ggml_cl_flash_attn(backend, tensor->src[0], tensor->src[1], tensor);
7857
+ return true;
7858
  default:
7859
  return false;
7860
  }
ggml/src/ggml-opencl/kernels/flash_attn_f16.cl ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
+
3
+ #define ACC_TYPE float
4
+ #define ACC_TYPE4 float4
5
+ #define DATA_TYPE half
6
+ #define DATA_TYPE4 half4
7
+ #define CONVERT_ACC4(x) convert_float4(x)
8
+ #define CONVERT_DATA4(x) convert_half4(x)
9
+
10
+ #define DK_VEC (DK/4)
11
+ #define DV_VEC (DV/4)
12
+ #define WG_SIZE (BLOCK_M)
13
+ #define Q1_WG_SIZE 64
14
+
15
+ inline float get_alibi_slope(
16
+ const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
17
+ ) {
18
+ if (max_bias <= 0.0f) {
19
+ return 1.0f;
20
+ }
21
+ const float base = h < n_head_log2 ? m0 : m1;
22
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
23
+
24
+ return pow(base, exph);
25
+ }
26
+ __kernel void flash_attn_f16(
27
+ const global void * q_void, ulong q_offset,
28
+ const global void * k_void, ulong k_offset,
29
+ const global void * v_void, ulong v_offset,
30
+ global void * o_void, ulong o_offset,
31
+ const float scale,
32
+ const int n_q,
33
+ const int n_kv,
34
+ const int is_causal,
35
+ const int n_head,
36
+ const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
37
+ const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
38
+ const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
39
+ const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
40
+ const float max_bias,
41
+ const float m0,
42
+ const float m1,
43
+ const int n_head_log2,
44
+ const float logit_softcap,
45
+ const int n_head_kv,
46
+ const global void* mask_void,
47
+ const ulong mask_offset,
48
+ const ulong mask_nb1,
49
+ const ulong mask_nb2,
50
+ const ulong mask_nb3,
51
+ const int mask_ne2,
52
+ const int mask_ne3
53
+ ) {
54
+ const int tid = get_local_id(0);
55
+ const int block_q_idx = get_group_id(0);
56
+ const int head_batch_idx = get_global_id(1);
57
+
58
+ const int my_query_row = block_q_idx * BLOCK_M + tid;
59
+
60
+ const int batch_idx = head_batch_idx / n_head;
61
+ const int head_idx = head_batch_idx % n_head;
62
+
63
+ const int gqa_ratio = n_head / n_head_kv;
64
+ const int head_kv_idx = head_idx / gqa_ratio;
65
+
66
+ const global char* q_base = (const global char*)q_void + q_offset;
67
+ const global char* k_base = (const global char*)k_void + k_offset;
68
+ const global char* v_base = (const global char*)v_void + v_offset;
69
+ global char* o_base = (global char*)o_void + o_offset;
70
+
71
+ const global char* mask_base = NULL;
72
+ if (mask_void != NULL) {
73
+ const int mask_head_idx = head_idx % mask_ne2;
74
+ const int mask_batch_idx = batch_idx % mask_ne3;
75
+ mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
76
+ }
77
+
78
+ ACC_TYPE4 q_priv[DK_VEC];
79
+ if (my_query_row < n_q) {
80
+ const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
81
+ const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
82
+ #pragma unroll
83
+ for (int i = 0; i < DK_VEC; ++i) {
84
+ q_priv[i] = CONVERT_ACC4(q_ptr[i]);
85
+ }
86
+ }
87
+
88
+ ACC_TYPE4 o_acc[DV_VEC];
89
+ #pragma unroll
90
+ for (int i = 0; i < DV_VEC; ++i) {
91
+ o_acc[i] = (ACC_TYPE4)(0.0f);
92
+ }
93
+ ACC_TYPE m_i = -INFINITY;
94
+ ACC_TYPE l_i = 0.0f;
95
+
96
+ float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
97
+
98
+ __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
99
+ __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
100
+
101
+ for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
102
+ for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
103
+ const int row = i / DK_VEC;
104
+ const int col = i % DK_VEC;
105
+ const int k_row_idx = k_start + row;
106
+ if (k_row_idx < n_kv) {
107
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
108
+ l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
109
+ }
110
+ }
111
+ for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
112
+ const int row = i / DV_VEC;
113
+ const int col = i % DV_VEC;
114
+ const int v_row_idx = k_start + row;
115
+ if (v_row_idx < n_kv) {
116
+ const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
117
+ l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
118
+ }
119
+ }
120
+ barrier(CLK_LOCAL_MEM_FENCE);
121
+
122
+ if (my_query_row >= n_q) {
123
+ continue;
124
+ }
125
+
126
+ for (int j = 0; j < BLOCK_N; j += 2) {
127
+ const int k_row0 = k_start + j;
128
+ const int k_row1 = k_start + j + 1;
129
+
130
+ ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
131
+ ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
132
+ #pragma unroll
133
+ for (int k = 0; k < DK_VEC; k++) {
134
+ dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
135
+ dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
136
+ }
137
+ ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
138
+ ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
139
+
140
+ if (is_causal) {
141
+ if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
142
+ if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
143
+ }
144
+
145
+ if (k_row0 >= n_kv) score0 = -INFINITY;
146
+ if (k_row1 >= n_kv) score1 = -INFINITY;
147
+
148
+ if (mask_base != NULL) {
149
+ const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
150
+ if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
151
+ if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
152
+ }
153
+
154
+ if (logit_softcap > 0.0f) {
155
+ score0 = logit_softcap * tanh(score0 / logit_softcap);
156
+ score1 = logit_softcap * tanh(score1 / logit_softcap);
157
+ }
158
+
159
+ const ACC_TYPE m_new = max(m_i, max(score0, score1));
160
+ const ACC_TYPE p0 = exp(score0 - m_new);
161
+ const ACC_TYPE p1 = exp(score1 - m_new);
162
+ const ACC_TYPE scale_prev = exp(m_i - m_new);
163
+
164
+ #pragma unroll
165
+ for (int i = 0; i < DV_VEC; ++i) {
166
+ o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
167
+ }
168
+ l_i = l_i * scale_prev + p0 + p1;
169
+ m_i = m_new;
170
+ }
171
+ }
172
+
173
+ if (my_query_row < n_q) {
174
+ const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
175
+ global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
176
+ if (l_i > 0.0f) {
177
+ const ACC_TYPE l_inv = 1.0f / l_i;
178
+ #pragma unroll
179
+ for (int i = 0; i < DV_VEC; ++i) {
180
+ o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
181
+ }
182
+ } else {
183
+ #pragma unroll
184
+ for (int i = 0; i < DV_VEC; ++i) {
185
+ o_row[i] = (DATA_TYPE4)(0.0f);
186
+ }
187
+ }
188
+ }
189
+ }
190
+
191
+ __kernel void flash_attn_f16_q1(
192
+ const global void * q_void, ulong q_offset,
193
+ const global void * k_void, ulong k_offset,
194
+ const global void * v_void, ulong v_offset,
195
+ global void * o_void, ulong o_offset,
196
+ const float scale,
197
+ const int n_q,
198
+ const int n_kv,
199
+ const int is_causal,
200
+ const int n_head,
201
+ const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
202
+ const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
203
+ const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
204
+ const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
205
+ const float max_bias,
206
+ const float m0,
207
+ const float m1,
208
+ const int n_head_log2,
209
+ const float logit_softcap,
210
+ const int n_head_kv,
211
+ const global void* mask_void,
212
+ const ulong mask_offset,
213
+ const ulong mask_nb1,
214
+ const ulong mask_nb2,
215
+ const ulong mask_nb3,
216
+ const int mask_ne2,
217
+ const int mask_ne3
218
+ ) {
219
+ const int tid = get_local_id(0);
220
+ const int head_batch_idx = get_global_id(1);
221
+
222
+ const int batch_idx = head_batch_idx / n_head;
223
+ const int head_idx = head_batch_idx % n_head;
224
+
225
+ const int gqa_ratio = n_head / n_head_kv;
226
+ const int head_kv_idx = head_idx / gqa_ratio;
227
+
228
+ const global char* q_base = (const global char*)q_void + q_offset;
229
+ const global char* k_base = (const global char*)k_void + k_offset;
230
+ const global char* v_base = (const global char*)v_void + v_offset;
231
+ global char* o_base = (global char*)o_void + o_offset;
232
+
233
+ const global char* mask_base = NULL;
234
+ if (mask_void != NULL) {
235
+ const int mask_head_idx = head_idx % mask_ne2;
236
+ const int mask_batch_idx = batch_idx % mask_ne3;
237
+ mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
238
+ }
239
+
240
+ ACC_TYPE4 q_priv[DK_VEC];
241
+ const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
242
+ const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
243
+ #pragma unroll
244
+ for (int i = 0; i < DK_VEC; ++i) {
245
+ q_priv[i] = CONVERT_ACC4(q_ptr[i]);
246
+ }
247
+
248
+ float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
249
+
250
+ ACC_TYPE m_i = -INFINITY;
251
+ for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
252
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
253
+ const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
254
+ ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
255
+ #pragma unroll
256
+ for (int k = 0; k < DK_VEC; k++) {
257
+ dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
258
+ }
259
+ ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
260
+ if (mask_base != NULL) {
261
+ const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
262
+ score += slope * (ACC_TYPE)mask_ptr[k_idx];
263
+ }
264
+ if (logit_softcap > 0.0f) {
265
+ score = logit_softcap * tanh(score / logit_softcap);
266
+ }
267
+ m_i = max(m_i, score);
268
+ }
269
+
270
+ __local ACC_TYPE local_m[Q1_WG_SIZE];
271
+ local_m[tid] = m_i;
272
+ barrier(CLK_LOCAL_MEM_FENCE);
273
+ #pragma unroll
274
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
275
+ if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
276
+ barrier(CLK_LOCAL_MEM_FENCE);
277
+ }
278
+ const ACC_TYPE m_final = local_m[0];
279
+
280
+ ACC_TYPE4 o_acc[DV_VEC];
281
+ #pragma unroll
282
+ for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
283
+ ACC_TYPE l_i = 0.0f;
284
+
285
+ for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
286
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
287
+ const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
288
+ const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
289
+ const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
290
+ ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
291
+ #pragma unroll
292
+ for (int k = 0; k < DK_VEC; k++) {
293
+ dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
294
+ }
295
+ ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
296
+ if (mask_base != NULL) {
297
+ const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
298
+ score += slope * (ACC_TYPE)mask_ptr[k_idx];
299
+ }
300
+ if (logit_softcap > 0.0f) {
301
+ score = logit_softcap * tanh(score / logit_softcap);
302
+ }
303
+ const ACC_TYPE p = exp(score - m_final);
304
+ l_i += p;
305
+ #pragma unroll
306
+ for (int i = 0; i < DV_VEC; i++) {
307
+ o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
308
+ }
309
+ }
310
+
311
+ __local ACC_TYPE local_l[Q1_WG_SIZE];
312
+ __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
313
+ local_l[tid] = l_i;
314
+ barrier(CLK_LOCAL_MEM_FENCE);
315
+ #pragma unroll
316
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
317
+ if (tid < s) local_l[tid] += local_l[tid + s];
318
+ barrier(CLK_LOCAL_MEM_FENCE);
319
+ }
320
+
321
+ const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
322
+ global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
323
+ const ACC_TYPE l_final = local_l[0];
324
+
325
+ if (l_final > 0.0f) {
326
+ const ACC_TYPE l_inv = 1.0f / l_final;
327
+ for (int i = 0; i < DV_VEC; i++) {
328
+ local_o_comp[tid] = o_acc[i];
329
+ barrier(CLK_LOCAL_MEM_FENCE);
330
+ #pragma unroll
331
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
332
+ if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
333
+ barrier(CLK_LOCAL_MEM_FENCE);
334
+ }
335
+ if (tid == 0) {
336
+ o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
337
+ }
338
+ }
339
+ } else if (tid == 0) {
340
+ #pragma unroll
341
+ for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
342
+ }
343
+ }
ggml/src/ggml-opencl/kernels/flash_attn_f32.cl ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
+
3
+ #define ACC_TYPE float
4
+ #define ACC_TYPE4 float4
5
+ #define DATA_TYPE float
6
+ #define DATA_TYPE4 float4
7
+ #define CONVERT_ACC4(x) (x)
8
+ #define CONVERT_DATA4(x) (x)
9
+
10
+ #define DK_VEC (DK/4)
11
+ #define DV_VEC (DV/4)
12
+ #define WG_SIZE (BLOCK_M)
13
+ #define Q1_WG_SIZE 64
14
+
15
+ inline float get_alibi_slope(
16
+ const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
17
+ ) {
18
+ if (max_bias <= 0.0f) {
19
+ return 1.0f;
20
+ }
21
+ const float base = h < n_head_log2 ? m0 : m1;
22
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
23
+
24
+ return pow(base, exph);
25
+ }
26
+ __kernel void flash_attn_f32(
27
+ const global void * q_void, ulong q_offset,
28
+ const global void * k_void, ulong k_offset,
29
+ const global void * v_void, ulong v_offset,
30
+ global void * o_void, ulong o_offset,
31
+ const float scale,
32
+ const int n_q,
33
+ const int n_kv,
34
+ const int is_causal,
35
+ const int n_head,
36
+ const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
37
+ const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
38
+ const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
39
+ const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
40
+ const float max_bias,
41
+ const float m0,
42
+ const float m1,
43
+ const int n_head_log2,
44
+ const float logit_softcap,
45
+ const int n_head_kv,
46
+ const global void* mask_void,
47
+ const ulong mask_offset,
48
+ const ulong mask_nb1,
49
+ const ulong mask_nb2,
50
+ const ulong mask_nb3,
51
+ const int mask_ne2,
52
+ const int mask_ne3
53
+ ) {
54
+ const int tid = get_local_id(0);
55
+ const int block_q_idx = get_group_id(0);
56
+ const int head_batch_idx = get_global_id(1);
57
+
58
+ const int my_query_row = block_q_idx * BLOCK_M + tid;
59
+
60
+ const int batch_idx = head_batch_idx / n_head;
61
+ const int head_idx = head_batch_idx % n_head;
62
+
63
+ const int gqa_ratio = n_head / n_head_kv;
64
+ const int head_kv_idx = head_idx / gqa_ratio;
65
+
66
+ const global char* q_base = (const global char*)q_void + q_offset;
67
+ const global char* k_base = (const global char*)k_void + k_offset;
68
+ const global char* v_base = (const global char*)v_void + v_offset;
69
+ global char* o_base = (global char*)o_void + o_offset;
70
+
71
+ const global char* mask_base = NULL;
72
+ if (mask_void != NULL) {
73
+ const int mask_head_idx = head_idx % mask_ne2;
74
+ const int mask_batch_idx = batch_idx % mask_ne3;
75
+ mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
76
+ }
77
+
78
+ ACC_TYPE4 q_priv[DK_VEC];
79
+ if (my_query_row < n_q) {
80
+ const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
81
+ const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
82
+ #pragma unroll
83
+ for (int i = 0; i < DK_VEC; ++i) {
84
+ q_priv[i] = CONVERT_ACC4(q_ptr[i]);
85
+ }
86
+ }
87
+
88
+ ACC_TYPE4 o_acc[DV_VEC];
89
+ #pragma unroll
90
+ for (int i = 0; i < DV_VEC; ++i) {
91
+ o_acc[i] = (ACC_TYPE4)(0.0f);
92
+ }
93
+ ACC_TYPE m_i = -INFINITY;
94
+ ACC_TYPE l_i = 0.0f;
95
+
96
+ float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
97
+
98
+ __local DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
99
+ __local DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
100
+
101
+ for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
102
+ for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
103
+ const int row = i / DK_VEC;
104
+ const int col = i % DK_VEC;
105
+ const int k_row_idx = k_start + row;
106
+ if (k_row_idx < n_kv) {
107
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
108
+ l_k[row][col] = ((__global DATA_TYPE4*)(k_base + k_row_offset))[col];
109
+ }
110
+ }
111
+ for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
112
+ const int row = i / DV_VEC;
113
+ const int col = i % DV_VEC;
114
+ const int v_row_idx = k_start + row;
115
+ if (v_row_idx < n_kv) {
116
+ const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
117
+ l_v[row][col] = ((__global DATA_TYPE4*)(v_base + v_row_offset))[col];
118
+ }
119
+ }
120
+ barrier(CLK_LOCAL_MEM_FENCE);
121
+
122
+ if (my_query_row >= n_q) {
123
+ continue;
124
+ }
125
+
126
+ for (int j = 0; j < BLOCK_N; j += 2) {
127
+ const int k_row0 = k_start + j;
128
+ const int k_row1 = k_start + j + 1;
129
+
130
+ ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
131
+ ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
132
+ #pragma unroll
133
+ for (int k = 0; k < DK_VEC; k++) {
134
+ dot_acc0 = mad(q_priv[k], CONVERT_ACC4(l_k[j][k]), dot_acc0);
135
+ dot_acc1 = mad(q_priv[k], CONVERT_ACC4(l_k[j+1][k]), dot_acc1);
136
+ }
137
+ ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
138
+ ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
139
+
140
+ if (is_causal) {
141
+ if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
142
+ if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
143
+ }
144
+
145
+ if (k_row0 >= n_kv) score0 = -INFINITY;
146
+ if (k_row1 >= n_kv) score1 = -INFINITY;
147
+
148
+ if (mask_base != NULL) {
149
+ const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
150
+ if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
151
+ if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
152
+ }
153
+
154
+ if (logit_softcap > 0.0f) {
155
+ score0 = logit_softcap * tanh(score0 / logit_softcap);
156
+ score1 = logit_softcap * tanh(score1 / logit_softcap);
157
+ }
158
+
159
+ const ACC_TYPE m_new = max(m_i, max(score0, score1));
160
+ const ACC_TYPE p0 = exp(score0 - m_new);
161
+ const ACC_TYPE p1 = exp(score1 - m_new);
162
+ const ACC_TYPE scale_prev = exp(m_i - m_new);
163
+
164
+ #pragma unroll
165
+ for (int i = 0; i < DV_VEC; ++i) {
166
+ o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_ACC4(l_v[j][i]) + p1 * CONVERT_ACC4(l_v[j+1][i]);
167
+ }
168
+ l_i = l_i * scale_prev + p0 + p1;
169
+ m_i = m_new;
170
+ }
171
+ }
172
+
173
+ if (my_query_row < n_q) {
174
+ const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
175
+ global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
176
+ if (l_i > 0.0f) {
177
+ const ACC_TYPE l_inv = 1.0f / l_i;
178
+ #pragma unroll
179
+ for (int i = 0; i < DV_VEC; ++i) {
180
+ o_row[i] = CONVERT_DATA4(o_acc[i] * l_inv);
181
+ }
182
+ } else {
183
+ #pragma unroll
184
+ for (int i = 0; i < DV_VEC; ++i) {
185
+ o_row[i] = (DATA_TYPE4)(0.0f);
186
+ }
187
+ }
188
+ }
189
+ }
190
+
191
+ __kernel void flash_attn_f32_q1(
192
+ const global void * q_void, ulong q_offset,
193
+ const global void * k_void, ulong k_offset,
194
+ const global void * v_void, ulong v_offset,
195
+ global void * o_void, ulong o_offset,
196
+ const float scale,
197
+ const int n_q,
198
+ const int n_kv,
199
+ const int is_causal,
200
+ const int n_head,
201
+ const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
202
+ const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
203
+ const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
204
+ const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
205
+ const float max_bias,
206
+ const float m0,
207
+ const float m1,
208
+ const int n_head_log2,
209
+ const float logit_softcap,
210
+ const int n_head_kv,
211
+ const global void* mask_void,
212
+ const ulong mask_offset,
213
+ const ulong mask_nb1,
214
+ const ulong mask_nb2,
215
+ const ulong mask_nb3,
216
+ const int mask_ne2,
217
+ const int mask_ne3
218
+ ) {
219
+ const int tid = get_local_id(0);
220
+ const int head_batch_idx = get_global_id(1);
221
+
222
+ const int batch_idx = head_batch_idx / n_head;
223
+ const int head_idx = head_batch_idx % n_head;
224
+
225
+ const int gqa_ratio = n_head / n_head_kv;
226
+ const int head_kv_idx = head_idx / gqa_ratio;
227
+
228
+ const global char* q_base = (const global char*)q_void + q_offset;
229
+ const global char* k_base = (const global char*)k_void + k_offset;
230
+ const global char* v_base = (const global char*)v_void + v_offset;
231
+ global char* o_base = (global char*)o_void + o_offset;
232
+
233
+ const global char* mask_base = NULL;
234
+ if (mask_void != NULL) {
235
+ const int mask_head_idx = head_idx % mask_ne2;
236
+ const int mask_batch_idx = batch_idx % mask_ne3;
237
+ mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
238
+ }
239
+
240
+ ACC_TYPE4 q_priv[DK_VEC];
241
+ const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
242
+ const global DATA_TYPE4* q_ptr = (const global DATA_TYPE4*)(q_base + q_row_offset);
243
+ #pragma unroll
244
+ for (int i = 0; i < DK_VEC; ++i) {
245
+ q_priv[i] = CONVERT_ACC4(q_ptr[i]);
246
+ }
247
+
248
+ float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
249
+
250
+ ACC_TYPE m_i = -INFINITY;
251
+ for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
252
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
253
+ const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
254
+ ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
255
+ #pragma unroll
256
+ for (int k = 0; k < DK_VEC; k++) {
257
+ dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
258
+ }
259
+ ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
260
+ if (mask_base != NULL) {
261
+ const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
262
+ score += slope * (ACC_TYPE)mask_ptr[k_idx];
263
+ }
264
+ if (logit_softcap > 0.0f) {
265
+ score = logit_softcap * tanh(score / logit_softcap);
266
+ }
267
+ m_i = max(m_i, score);
268
+ }
269
+
270
+ __local ACC_TYPE local_m[Q1_WG_SIZE];
271
+ local_m[tid] = m_i;
272
+ barrier(CLK_LOCAL_MEM_FENCE);
273
+ #pragma unroll
274
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
275
+ if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
276
+ barrier(CLK_LOCAL_MEM_FENCE);
277
+ }
278
+ const ACC_TYPE m_final = local_m[0];
279
+
280
+ ACC_TYPE4 o_acc[DV_VEC];
281
+ #pragma unroll
282
+ for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
283
+ ACC_TYPE l_i = 0.0f;
284
+
285
+ for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
286
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
287
+ const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
288
+ const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
289
+ const global DATA_TYPE4* v_ptr = (const global DATA_TYPE4*)(v_base + v_row_offset);
290
+ ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
291
+ #pragma unroll
292
+ for (int k = 0; k < DK_VEC; k++) {
293
+ dot_acc = mad(q_priv[k], CONVERT_ACC4(k_ptr[k]), dot_acc);
294
+ }
295
+ ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
296
+ if (mask_base != NULL) {
297
+ const global DATA_TYPE* mask_ptr = (const global DATA_TYPE*)(mask_base);
298
+ score += slope * (ACC_TYPE)mask_ptr[k_idx];
299
+ }
300
+ if (logit_softcap > 0.0f) {
301
+ score = logit_softcap * tanh(score / logit_softcap);
302
+ }
303
+ const ACC_TYPE p = exp(score - m_final);
304
+ l_i += p;
305
+ #pragma unroll
306
+ for (int i = 0; i < DV_VEC; i++) {
307
+ o_acc[i] = mad(p, CONVERT_ACC4(v_ptr[i]), o_acc[i]);
308
+ }
309
+ }
310
+
311
+ __local ACC_TYPE local_l[Q1_WG_SIZE];
312
+ __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
313
+ local_l[tid] = l_i;
314
+ barrier(CLK_LOCAL_MEM_FENCE);
315
+ #pragma unroll
316
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
317
+ if (tid < s) local_l[tid] += local_l[tid + s];
318
+ barrier(CLK_LOCAL_MEM_FENCE);
319
+ }
320
+
321
+ const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
322
+ global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
323
+ const ACC_TYPE l_final = local_l[0];
324
+
325
+ if (l_final > 0.0f) {
326
+ const ACC_TYPE l_inv = 1.0f / l_final;
327
+ for (int i = 0; i < DV_VEC; i++) {
328
+ local_o_comp[tid] = o_acc[i];
329
+ barrier(CLK_LOCAL_MEM_FENCE);
330
+ #pragma unroll
331
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
332
+ if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
333
+ barrier(CLK_LOCAL_MEM_FENCE);
334
+ }
335
+ if (tid == 0) {
336
+ o_row[i] = CONVERT_DATA4(local_o_comp[0] * l_inv);
337
+ }
338
+ }
339
+ } else if (tid == 0) {
340
+ #pragma unroll
341
+ for (int i = 0; i < DV_VEC; ++i) o_row[i] = (DATA_TYPE4)(0.0f);
342
+ }
343
+ }
ggml/src/ggml-opencl/kernels/flash_attn_f32_f16.cl ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma OPENCL EXTENSION cl_khr_fp16 : enable
2
+
3
+ #define ACC_TYPE float
4
+ #define ACC_TYPE4 float4
5
+ #define Q_DATA_TYPE4 float4
6
+ #define KV_DATA_TYPE4 half4
7
+ #define O_DATA_TYPE4 float4
8
+ #define MASK_DATA_TYPE half
9
+ #define CONVERT_Q_ACC4(x) (x)
10
+ #define CONVERT_KV_ACC4(x) convert_float4(x)
11
+ #define CONVERT_O_DATA4(x) (x)
12
+
13
+ #define DK_VEC (DK/4)
14
+ #define DV_VEC (DV/4)
15
+ #define WG_SIZE (BLOCK_M)
16
+ #define Q1_WG_SIZE 64
17
+
18
+ inline float get_alibi_slope(
19
+ const float max_bias, const uint h, const uint n_head_log2, const float m0, const float m1
20
+ ) {
21
+ if (max_bias <= 0.0f) {
22
+ return 1.0f;
23
+ }
24
+ const float base = h < n_head_log2 ? m0 : m1;
25
+ const int exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
26
+
27
+ return pow(base, exph);
28
+ }
29
+ __kernel void flash_attn_f32_f16(
30
+ const global void * q_void, ulong q_offset,
31
+ const global void * k_void, ulong k_offset,
32
+ const global void * v_void, ulong v_offset,
33
+ global void * o_void, ulong o_offset,
34
+ const float scale,
35
+ const int n_q,
36
+ const int n_kv,
37
+ const int is_causal,
38
+ const int n_head,
39
+ const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
40
+ const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
41
+ const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
42
+ const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
43
+ const float max_bias,
44
+ const float m0,
45
+ const float m1,
46
+ const int n_head_log2,
47
+ const float logit_softcap,
48
+ const int n_head_kv,
49
+ const global void* mask_void,
50
+ const ulong mask_offset,
51
+ const ulong mask_nb1,
52
+ const ulong mask_nb2,
53
+ const ulong mask_nb3,
54
+ const int mask_ne2,
55
+ const int mask_ne3
56
+ ) {
57
+ const int tid = get_local_id(0);
58
+ const int block_q_idx = get_group_id(0);
59
+ const int head_batch_idx = get_global_id(1);
60
+
61
+ const int my_query_row = block_q_idx * BLOCK_M + tid;
62
+
63
+ const int batch_idx = head_batch_idx / n_head;
64
+ const int head_idx = head_batch_idx % n_head;
65
+
66
+ const int gqa_ratio = n_head / n_head_kv;
67
+ const int head_kv_idx = head_idx / gqa_ratio;
68
+
69
+ const global char* q_base = (const global char*)q_void + q_offset;
70
+ const global char* k_base = (const global char*)k_void + k_offset;
71
+ const global char* v_base = (const global char*)v_void + v_offset;
72
+ global char* o_base = (global char*)o_void + o_offset;
73
+
74
+ const global char* mask_base = NULL;
75
+ if (mask_void != NULL) {
76
+ const int mask_head_idx = head_idx % mask_ne2;
77
+ const int mask_batch_idx = batch_idx % mask_ne3;
78
+ mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
79
+ }
80
+
81
+ ACC_TYPE4 q_priv[DK_VEC];
82
+ if (my_query_row < n_q) {
83
+ const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2 + my_query_row * q_nb1;
84
+ const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
85
+ #pragma unroll
86
+ for (int i = 0; i < DK_VEC; ++i) {
87
+ q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
88
+ }
89
+ }
90
+
91
+ ACC_TYPE4 o_acc[DV_VEC];
92
+ #pragma unroll
93
+ for (int i = 0; i < DV_VEC; ++i) {
94
+ o_acc[i] = (ACC_TYPE4)(0.0f);
95
+ }
96
+ ACC_TYPE m_i = -INFINITY;
97
+ ACC_TYPE l_i = 0.0f;
98
+
99
+ float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
100
+
101
+ __local KV_DATA_TYPE4 l_k[BLOCK_N][DK_VEC];
102
+ __local KV_DATA_TYPE4 l_v[BLOCK_N][DV_VEC];
103
+
104
+ for (int k_start = 0; k_start < n_kv; k_start += BLOCK_N) {
105
+ for (int i = tid; i < BLOCK_N * DK_VEC; i += WG_SIZE) {
106
+ const int row = i / DK_VEC;
107
+ const int col = i % DK_VEC;
108
+ const int k_row_idx = k_start + row;
109
+ if (k_row_idx < n_kv) {
110
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_row_idx * k_nb1;
111
+ l_k[row][col] = ((__global KV_DATA_TYPE4*)(k_base + k_row_offset))[col];
112
+ }
113
+ }
114
+ for (int i = tid; i < BLOCK_N * DV_VEC; i += WG_SIZE) {
115
+ const int row = i / DV_VEC;
116
+ const int col = i % DV_VEC;
117
+ const int v_row_idx = k_start + row;
118
+ if (v_row_idx < n_kv) {
119
+ const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + v_row_idx * v_nb1;
120
+ l_v[row][col] = ((__global KV_DATA_TYPE4*)(v_base + v_row_offset))[col];
121
+ }
122
+ }
123
+ barrier(CLK_LOCAL_MEM_FENCE);
124
+
125
+ if (my_query_row >= n_q) {
126
+ continue;
127
+ }
128
+
129
+ for (int j = 0; j < BLOCK_N; j += 2) {
130
+ const int k_row0 = k_start + j;
131
+ const int k_row1 = k_start + j + 1;
132
+
133
+ ACC_TYPE4 dot_acc0 = (ACC_TYPE4)(0.0f);
134
+ ACC_TYPE4 dot_acc1 = (ACC_TYPE4)(0.0f);
135
+ #pragma unroll
136
+ for (int k = 0; k < DK_VEC; k++) {
137
+ dot_acc0 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j][k]), dot_acc0);
138
+ dot_acc1 = mad(q_priv[k], CONVERT_KV_ACC4(l_k[j+1][k]), dot_acc1);
139
+ }
140
+ ACC_TYPE score0 = (dot_acc0.s0 + dot_acc0.s1 + dot_acc0.s2 + dot_acc0.s3) * scale;
141
+ ACC_TYPE score1 = (dot_acc1.s0 + dot_acc1.s1 + dot_acc1.s2 + dot_acc1.s3) * scale;
142
+
143
+ if (is_causal) {
144
+ if (k_row0 > (n_kv - n_q + my_query_row)) score0 = -INFINITY;
145
+ if (k_row1 > (n_kv - n_q + my_query_row)) score1 = -INFINITY;
146
+ }
147
+
148
+ if (k_row0 >= n_kv) score0 = -INFINITY;
149
+ if (k_row1 >= n_kv) score1 = -INFINITY;
150
+
151
+ if (mask_base != NULL) {
152
+ const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base + my_query_row * mask_nb1);
153
+ if (k_row0 < n_kv) score0 += slope * (ACC_TYPE)mask_ptr[k_row0];
154
+ if (k_row1 < n_kv) score1 += slope * (ACC_TYPE)mask_ptr[k_row1];
155
+ }
156
+
157
+ if (logit_softcap > 0.0f) {
158
+ score0 = logit_softcap * tanh(score0 / logit_softcap);
159
+ score1 = logit_softcap * tanh(score1 / logit_softcap);
160
+ }
161
+
162
+ const ACC_TYPE m_new = max(m_i, max(score0, score1));
163
+ const ACC_TYPE p0 = exp(score0 - m_new);
164
+ const ACC_TYPE p1 = exp(score1 - m_new);
165
+ const ACC_TYPE scale_prev = exp(m_i - m_new);
166
+
167
+ #pragma unroll
168
+ for (int i = 0; i < DV_VEC; ++i) {
169
+ o_acc[i] = o_acc[i] * scale_prev + p0 * CONVERT_KV_ACC4(l_v[j][i]) + p1 * CONVERT_KV_ACC4(l_v[j+1][i]);
170
+ }
171
+ l_i = l_i * scale_prev + p0 + p1;
172
+ m_i = m_new;
173
+ }
174
+ }
175
+
176
+ if (my_query_row < n_q) {
177
+ const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
178
+ global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
179
+ if (l_i > 0.0f) {
180
+ const ACC_TYPE l_inv = 1.0f / l_i;
181
+ #pragma unroll
182
+ for (int i = 0; i < DV_VEC; ++i) {
183
+ o_row[i] = CONVERT_O_DATA4(o_acc[i] * l_inv);
184
+ }
185
+ } else {
186
+ #pragma unroll
187
+ for (int i = 0; i < DV_VEC; ++i) {
188
+ o_row[i] = (O_DATA_TYPE4)(0.0f);
189
+ }
190
+ }
191
+ }
192
+ }
193
+
194
+ __kernel void flash_attn_f32_f16_q1(
195
+ const global void * q_void, ulong q_offset,
196
+ const global void * k_void, ulong k_offset,
197
+ const global void * v_void, ulong v_offset,
198
+ global void * o_void, ulong o_offset,
199
+ const float scale,
200
+ const int n_q,
201
+ const int n_kv,
202
+ const int is_causal,
203
+ const int n_head,
204
+ const ulong q_nb1, const ulong q_nb2, const ulong q_nb3,
205
+ const ulong k_nb1, const ulong k_nb2, const ulong k_nb3,
206
+ const ulong v_nb1, const ulong v_nb2, const ulong v_nb3,
207
+ const ulong o_nb1, const ulong o_nb2, const ulong o_nb3,
208
+ const float max_bias,
209
+ const float m0,
210
+ const float m1,
211
+ const int n_head_log2,
212
+ const float logit_softcap,
213
+ const int n_head_kv,
214
+ const global void* mask_void,
215
+ const ulong mask_offset,
216
+ const ulong mask_nb1,
217
+ const ulong mask_nb2,
218
+ const ulong mask_nb3,
219
+ const int mask_ne2,
220
+ const int mask_ne3
221
+ ) {
222
+ const int tid = get_local_id(0);
223
+ const int head_batch_idx = get_global_id(1);
224
+
225
+ const int batch_idx = head_batch_idx / n_head;
226
+ const int head_idx = head_batch_idx % n_head;
227
+
228
+ const int gqa_ratio = n_head / n_head_kv;
229
+ const int head_kv_idx = head_idx / gqa_ratio;
230
+
231
+ const global char* q_base = (const global char*)q_void + q_offset;
232
+ const global char* k_base = (const global char*)k_void + k_offset;
233
+ const global char* v_base = (const global char*)v_void + v_offset;
234
+ global char* o_base = (global char*)o_void + o_offset;
235
+
236
+ const global char* mask_base = NULL;
237
+ if (mask_void != NULL) {
238
+ const int mask_head_idx = head_idx % mask_ne2;
239
+ const int mask_batch_idx = batch_idx % mask_ne3;
240
+ mask_base = (const global char*)mask_void + mask_offset + mask_batch_idx * mask_nb3 + mask_head_idx * mask_nb2;
241
+ }
242
+
243
+ ACC_TYPE4 q_priv[DK_VEC];
244
+ const ulong q_row_offset = batch_idx * q_nb3 + head_idx * q_nb2;
245
+ const global Q_DATA_TYPE4* q_ptr = (const global Q_DATA_TYPE4*)(q_base + q_row_offset);
246
+ #pragma unroll
247
+ for (int i = 0; i < DK_VEC; ++i) {
248
+ q_priv[i] = CONVERT_Q_ACC4(q_ptr[i]);
249
+ }
250
+
251
+ float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
252
+
253
+ ACC_TYPE m_i = -INFINITY;
254
+ for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
255
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
256
+ const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
257
+ ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
258
+ #pragma unroll
259
+ for (int k = 0; k < DK_VEC; k++) {
260
+ dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
261
+ }
262
+ ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
263
+ if (mask_base != NULL) {
264
+ const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
265
+ score += slope * (ACC_TYPE)mask_ptr[k_idx];
266
+ }
267
+ if (logit_softcap > 0.0f) {
268
+ score = logit_softcap * tanh(score / logit_softcap);
269
+ }
270
+ m_i = max(m_i, score);
271
+ }
272
+
273
+ __local ACC_TYPE local_m[Q1_WG_SIZE];
274
+ local_m[tid] = m_i;
275
+ barrier(CLK_LOCAL_MEM_FENCE);
276
+ #pragma unroll
277
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
278
+ if (tid < s) local_m[tid] = max(local_m[tid], local_m[tid + s]);
279
+ barrier(CLK_LOCAL_MEM_FENCE);
280
+ }
281
+ const ACC_TYPE m_final = local_m[0];
282
+
283
+ ACC_TYPE4 o_acc[DV_VEC];
284
+ #pragma unroll
285
+ for (int i = 0; i < DV_VEC; ++i) o_acc[i] = (ACC_TYPE4)(0.0f);
286
+ ACC_TYPE l_i = 0.0f;
287
+
288
+ for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
289
+ const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
290
+ const ulong v_row_offset = batch_idx * v_nb3 + head_kv_idx * v_nb2 + k_idx * v_nb1;
291
+ const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
292
+ const global KV_DATA_TYPE4* v_ptr = (const global KV_DATA_TYPE4*)(v_base + v_row_offset);
293
+ ACC_TYPE4 dot_acc = (ACC_TYPE4)(0.0f);
294
+ #pragma unroll
295
+ for (int k = 0; k < DK_VEC; k++) {
296
+ dot_acc = mad(q_priv[k], CONVERT_KV_ACC4(k_ptr[k]), dot_acc);
297
+ }
298
+ ACC_TYPE score = (dot_acc.s0 + dot_acc.s1 + dot_acc.s2 + dot_acc.s3) * scale;
299
+ if (mask_base != NULL) {
300
+ const global MASK_DATA_TYPE* mask_ptr = (const global MASK_DATA_TYPE*)(mask_base);
301
+ score += slope * (ACC_TYPE)mask_ptr[k_idx];
302
+ }
303
+ if (logit_softcap > 0.0f) {
304
+ score = logit_softcap * tanh(score / logit_softcap);
305
+ }
306
+ const ACC_TYPE p = exp(score - m_final);
307
+ l_i += p;
308
+ #pragma unroll
309
+ for (int i = 0; i < DV_VEC; i++) {
310
+ o_acc[i] = mad(p, CONVERT_KV_ACC4(v_ptr[i]), o_acc[i]);
311
+ }
312
+ }
313
+
314
+ __local ACC_TYPE local_l[Q1_WG_SIZE];
315
+ __local ACC_TYPE4 local_o_comp[Q1_WG_SIZE];
316
+ local_l[tid] = l_i;
317
+ barrier(CLK_LOCAL_MEM_FENCE);
318
+ #pragma unroll
319
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
320
+ if (tid < s) local_l[tid] += local_l[tid + s];
321
+ barrier(CLK_LOCAL_MEM_FENCE);
322
+ }
323
+
324
+ const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
325
+ global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
326
+ const ACC_TYPE l_final = local_l[0];
327
+
328
+ if (l_final > 0.0f) {
329
+ const ACC_TYPE l_inv = 1.0f / l_final;
330
+ for (int i = 0; i < DV_VEC; i++) {
331
+ local_o_comp[tid] = o_acc[i];
332
+ barrier(CLK_LOCAL_MEM_FENCE);
333
+ #pragma unroll
334
+ for (int s = Q1_WG_SIZE / 2; s > 0; s >>= 1) {
335
+ if (tid < s) local_o_comp[tid] += local_o_comp[tid + s];
336
+ barrier(CLK_LOCAL_MEM_FENCE);
337
+ }
338
+ if (tid == 0) {
339
+ o_row[i] = CONVERT_O_DATA4(local_o_comp[0] * l_inv);
340
+ }
341
+ }
342
+ } else if (tid == 0) {
343
+ #pragma unroll
344
+ for (int i = 0; i < DV_VEC; ++i) o_row[i] = (O_DATA_TYPE4)(0.0f);
345
+ }
346
+ }