ggerganov commited on
Commit
26c019a
·
unverified ·
1 Parent(s): b5903fc

ggml : add ALiBi support for ggml_soft_max_ext (llama/5488)

Browse files
Files changed (6) hide show
  1. ggml-alloc.c +7 -7
  2. ggml-cuda.cu +55 -202
  3. ggml-metal.m +27 -8
  4. ggml-metal.metal +41 -6
  5. ggml.c +78 -38
  6. ggml.h +9 -4
ggml-alloc.c CHANGED
@@ -468,7 +468,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
468
  for (int i = 0; i < GGML_MAX_SRC; i++) {
469
  struct ggml_tensor * parent = node->src[i];
470
  if (parent == NULL) {
471
- break;
472
  }
473
 
474
  // if the node's data is external, then we cannot re-use it
@@ -565,7 +565,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
565
  for (int j = 0; j < GGML_MAX_SRC; j++) {
566
  struct ggml_tensor * src = node->src[j];
567
  if (src == NULL) {
568
- break;
569
  }
570
 
571
  ggml_gallocr_hash_get(galloc, src)->n_children += 1;
@@ -599,7 +599,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
599
  for (int j = 0; j < GGML_MAX_SRC; j++) {
600
  struct ggml_tensor * parent = node->src[j];
601
  if (parent == NULL) {
602
- break;
603
  }
604
  ggml_gallocr_allocate_node(galloc, parent, buffer_id);
605
  }
@@ -611,7 +611,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
611
  for (int j = 0; j < GGML_MAX_SRC; j++) {
612
  struct ggml_tensor * parent = node->src[j];
613
  if (parent == NULL) {
614
- break;
615
  }
616
  AT_PRINTF("%s", parent->name);
617
  if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
@@ -624,7 +624,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
624
  for (int j = 0; j < GGML_MAX_SRC; j++) {
625
  struct ggml_tensor * parent = node->src[j];
626
  if (parent == NULL) {
627
- break;
628
  }
629
  struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
630
  p_hn->n_children -= 1;
@@ -810,7 +810,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
810
  for (int j = 0; j < GGML_MAX_SRC; j++) {
811
  struct ggml_tensor * src = node->src[j];
812
  if (src == NULL) {
813
- break;
814
  }
815
  if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
816
  #ifndef NDEBUG
@@ -857,7 +857,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
857
  for (int j = 0; j < GGML_MAX_SRC; j++) {
858
  struct ggml_tensor * src = node->src[j];
859
  if (src == NULL) {
860
- break;
861
  }
862
  ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]);
863
  }
 
468
  for (int i = 0; i < GGML_MAX_SRC; i++) {
469
  struct ggml_tensor * parent = node->src[i];
470
  if (parent == NULL) {
471
+ continue;
472
  }
473
 
474
  // if the node's data is external, then we cannot re-use it
 
565
  for (int j = 0; j < GGML_MAX_SRC; j++) {
566
  struct ggml_tensor * src = node->src[j];
567
  if (src == NULL) {
568
+ continue;
569
  }
570
 
571
  ggml_gallocr_hash_get(galloc, src)->n_children += 1;
 
599
  for (int j = 0; j < GGML_MAX_SRC; j++) {
600
  struct ggml_tensor * parent = node->src[j];
601
  if (parent == NULL) {
602
+ continue;
603
  }
604
  ggml_gallocr_allocate_node(galloc, parent, buffer_id);
605
  }
 
611
  for (int j = 0; j < GGML_MAX_SRC; j++) {
612
  struct ggml_tensor * parent = node->src[j];
613
  if (parent == NULL) {
614
+ continue;
615
  }
616
  AT_PRINTF("%s", parent->name);
617
  if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
 
624
  for (int j = 0; j < GGML_MAX_SRC; j++) {
625
  struct ggml_tensor * parent = node->src[j];
626
  if (parent == NULL) {
627
+ continue;
628
  }
629
  struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
630
  p_hn->n_children -= 1;
 
810
  for (int j = 0; j < GGML_MAX_SRC; j++) {
811
  struct ggml_tensor * src = node->src[j];
812
  if (src == NULL) {
813
+ continue;
814
  }
815
  if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
816
  #ifndef NDEBUG
 
857
  for (int j = 0; j < GGML_MAX_SRC; j++) {
858
  struct ggml_tensor * src = node->src[j];
859
  if (src == NULL) {
860
+ continue;
861
  }
862
  ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]);
863
  }
ggml-cuda.cu CHANGED
@@ -5956,148 +5956,30 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
5956
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
5957
  }
5958
 
5959
- template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
5960
- static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
5961
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
5962
- const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
5963
- const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
5964
 
5965
  const int tid = threadIdx.x;
5966
  const int rowx = blockIdx.x;
5967
- const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
5968
 
5969
  const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5970
 
5971
  const int warp_id = threadIdx.x / WARP_SIZE;
5972
  const int lane_id = threadIdx.x % WARP_SIZE;
5973
 
5974
- extern __shared__ half data_soft_max_f16[];
5975
- half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
5976
- // (shared memory) buffer to cache values between iterations:
5977
- half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
5978
- // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
5979
- // in that case col_smem == col_data must be enforced to avoid race conditions
5980
-
5981
- half2 max_val = make_half2(-INFINITY, -INFINITY);
5982
-
5983
- #pragma unroll
5984
- for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
5985
- const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
5986
- const int col_smem = vals_smem ? col0 + tid : col_data;
5987
-
5988
- const int ix = rowx*ncols_data + col_data;
5989
- const int iy = rowy*ncols_data + col_data;
5990
-
5991
- half2 val;
5992
- if (need_check && col_data + 0 >= ncols_data) {
5993
- val.x = -INFINITY;
5994
- } else {
5995
- val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
5996
- }
5997
- if (need_check && col_data + WARP_SIZE >= ncols_data) {
5998
- val.y = -INFINITY;
5999
- } else {
6000
- val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
6001
- }
6002
- if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
6003
- vals[col_smem] = val;
6004
- }
6005
- max_val = __hmax2(max_val, val);
6006
- }
6007
-
6008
- // find the max value in the block
6009
- max_val = warp_reduce_max(max_val);
6010
- if (block_size > WARP_SIZE) {
6011
- if (warp_id == 0) {
6012
- buf_iw[lane_id] = -INFINITY;
6013
- }
6014
- __syncthreads();
6015
-
6016
- if (lane_id == 0) {
6017
- buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
6018
- }
6019
- __syncthreads();
6020
-
6021
- max_val = __half2half2(buf_iw[lane_id]);
6022
- max_val = warp_reduce_max(max_val);
6023
- } else {
6024
- max_val = __half2half2(__hmax(max_val.x, max_val.y));
6025
- }
6026
-
6027
- half2 tmp = make_half2(0.0f, 0.0f); // partial sums
6028
 
6029
- #pragma unroll
6030
- for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
6031
- const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
6032
-
6033
- if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
6034
- break;
6035
- }
6036
-
6037
- const half2 val = h2exp(vals[col_smem] - max_val);
6038
-
6039
- tmp += val;
6040
- vals[col_smem] = val;
6041
- }
6042
-
6043
- // find the sum of exps in the block
6044
- tmp = warp_reduce_sum(tmp);
6045
- if (block_size > WARP_SIZE) {
6046
- if (warp_id == 0) {
6047
- buf_iw[lane_id] = 0.0f;
6048
- }
6049
- __syncthreads();
6050
-
6051
- if (lane_id == 0) {
6052
- buf_iw[warp_id] = tmp.x + tmp.y;
6053
- }
6054
- __syncthreads();
6055
-
6056
- tmp = __half2half2(buf_iw[lane_id]);
6057
- tmp = warp_reduce_sum(tmp);
6058
- } else {
6059
- tmp = __half2half2(tmp.x + tmp.y);
6060
- }
6061
-
6062
- const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
6063
-
6064
- #pragma unroll
6065
- for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
6066
- const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
6067
- const int col_smem = vals_smem ? col0 + tid : col_data;
6068
-
6069
- const int idst = rowx*ncols_data + col_data;
6070
- const half2 result = vals[col_smem] * inv_sum;
6071
-
6072
- if (need_check && col_data + 0 >= ncols_data) {
6073
- return;
6074
- }
6075
- dst[idst] = result.x;
6076
 
6077
- if (need_check && col_data + WARP_SIZE >= ncols_data) {
6078
- return;
6079
- }
6080
 
6081
- dst[idst + WARP_SIZE] = result.y;
6082
  }
6083
- #else
6084
- (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
6085
- NO_DEVICE_CODE;
6086
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
6087
- }
6088
-
6089
- template <bool vals_smem, int ncols_template, int block_size_template>
6090
- static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
6091
- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
6092
-
6093
- const int tid = threadIdx.x;
6094
- const int rowx = blockIdx.x;
6095
- const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
6096
-
6097
- const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
6098
-
6099
- const int warp_id = threadIdx.x / WARP_SIZE;
6100
- const int lane_id = threadIdx.x % WARP_SIZE;
6101
 
6102
  extern __shared__ float data_soft_max_f32[];
6103
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -6117,7 +5999,8 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
6117
  const int ix = rowx*ncols + col;
6118
  const int iy = rowy*ncols + col;
6119
 
6120
- const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
 
6121
  vals[col] = val;
6122
  max_val = max(max_val, val);
6123
  }
@@ -7589,89 +7472,53 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
7589
  diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
7590
  }
7591
 
7592
- static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
7593
- int nth = WARP_SIZE;
7594
- while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
7595
- const dim3 block_dims(nth, 1, 1);
7596
- const dim3 block_nums(nrows_x, 1, 1);
7597
- const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
7598
- static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
7599
- if (shmem <= g_device_caps[g_main_device].smpb) {
7600
- switch (ncols_x) {
7601
- case 32:
7602
- soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7603
- break;
7604
- case 64:
7605
- soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7606
- break;
7607
- case 128:
7608
- soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7609
- break;
7610
- case 256:
7611
- soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7612
- break;
7613
- case 512:
7614
- soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7615
- break;
7616
- case 1024:
7617
- soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7618
- break;
7619
- case 2048:
7620
- soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7621
- break;
7622
- case 4096:
7623
- soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7624
- break;
7625
- default:
7626
- soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7627
- break;
7628
- }
7629
- } else {
7630
- const size_t shmem_low = WARP_SIZE*sizeof(half);
7631
- soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7632
- }
7633
- }
7634
-
7635
- static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
7636
  int nth = WARP_SIZE;
7637
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
7638
  const dim3 block_dims(nth, 1, 1);
7639
  const dim3 block_nums(nrows_x, 1, 1);
7640
  const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
7641
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
 
 
 
 
 
 
 
7642
  if (shmem < g_device_caps[g_main_device].smpb) {
7643
  switch (ncols_x) {
7644
  case 32:
7645
- soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7646
  break;
7647
  case 64:
7648
- soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7649
  break;
7650
  case 128:
7651
- soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7652
  break;
7653
  case 256:
7654
- soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7655
  break;
7656
  case 512:
7657
- soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7658
  break;
7659
  case 1024:
7660
- soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7661
  break;
7662
  case 2048:
7663
- soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7664
  break;
7665
  case 4096:
7666
- soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7667
  break;
7668
  default:
7669
- soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7670
  break;
7671
  }
7672
  } else {
7673
  const size_t shmem_low = WARP_SIZE*sizeof(float);
7674
- soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
7675
  }
7676
  }
7677
 
@@ -9090,30 +8937,36 @@ static void ggml_cuda_op_soft_max(
9090
 
9091
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
9092
 
9093
- const int64_t ne00 = src0->ne[0];
9094
  const int64_t nrows_x = ggml_nrows(src0);
9095
- const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
9096
 
9097
- float scale = 1.0f;
9098
- memcpy(&scale, dst->op_params, sizeof(float));
9099
 
9100
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX
9101
- #ifdef GGML_CUDA_F16
9102
- const bool use_f16_soft_max = true;
9103
- #else
9104
- const bool use_f16_soft_max = false;
9105
- #endif // GGML_CUDA_F16
9106
- #else
9107
- const bool use_f16_soft_max = false;
9108
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
9109
 
9110
- if (use_f16_soft_max) {
9111
- soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
9112
- } else {
9113
- soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
 
 
 
 
 
 
 
 
 
 
 
 
 
9114
  }
9115
 
9116
- (void) dst;
9117
  }
9118
 
9119
  static void ggml_cuda_op_scale(
 
5956
  dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
5957
  }
5958
 
5959
+ template <bool vals_smem, int ncols_template, int block_size_template>
5960
+ static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
5961
+ const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
 
5962
 
5963
  const int tid = threadIdx.x;
5964
  const int rowx = blockIdx.x;
5965
+ const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
5966
 
5967
  const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
5968
 
5969
  const int warp_id = threadIdx.x / WARP_SIZE;
5970
  const int lane_id = threadIdx.x % WARP_SIZE;
5971
 
5972
+ float slope = 0.0f;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5973
 
5974
+ // ALiBi
5975
+ if (max_bias > 0.0f) {
5976
+ const int h = rowx/nrows_y; // head index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5977
 
5978
+ const float base = h < n_head_log2 ? m0 : m1;
5979
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 
5980
 
5981
+ slope = powf(base, exp);
5982
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5983
 
5984
  extern __shared__ float data_soft_max_f32[];
5985
  float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
 
5999
  const int ix = rowx*ncols + col;
6000
  const int iy = rowy*ncols + col;
6001
 
6002
+ const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col];
6003
+
6004
  vals[col] = val;
6005
  max_val = max(max_val, val);
6006
  }
 
7472
  diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
7473
  }
7474
 
7475
+ static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7476
  int nth = WARP_SIZE;
7477
  while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
7478
  const dim3 block_dims(nth, 1, 1);
7479
  const dim3 block_nums(nrows_x, 1, 1);
7480
  const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
7481
  static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
7482
+
7483
+ const uint32_t n_head_kv = nrows_x/nrows_y;
7484
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
7485
+
7486
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
7487
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
7488
+
7489
  if (shmem < g_device_caps[g_main_device].smpb) {
7490
  switch (ncols_x) {
7491
  case 32:
7492
+ soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7493
  break;
7494
  case 64:
7495
+ soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7496
  break;
7497
  case 128:
7498
+ soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7499
  break;
7500
  case 256:
7501
+ soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7502
  break;
7503
  case 512:
7504
+ soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7505
  break;
7506
  case 1024:
7507
+ soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7508
  break;
7509
  case 2048:
7510
+ soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7511
  break;
7512
  case 4096:
7513
+ soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7514
  break;
7515
  default:
7516
+ soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7517
  break;
7518
  }
7519
  } else {
7520
  const size_t shmem_low = WARP_SIZE*sizeof(float);
7521
+ soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
7522
  }
7523
  }
7524
 
 
8937
 
8938
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
8939
 
8940
+ const int64_t ne00 = src0->ne[0];
8941
  const int64_t nrows_x = ggml_nrows(src0);
8942
+ const int64_t nrows_y = src0->ne[1];
8943
 
8944
+ float scale = 1.0f;
8945
+ float max_bias = 0.0f;
8946
 
8947
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
8948
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
 
 
 
 
 
 
8949
 
8950
+ // positions tensor
8951
+ float * src2_dd = dst_dd; // default to avoid null checks in the kernel
8952
+ cuda_pool_alloc<float> src2_f;
8953
+
8954
+ ggml_tensor * src2 = dst->src[2];
8955
+ const bool use_src2 = src2 != nullptr;
8956
+
8957
+ if (use_src2) {
8958
+ const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU;
8959
+ ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
8960
+
8961
+ if (src2_on_device) {
8962
+ src2_dd = (float *) src2_extra->data_device[g_main_device];
8963
+ } else {
8964
+ src2_dd = src2_f.alloc(ggml_nelements(src2));
8965
+ CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
8966
+ }
8967
  }
8968
 
8969
+ soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
8970
  }
8971
 
8972
  static void ggml_cuda_op_scale(
ggml-metal.m CHANGED
@@ -737,6 +737,7 @@ static bool ggml_metal_graph_compute(
737
 
738
  size_t offs_src0 = 0;
739
  size_t offs_src1 = 0;
 
740
  size_t offs_dst = 0;
741
 
742
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
@@ -755,6 +756,7 @@ static bool ggml_metal_graph_compute(
755
 
756
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
757
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
 
758
  struct ggml_tensor * dst = gf->nodes[i];
759
 
760
  switch (dst->op) {
@@ -816,6 +818,7 @@ static bool ggml_metal_graph_compute(
816
 
817
  id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
818
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
 
819
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
820
 
821
  //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -1197,7 +1200,16 @@ static bool ggml_metal_graph_compute(
1197
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1198
  }
1199
 
1200
- const float scale = ((float *) dst->op_params)[0];
 
 
 
 
 
 
 
 
 
1201
 
1202
  [encoder setComputePipelineState:pipeline];
1203
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1206,11 +1218,20 @@ static bool ggml_metal_graph_compute(
1206
  } else {
1207
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1208
  }
1209
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1210
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1211
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1212
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1213
- [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
 
 
 
 
 
 
 
 
 
1214
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1215
 
1216
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1523,8 +1544,6 @@ static bool ggml_metal_graph_compute(
1523
  // max size of the src1ids array in the kernel stack
1524
  GGML_ASSERT(ne11 <= 512);
1525
 
1526
- struct ggml_tensor * src2 = gf->nodes[i]->src[2];
1527
-
1528
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
1529
  const int64_t ne21 = src2 ? src2->ne[1] : 0;
1530
  const int64_t ne22 = src2 ? src2->ne[2] : 0;
 
737
 
738
  size_t offs_src0 = 0;
739
  size_t offs_src1 = 0;
740
+ size_t offs_src2 = 0;
741
  size_t offs_dst = 0;
742
 
743
  id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
 
756
 
757
  struct ggml_tensor * src0 = gf->nodes[i]->src[0];
758
  struct ggml_tensor * src1 = gf->nodes[i]->src[1];
759
+ struct ggml_tensor * src2 = gf->nodes[i]->src[2];
760
  struct ggml_tensor * dst = gf->nodes[i];
761
 
762
  switch (dst->op) {
 
818
 
819
  id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
820
  id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
821
+ id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
822
  id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
823
 
824
  //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
 
1200
  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
1201
  }
1202
 
1203
+ const float scale = ((float *) dst->op_params)[0];
1204
+ const float max_bias = ((float *) dst->op_params)[1];
1205
+
1206
+ const int64_t nrows_x = ggml_nrows(src0);
1207
+ const int64_t nrows_y = src0->ne[1];
1208
+ const uint32_t n_head_kv = nrows_x/nrows_y;
1209
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
1210
+
1211
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
1212
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
1213
 
1214
  [encoder setComputePipelineState:pipeline];
1215
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
 
1218
  } else {
1219
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
1220
  }
1221
+ if (id_src2) {
1222
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1223
+ } else {
1224
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
1225
+ }
1226
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:3];
1227
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
1228
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
1229
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1230
+ [encoder setBytes:&scale length:sizeof(scale) atIndex:7];
1231
+ [encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
1232
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
1233
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
1234
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
1235
  [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
1236
 
1237
  [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
 
1544
  // max size of the src1ids array in the kernel stack
1545
  GGML_ASSERT(ne11 <= 512);
1546
 
 
 
1547
  const int64_t ne20 = src2 ? src2->ne[0] : 0;
1548
  const int64_t ne21 = src2 ? src2->ne[1] : 0;
1549
  const int64_t ne22 = src2 ? src2->ne[2] : 0;
ggml-metal.metal CHANGED
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
351
  kernel void kernel_soft_max(
352
  device const float * src0,
353
  device const float * src1,
 
354
  device float * dst,
355
  constant int64_t & ne00,
356
  constant int64_t & ne01,
357
  constant int64_t & ne02,
358
  constant float & scale,
359
- threadgroup float * buf [[threadgroup(0)]],
 
 
 
 
360
  uint tgpig[[threadgroup_position_in_grid]],
361
  uint tpitg[[thread_position_in_threadgroup]],
362
  uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
368
 
369
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
370
  device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
 
371
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
372
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  // parallel max
374
  float lmax = -INFINITY;
375
 
376
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
377
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
378
  }
379
 
380
  // find the max value in the block
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
399
  // parallel sum
400
  float lsum = 0.0f;
401
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
402
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
403
  lsum += exp_psrc0;
404
  pdst[i00] = exp_psrc0;
405
  }
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
437
  kernel void kernel_soft_max_4(
438
  device const float * src0,
439
  device const float * src1,
 
440
  device float * dst,
441
  constant int64_t & ne00,
442
  constant int64_t & ne01,
443
  constant int64_t & ne02,
444
  constant float & scale,
445
- threadgroup float * buf [[threadgroup(0)]],
 
 
 
 
446
  uint tgpig[[threadgroup_position_in_grid]],
447
  uint tpitg[[thread_position_in_threadgroup]],
448
  uint sgitg[[simdgroup_index_in_threadgroup]],
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
454
 
455
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
456
  device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
 
457
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
458
 
 
 
 
 
 
 
 
 
 
 
 
459
  // parallel max
460
  float4 lmax4 = -INFINITY;
461
 
462
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
463
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
464
  }
465
 
466
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
486
  // parallel sum
487
  float4 lsum4 = 0.0f;
488
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
489
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
490
  lsum4 += exp_psrc4;
491
  pdst4[i00] = exp_psrc4;
492
  }
 
351
  kernel void kernel_soft_max(
352
  device const float * src0,
353
  device const float * src1,
354
+ device const float * src2,
355
  device float * dst,
356
  constant int64_t & ne00,
357
  constant int64_t & ne01,
358
  constant int64_t & ne02,
359
  constant float & scale,
360
+ constant float & max_bias,
361
+ constant float & m0,
362
+ constant float & m1,
363
+ constant uint32_t & n_head_log2,
364
+ threadgroup float * buf [[threadgroup(0)]],
365
  uint tgpig[[threadgroup_position_in_grid]],
366
  uint tpitg[[thread_position_in_threadgroup]],
367
  uint sgitg[[simdgroup_index_in_threadgroup]],
 
373
 
374
  device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
375
  device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
376
+ device const float * ppos = src2 != src0 ? src2 : nullptr;
377
  device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
378
 
379
+ float slope = 0.0f;
380
+
381
+ // ALiBi
382
+ if (max_bias > 0.0f) {
383
+ const int64_t h = i02;
384
+
385
+ const float base = h < n_head_log2 ? m0 : m1;
386
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
387
+
388
+ slope = pow(base, exp);
389
+ }
390
+
391
  // parallel max
392
  float lmax = -INFINITY;
393
 
394
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
395
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
396
  }
397
 
398
  // find the max value in the block
 
417
  // parallel sum
418
  float lsum = 0.0f;
419
  for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
420
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
421
  lsum += exp_psrc0;
422
  pdst[i00] = exp_psrc0;
423
  }
 
455
  kernel void kernel_soft_max_4(
456
  device const float * src0,
457
  device const float * src1,
458
+ device const float * src2,
459
  device float * dst,
460
  constant int64_t & ne00,
461
  constant int64_t & ne01,
462
  constant int64_t & ne02,
463
  constant float & scale,
464
+ constant float & max_bias,
465
+ constant float & m0,
466
+ constant float & m1,
467
+ constant uint32_t & n_head_log2,
468
+ threadgroup float * buf [[threadgroup(0)]],
469
  uint tgpig[[threadgroup_position_in_grid]],
470
  uint tpitg[[thread_position_in_threadgroup]],
471
  uint sgitg[[simdgroup_index_in_threadgroup]],
 
477
 
478
  device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
479
  device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
480
+ device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
481
  device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
482
 
483
+ float slope = 0.0f;
484
+
485
+ if (max_bias > 0.0f) {
486
+ const int64_t h = i02;
487
+
488
+ const float base = h < n_head_log2 ? m0 : m1;
489
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
490
+
491
+ slope = pow(base, exp);
492
+ }
493
+
494
  // parallel max
495
  float4 lmax4 = -INFINITY;
496
 
497
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
498
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
499
  }
500
 
501
  const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
 
521
  // parallel sum
522
  float4 lsum4 = 0.0f;
523
  for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
524
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
525
  lsum4 += exp_psrc4;
526
  pdst4[i00] = exp_psrc4;
527
  }
ggml.c CHANGED
@@ -5096,16 +5096,28 @@ static struct ggml_tensor * ggml_soft_max_impl(
5096
  struct ggml_context * ctx,
5097
  struct ggml_tensor * a,
5098
  struct ggml_tensor * mask,
 
5099
  float scale,
 
5100
  bool inplace) {
5101
  GGML_ASSERT(ggml_is_contiguous(a));
 
5102
  if (mask) {
5103
  GGML_ASSERT(ggml_is_contiguous(mask));
5104
- GGML_ASSERT(mask->ne[2] == 1);
5105
- GGML_ASSERT(mask->ne[3] == 1);
5106
  GGML_ASSERT(ggml_can_repeat_rows(mask, a));
5107
  }
5108
 
 
 
 
 
 
 
 
 
 
 
5109
  bool is_node = false;
5110
 
5111
  if (a->grad) {
@@ -5114,13 +5126,14 @@ static struct ggml_tensor * ggml_soft_max_impl(
5114
 
5115
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5116
 
5117
- float params[] = { scale };
5118
  ggml_set_op_params(result, params, sizeof(params));
5119
 
5120
  result->op = GGML_OP_SOFT_MAX;
5121
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5122
  result->src[0] = a;
5123
  result->src[1] = mask;
 
5124
 
5125
  return result;
5126
  }
@@ -5128,21 +5141,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
5128
  struct ggml_tensor * ggml_soft_max(
5129
  struct ggml_context * ctx,
5130
  struct ggml_tensor * a) {
5131
- return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
5132
  }
5133
 
5134
  struct ggml_tensor * ggml_soft_max_inplace(
5135
  struct ggml_context * ctx,
5136
  struct ggml_tensor * a) {
5137
- return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
5138
  }
5139
 
5140
  struct ggml_tensor * ggml_soft_max_ext(
5141
  struct ggml_context * ctx,
5142
  struct ggml_tensor * a,
5143
  struct ggml_tensor * mask,
5144
- float scale) {
5145
- return ggml_soft_max_impl(ctx, a, mask, scale, false);
 
 
5146
  }
5147
 
5148
  // ggml_soft_max_back
@@ -11495,6 +11510,7 @@ static void ggml_compute_forward_soft_max_f32(
11495
  const struct ggml_compute_params * params,
11496
  const struct ggml_tensor * src0,
11497
  const struct ggml_tensor * src1,
 
11498
  struct ggml_tensor * dst) {
11499
  assert(ggml_is_contiguous(dst));
11500
  assert(ggml_are_same_shape(src0, dst));
@@ -11503,16 +11519,29 @@ static void ggml_compute_forward_soft_max_f32(
11503
  return;
11504
  }
11505
 
11506
- float scale = 1.0f;
11507
- memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
 
 
 
11508
 
11509
  // TODO: handle transposed/permuted matrices
11510
 
11511
  const int ith = params->ith;
11512
  const int nth = params->nth;
11513
 
 
 
11514
  const int64_t ne11 = src1 ? src1->ne[1] : 1;
11515
 
 
 
 
 
 
 
 
 
11516
  const int nc = src0->ne[0];
11517
  const int nr = ggml_nrows(src0);
11518
 
@@ -11525,6 +11554,9 @@ static void ggml_compute_forward_soft_max_f32(
11525
 
11526
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
11527
 
 
 
 
11528
  for (int i1 = ir0; i1 < ir1; i1++) {
11529
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
11530
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
@@ -11538,6 +11570,16 @@ static void ggml_compute_forward_soft_max_f32(
11538
  ggml_vec_acc_f32(nc, wp, mp);
11539
  }
11540
 
 
 
 
 
 
 
 
 
 
 
11541
  #ifndef NDEBUG
11542
  for (int i = 0; i < nc; ++i) {
11543
  //printf("p[%d] = %f\n", i, p[i]);
@@ -11582,11 +11624,12 @@ static void ggml_compute_forward_soft_max(
11582
  const struct ggml_compute_params * params,
11583
  const struct ggml_tensor * src0,
11584
  const struct ggml_tensor * src1,
 
11585
  struct ggml_tensor * dst) {
11586
  switch (src0->type) {
11587
  case GGML_TYPE_F32:
11588
  {
11589
- ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
11590
  } break;
11591
  default:
11592
  {
@@ -11730,22 +11773,20 @@ static void ggml_compute_forward_alibi_f32(
11730
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
11731
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
11732
 
11733
- for (int64_t i = 0; i < ne0; i++) {
11734
- for (int64_t j = 0; j < ne1; j++) {
11735
- for (int64_t k = 0; k < ne2_ne3; k++) {
11736
- float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
11737
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
11738
-
11739
- // TODO: k*nb2 or k*nb3
11740
 
11741
- float m_k;
11742
-
11743
- if (k < n_heads_log2_floor) {
11744
- m_k = powf(m0, k + 1);
11745
- } else {
11746
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
11747
- }
11748
 
 
 
 
 
11749
  pdst[0] = i * m_k + src[0];
11750
  }
11751
  }
@@ -11790,21 +11831,20 @@ static void ggml_compute_forward_alibi_f16(
11790
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
11791
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
11792
 
11793
- for (int i = 0; i < ne0; i++) {
11794
- for (int j = 0; j < ne1; j++) {
11795
- for (int k = 0; k < ne2_ne3; k++) {
11796
- ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
11797
- float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
11798
-
11799
- // TODO: k*nb2 or k*nb3
11800
 
11801
- float m_k;
 
 
 
 
11802
 
11803
- if (k < n_heads_log2_floor) {
11804
- m_k = powf(m0, k + 1);
11805
- } else {
11806
- m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
11807
- }
11808
 
11809
  // we return F32
11810
  pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
@@ -15116,7 +15156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
15116
  } break;
15117
  case GGML_OP_SOFT_MAX:
15118
  {
15119
- ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
15120
  } break;
15121
  case GGML_OP_SOFT_MAX_BACK:
15122
  {
 
5096
  struct ggml_context * ctx,
5097
  struct ggml_tensor * a,
5098
  struct ggml_tensor * mask,
5099
+ struct ggml_tensor * pos,
5100
  float scale,
5101
+ float max_bias,
5102
  bool inplace) {
5103
  GGML_ASSERT(ggml_is_contiguous(a));
5104
+
5105
  if (mask) {
5106
  GGML_ASSERT(ggml_is_contiguous(mask));
5107
+ GGML_ASSERT(ggml_is_matrix(mask));
 
5108
  GGML_ASSERT(ggml_can_repeat_rows(mask, a));
5109
  }
5110
 
5111
+ if (pos) {
5112
+ GGML_ASSERT(ggml_is_vector(pos));
5113
+ GGML_ASSERT(pos->type == GGML_TYPE_F32);
5114
+ GGML_ASSERT(pos->ne[0] == a->ne[0]);
5115
+ }
5116
+
5117
+ if (max_bias > 0.0f) {
5118
+ GGML_ASSERT(pos);
5119
+ }
5120
+
5121
  bool is_node = false;
5122
 
5123
  if (a->grad) {
 
5126
 
5127
  struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
5128
 
5129
+ float params[] = { scale, max_bias };
5130
  ggml_set_op_params(result, params, sizeof(params));
5131
 
5132
  result->op = GGML_OP_SOFT_MAX;
5133
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
5134
  result->src[0] = a;
5135
  result->src[1] = mask;
5136
+ result->src[2] = pos;
5137
 
5138
  return result;
5139
  }
 
5141
  struct ggml_tensor * ggml_soft_max(
5142
  struct ggml_context * ctx,
5143
  struct ggml_tensor * a) {
5144
+ return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
5145
  }
5146
 
5147
  struct ggml_tensor * ggml_soft_max_inplace(
5148
  struct ggml_context * ctx,
5149
  struct ggml_tensor * a) {
5150
+ return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
5151
  }
5152
 
5153
  struct ggml_tensor * ggml_soft_max_ext(
5154
  struct ggml_context * ctx,
5155
  struct ggml_tensor * a,
5156
  struct ggml_tensor * mask,
5157
+ struct ggml_tensor * pos,
5158
+ float scale,
5159
+ float max_bias) {
5160
+ return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
5161
  }
5162
 
5163
  // ggml_soft_max_back
 
11510
  const struct ggml_compute_params * params,
11511
  const struct ggml_tensor * src0,
11512
  const struct ggml_tensor * src1,
11513
+ const struct ggml_tensor * src2,
11514
  struct ggml_tensor * dst) {
11515
  assert(ggml_is_contiguous(dst));
11516
  assert(ggml_are_same_shape(src0, dst));
 
11519
  return;
11520
  }
11521
 
11522
+ float scale = 1.0f;
11523
+ float max_bias = 0.0f;
11524
+
11525
+ memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
11526
+ memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
11527
 
11528
  // TODO: handle transposed/permuted matrices
11529
 
11530
  const int ith = params->ith;
11531
  const int nth = params->nth;
11532
 
11533
+ GGML_TENSOR_UNARY_OP_LOCALS
11534
+
11535
  const int64_t ne11 = src1 ? src1->ne[1] : 1;
11536
 
11537
+ // TODO: is this supposed to be ceil instead of floor?
11538
+ // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
11539
+ const uint32_t n_head_kv = ne02;
11540
+ const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
11541
+
11542
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
11543
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
11544
+
11545
  const int nc = src0->ne[0];
11546
  const int nr = ggml_nrows(src0);
11547
 
 
11554
 
11555
  float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
11556
 
11557
+ // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
11558
+ float * pos = src2 ? (float *) src2->data : src0->data;
11559
+
11560
  for (int i1 = ir0; i1 < ir1; i1++) {
11561
  float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
11562
  float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
 
11570
  ggml_vec_acc_f32(nc, wp, mp);
11571
  }
11572
 
11573
+ // ALiBi bias
11574
+ if (max_bias > 0.0f) {
11575
+ const uint32_t h = (i1/ne01)%ne02; // head
11576
+ const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
11577
+
11578
+ for (int i = 0; i < nc; i++) {
11579
+ wp[i] = wp[i] + slope*pos[i];
11580
+ }
11581
+ }
11582
+
11583
  #ifndef NDEBUG
11584
  for (int i = 0; i < nc; ++i) {
11585
  //printf("p[%d] = %f\n", i, p[i]);
 
11624
  const struct ggml_compute_params * params,
11625
  const struct ggml_tensor * src0,
11626
  const struct ggml_tensor * src1,
11627
+ const struct ggml_tensor * src2,
11628
  struct ggml_tensor * dst) {
11629
  switch (src0->type) {
11630
  case GGML_TYPE_F32:
11631
  {
11632
+ ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst);
11633
  } break;
11634
  default:
11635
  {
 
11773
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
11774
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
11775
 
11776
+ for (int64_t k = 0; k < ne2_ne3; k++) {
11777
+ // TODO: k*nb2 or k*nb3
11778
+ float m_k;
 
 
 
 
11779
 
11780
+ if (k < n_heads_log2_floor) {
11781
+ m_k = powf(m0, k + 1);
11782
+ } else {
11783
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
11784
+ }
 
 
11785
 
11786
+ for (int64_t i = 0; i < ne0; i++) {
11787
+ for (int64_t j = 0; j < ne1; j++) {
11788
+ float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
11789
+ float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
11790
  pdst[0] = i * m_k + src[0];
11791
  }
11792
  }
 
11831
  const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
11832
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
11833
 
11834
+ for (int k = 0; k < ne2_ne3; k++) {
11835
+ // TODO: k*nb2 or k*nb3
11836
+ float m_k;
 
 
 
 
11837
 
11838
+ if (k < n_heads_log2_floor) {
11839
+ m_k = powf(m0, k + 1);
11840
+ } else {
11841
+ m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
11842
+ }
11843
 
11844
+ for (int i = 0; i < ne0; i++) {
11845
+ for (int j = 0; j < ne1; j++) {
11846
+ ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
11847
+ float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
 
11848
 
11849
  // we return F32
11850
  pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
 
15156
  } break;
15157
  case GGML_OP_SOFT_MAX:
15158
  {
15159
+ ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
15160
  } break;
15161
  case GGML_OP_SOFT_MAX_BACK:
15162
  {
ggml.h CHANGED
@@ -1383,13 +1383,17 @@ extern "C" {
1383
  struct ggml_context * ctx,
1384
  struct ggml_tensor * a);
1385
 
1386
- // fused soft_max(a*scale + mask)
1387
  // mask is optional
 
 
1388
  GGML_API struct ggml_tensor * ggml_soft_max_ext(
1389
  struct ggml_context * ctx,
1390
  struct ggml_tensor * a,
1391
  struct ggml_tensor * mask,
1392
- float scale);
 
 
1393
 
1394
  GGML_API struct ggml_tensor * ggml_soft_max_back(
1395
  struct ggml_context * ctx,
@@ -1491,12 +1495,13 @@ extern "C" {
1491
 
1492
  // alibi position embedding
1493
  // in-place, returns view(a)
1494
- GGML_API struct ggml_tensor * ggml_alibi(
1495
  struct ggml_context * ctx,
1496
  struct ggml_tensor * a,
1497
  int n_past,
1498
  int n_head,
1499
- float bias_max);
 
1500
 
1501
  // clamp
1502
  // in-place, returns view(a)
 
1383
  struct ggml_context * ctx,
1384
  struct ggml_tensor * a);
1385
 
1386
+ // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
1387
  // mask is optional
1388
+ // pos is required when max_bias > 0.0f
1389
+ // max_bias = 0.0f for no ALiBi
1390
  GGML_API struct ggml_tensor * ggml_soft_max_ext(
1391
  struct ggml_context * ctx,
1392
  struct ggml_tensor * a,
1393
  struct ggml_tensor * mask,
1394
+ struct ggml_tensor * pos,
1395
+ float scale,
1396
+ float max_bias);
1397
 
1398
  GGML_API struct ggml_tensor * ggml_soft_max_back(
1399
  struct ggml_context * ctx,
 
1495
 
1496
  // alibi position embedding
1497
  // in-place, returns view(a)
1498
+ GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
1499
  struct ggml_context * ctx,
1500
  struct ggml_tensor * a,
1501
  int n_past,
1502
  int n_head,
1503
+ float bias_max),
1504
+ "use ggml_soft_max_ext instead (will be removed in Mar 2024)");
1505
 
1506
  // clamp
1507
  // in-place, returns view(a)