Spaces:
Running
Running
ggml : add ALiBi support for ggml_soft_max_ext (llama/5488)
Browse files- ggml-alloc.c +7 -7
- ggml-cuda.cu +55 -202
- ggml-metal.m +27 -8
- ggml-metal.metal +41 -6
- ggml.c +78 -38
- ggml.h +9 -4
ggml-alloc.c
CHANGED
|
@@ -468,7 +468,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor
|
|
| 468 |
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
| 469 |
struct ggml_tensor * parent = node->src[i];
|
| 470 |
if (parent == NULL) {
|
| 471 |
-
|
| 472 |
}
|
| 473 |
|
| 474 |
// if the node's data is external, then we cannot re-use it
|
|
@@ -565,7 +565,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
|
| 565 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 566 |
struct ggml_tensor * src = node->src[j];
|
| 567 |
if (src == NULL) {
|
| 568 |
-
|
| 569 |
}
|
| 570 |
|
| 571 |
ggml_gallocr_hash_get(galloc, src)->n_children += 1;
|
|
@@ -599,7 +599,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
|
| 599 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 600 |
struct ggml_tensor * parent = node->src[j];
|
| 601 |
if (parent == NULL) {
|
| 602 |
-
|
| 603 |
}
|
| 604 |
ggml_gallocr_allocate_node(galloc, parent, buffer_id);
|
| 605 |
}
|
|
@@ -611,7 +611,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
|
| 611 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 612 |
struct ggml_tensor * parent = node->src[j];
|
| 613 |
if (parent == NULL) {
|
| 614 |
-
|
| 615 |
}
|
| 616 |
AT_PRINTF("%s", parent->name);
|
| 617 |
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
|
|
@@ -624,7 +624,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
|
|
| 624 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 625 |
struct ggml_tensor * parent = node->src[j];
|
| 626 |
if (parent == NULL) {
|
| 627 |
-
|
| 628 |
}
|
| 629 |
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
| 630 |
p_hn->n_children -= 1;
|
|
@@ -810,7 +810,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
|
|
| 810 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 811 |
struct ggml_tensor * src = node->src[j];
|
| 812 |
if (src == NULL) {
|
| 813 |
-
|
| 814 |
}
|
| 815 |
if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
|
| 816 |
#ifndef NDEBUG
|
|
@@ -857,7 +857,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
|
|
| 857 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 858 |
struct ggml_tensor * src = node->src[j];
|
| 859 |
if (src == NULL) {
|
| 860 |
-
|
| 861 |
}
|
| 862 |
ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]);
|
| 863 |
}
|
|
|
|
| 468 |
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
| 469 |
struct ggml_tensor * parent = node->src[i];
|
| 470 |
if (parent == NULL) {
|
| 471 |
+
continue;
|
| 472 |
}
|
| 473 |
|
| 474 |
// if the node's data is external, then we cannot re-use it
|
|
|
|
| 565 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 566 |
struct ggml_tensor * src = node->src[j];
|
| 567 |
if (src == NULL) {
|
| 568 |
+
continue;
|
| 569 |
}
|
| 570 |
|
| 571 |
ggml_gallocr_hash_get(galloc, src)->n_children += 1;
|
|
|
|
| 599 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 600 |
struct ggml_tensor * parent = node->src[j];
|
| 601 |
if (parent == NULL) {
|
| 602 |
+
continue;
|
| 603 |
}
|
| 604 |
ggml_gallocr_allocate_node(galloc, parent, buffer_id);
|
| 605 |
}
|
|
|
|
| 611 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 612 |
struct ggml_tensor * parent = node->src[j];
|
| 613 |
if (parent == NULL) {
|
| 614 |
+
continue;
|
| 615 |
}
|
| 616 |
AT_PRINTF("%s", parent->name);
|
| 617 |
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
|
|
|
|
| 624 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 625 |
struct ggml_tensor * parent = node->src[j];
|
| 626 |
if (parent == NULL) {
|
| 627 |
+
continue;
|
| 628 |
}
|
| 629 |
struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent);
|
| 630 |
p_hn->n_children -= 1;
|
|
|
|
| 810 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 811 |
struct ggml_tensor * src = node->src[j];
|
| 812 |
if (src == NULL) {
|
| 813 |
+
continue;
|
| 814 |
}
|
| 815 |
if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
|
| 816 |
#ifndef NDEBUG
|
|
|
|
| 857 |
for (int j = 0; j < GGML_MAX_SRC; j++) {
|
| 858 |
struct ggml_tensor * src = node->src[j];
|
| 859 |
if (src == NULL) {
|
| 860 |
+
continue;
|
| 861 |
}
|
| 862 |
ggml_gallocr_init_tensor(galloc, src, node_alloc->buffer_id, &node_alloc->src[j]);
|
| 863 |
}
|
ggml-cuda.cu
CHANGED
|
@@ -5956,148 +5956,30 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
|
| 5956 |
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
| 5957 |
}
|
| 5958 |
|
| 5959 |
-
template <bool vals_smem, int ncols_template, int block_size_template
|
| 5960 |
-
static __global__ void
|
| 5961 |
-
|
| 5962 |
-
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
| 5963 |
-
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
| 5964 |
|
| 5965 |
const int tid = threadIdx.x;
|
| 5966 |
const int rowx = blockIdx.x;
|
| 5967 |
-
const int rowy = rowx % nrows_y; // broadcast the mask
|
| 5968 |
|
| 5969 |
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
| 5970 |
|
| 5971 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 5972 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 5973 |
|
| 5974 |
-
|
| 5975 |
-
half * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
|
| 5976 |
-
// (shared memory) buffer to cache values between iterations:
|
| 5977 |
-
half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
|
| 5978 |
-
// if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
|
| 5979 |
-
// in that case col_smem == col_data must be enforced to avoid race conditions
|
| 5980 |
-
|
| 5981 |
-
half2 max_val = make_half2(-INFINITY, -INFINITY);
|
| 5982 |
-
|
| 5983 |
-
#pragma unroll
|
| 5984 |
-
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
| 5985 |
-
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
| 5986 |
-
const int col_smem = vals_smem ? col0 + tid : col_data;
|
| 5987 |
-
|
| 5988 |
-
const int ix = rowx*ncols_data + col_data;
|
| 5989 |
-
const int iy = rowy*ncols_data + col_data;
|
| 5990 |
-
|
| 5991 |
-
half2 val;
|
| 5992 |
-
if (need_check && col_data + 0 >= ncols_data) {
|
| 5993 |
-
val.x = -INFINITY;
|
| 5994 |
-
} else {
|
| 5995 |
-
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
| 5996 |
-
}
|
| 5997 |
-
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
| 5998 |
-
val.y = -INFINITY;
|
| 5999 |
-
} else {
|
| 6000 |
-
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
| 6001 |
-
}
|
| 6002 |
-
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
| 6003 |
-
vals[col_smem] = val;
|
| 6004 |
-
}
|
| 6005 |
-
max_val = __hmax2(max_val, val);
|
| 6006 |
-
}
|
| 6007 |
-
|
| 6008 |
-
// find the max value in the block
|
| 6009 |
-
max_val = warp_reduce_max(max_val);
|
| 6010 |
-
if (block_size > WARP_SIZE) {
|
| 6011 |
-
if (warp_id == 0) {
|
| 6012 |
-
buf_iw[lane_id] = -INFINITY;
|
| 6013 |
-
}
|
| 6014 |
-
__syncthreads();
|
| 6015 |
-
|
| 6016 |
-
if (lane_id == 0) {
|
| 6017 |
-
buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
|
| 6018 |
-
}
|
| 6019 |
-
__syncthreads();
|
| 6020 |
-
|
| 6021 |
-
max_val = __half2half2(buf_iw[lane_id]);
|
| 6022 |
-
max_val = warp_reduce_max(max_val);
|
| 6023 |
-
} else {
|
| 6024 |
-
max_val = __half2half2(__hmax(max_val.x, max_val.y));
|
| 6025 |
-
}
|
| 6026 |
-
|
| 6027 |
-
half2 tmp = make_half2(0.0f, 0.0f); // partial sums
|
| 6028 |
|
| 6029 |
-
|
| 6030 |
-
|
| 6031 |
-
const int
|
| 6032 |
-
|
| 6033 |
-
if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
|
| 6034 |
-
break;
|
| 6035 |
-
}
|
| 6036 |
-
|
| 6037 |
-
const half2 val = h2exp(vals[col_smem] - max_val);
|
| 6038 |
-
|
| 6039 |
-
tmp += val;
|
| 6040 |
-
vals[col_smem] = val;
|
| 6041 |
-
}
|
| 6042 |
-
|
| 6043 |
-
// find the sum of exps in the block
|
| 6044 |
-
tmp = warp_reduce_sum(tmp);
|
| 6045 |
-
if (block_size > WARP_SIZE) {
|
| 6046 |
-
if (warp_id == 0) {
|
| 6047 |
-
buf_iw[lane_id] = 0.0f;
|
| 6048 |
-
}
|
| 6049 |
-
__syncthreads();
|
| 6050 |
-
|
| 6051 |
-
if (lane_id == 0) {
|
| 6052 |
-
buf_iw[warp_id] = tmp.x + tmp.y;
|
| 6053 |
-
}
|
| 6054 |
-
__syncthreads();
|
| 6055 |
-
|
| 6056 |
-
tmp = __half2half2(buf_iw[lane_id]);
|
| 6057 |
-
tmp = warp_reduce_sum(tmp);
|
| 6058 |
-
} else {
|
| 6059 |
-
tmp = __half2half2(tmp.x + tmp.y);
|
| 6060 |
-
}
|
| 6061 |
-
|
| 6062 |
-
const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
|
| 6063 |
-
|
| 6064 |
-
#pragma unroll
|
| 6065 |
-
for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
|
| 6066 |
-
const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
|
| 6067 |
-
const int col_smem = vals_smem ? col0 + tid : col_data;
|
| 6068 |
-
|
| 6069 |
-
const int idst = rowx*ncols_data + col_data;
|
| 6070 |
-
const half2 result = vals[col_smem] * inv_sum;
|
| 6071 |
-
|
| 6072 |
-
if (need_check && col_data + 0 >= ncols_data) {
|
| 6073 |
-
return;
|
| 6074 |
-
}
|
| 6075 |
-
dst[idst] = result.x;
|
| 6076 |
|
| 6077 |
-
|
| 6078 |
-
|
| 6079 |
-
}
|
| 6080 |
|
| 6081 |
-
|
| 6082 |
}
|
| 6083 |
-
#else
|
| 6084 |
-
(void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
|
| 6085 |
-
NO_DEVICE_CODE;
|
| 6086 |
-
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
| 6087 |
-
}
|
| 6088 |
-
|
| 6089 |
-
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 6090 |
-
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
| 6091 |
-
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
| 6092 |
-
|
| 6093 |
-
const int tid = threadIdx.x;
|
| 6094 |
-
const int rowx = blockIdx.x;
|
| 6095 |
-
const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
|
| 6096 |
-
|
| 6097 |
-
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
| 6098 |
-
|
| 6099 |
-
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 6100 |
-
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 6101 |
|
| 6102 |
extern __shared__ float data_soft_max_f32[];
|
| 6103 |
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
|
@@ -6117,7 +5999,8 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
|
| 6117 |
const int ix = rowx*ncols + col;
|
| 6118 |
const int iy = rowy*ncols + col;
|
| 6119 |
|
| 6120 |
-
const float val = x[ix]*scale + (
|
|
|
|
| 6121 |
vals[col] = val;
|
| 6122 |
max_val = max(max_val, val);
|
| 6123 |
}
|
|
@@ -7589,89 +7472,53 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
|
| 7589 |
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
| 7590 |
}
|
| 7591 |
|
| 7592 |
-
static void
|
| 7593 |
-
int nth = WARP_SIZE;
|
| 7594 |
-
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 7595 |
-
const dim3 block_dims(nth, 1, 1);
|
| 7596 |
-
const dim3 block_nums(nrows_x, 1, 1);
|
| 7597 |
-
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
|
| 7598 |
-
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 7599 |
-
if (shmem <= g_device_caps[g_main_device].smpb) {
|
| 7600 |
-
switch (ncols_x) {
|
| 7601 |
-
case 32:
|
| 7602 |
-
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7603 |
-
break;
|
| 7604 |
-
case 64:
|
| 7605 |
-
soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7606 |
-
break;
|
| 7607 |
-
case 128:
|
| 7608 |
-
soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7609 |
-
break;
|
| 7610 |
-
case 256:
|
| 7611 |
-
soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7612 |
-
break;
|
| 7613 |
-
case 512:
|
| 7614 |
-
soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7615 |
-
break;
|
| 7616 |
-
case 1024:
|
| 7617 |
-
soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7618 |
-
break;
|
| 7619 |
-
case 2048:
|
| 7620 |
-
soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7621 |
-
break;
|
| 7622 |
-
case 4096:
|
| 7623 |
-
soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7624 |
-
break;
|
| 7625 |
-
default:
|
| 7626 |
-
soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7627 |
-
break;
|
| 7628 |
-
}
|
| 7629 |
-
} else {
|
| 7630 |
-
const size_t shmem_low = WARP_SIZE*sizeof(half);
|
| 7631 |
-
soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
|
| 7632 |
-
}
|
| 7633 |
-
}
|
| 7634 |
-
|
| 7635 |
-
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
| 7636 |
int nth = WARP_SIZE;
|
| 7637 |
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 7638 |
const dim3 block_dims(nth, 1, 1);
|
| 7639 |
const dim3 block_nums(nrows_x, 1, 1);
|
| 7640 |
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
| 7641 |
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7642 |
if (shmem < g_device_caps[g_main_device].smpb) {
|
| 7643 |
switch (ncols_x) {
|
| 7644 |
case 32:
|
| 7645 |
-
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7646 |
break;
|
| 7647 |
case 64:
|
| 7648 |
-
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7649 |
break;
|
| 7650 |
case 128:
|
| 7651 |
-
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7652 |
break;
|
| 7653 |
case 256:
|
| 7654 |
-
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7655 |
break;
|
| 7656 |
case 512:
|
| 7657 |
-
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7658 |
break;
|
| 7659 |
case 1024:
|
| 7660 |
-
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7661 |
break;
|
| 7662 |
case 2048:
|
| 7663 |
-
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7664 |
break;
|
| 7665 |
case 4096:
|
| 7666 |
-
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7667 |
break;
|
| 7668 |
default:
|
| 7669 |
-
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x,
|
| 7670 |
break;
|
| 7671 |
}
|
| 7672 |
} else {
|
| 7673 |
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
| 7674 |
-
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x,
|
| 7675 |
}
|
| 7676 |
}
|
| 7677 |
|
|
@@ -9090,30 +8937,36 @@ static void ggml_cuda_op_soft_max(
|
|
| 9090 |
|
| 9091 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
| 9092 |
|
| 9093 |
-
const int64_t ne00
|
| 9094 |
const int64_t nrows_x = ggml_nrows(src0);
|
| 9095 |
-
const int64_t nrows_y =
|
| 9096 |
|
| 9097 |
-
float scale
|
| 9098 |
-
|
| 9099 |
|
| 9100 |
-
|
| 9101 |
-
|
| 9102 |
-
const bool use_f16_soft_max = true;
|
| 9103 |
-
#else
|
| 9104 |
-
const bool use_f16_soft_max = false;
|
| 9105 |
-
#endif // GGML_CUDA_F16
|
| 9106 |
-
#else
|
| 9107 |
-
const bool use_f16_soft_max = false;
|
| 9108 |
-
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
|
| 9109 |
|
| 9110 |
-
|
| 9111 |
-
|
| 9112 |
-
|
| 9113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9114 |
}
|
| 9115 |
|
| 9116 |
-
(
|
| 9117 |
}
|
| 9118 |
|
| 9119 |
static void ggml_cuda_op_scale(
|
|
|
|
| 5956 |
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
|
| 5957 |
}
|
| 5958 |
|
| 5959 |
+
template <bool vals_smem, int ncols_template, int block_size_template>
|
| 5960 |
+
static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
|
| 5961 |
+
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
|
|
|
|
|
|
| 5962 |
|
| 5963 |
const int tid = threadIdx.x;
|
| 5964 |
const int rowx = blockIdx.x;
|
| 5965 |
+
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
|
| 5966 |
|
| 5967 |
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
| 5968 |
|
| 5969 |
const int warp_id = threadIdx.x / WARP_SIZE;
|
| 5970 |
const int lane_id = threadIdx.x % WARP_SIZE;
|
| 5971 |
|
| 5972 |
+
float slope = 0.0f;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5973 |
|
| 5974 |
+
// ALiBi
|
| 5975 |
+
if (max_bias > 0.0f) {
|
| 5976 |
+
const int h = rowx/nrows_y; // head index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5977 |
|
| 5978 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 5979 |
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
|
|
|
| 5980 |
|
| 5981 |
+
slope = powf(base, exp);
|
| 5982 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5983 |
|
| 5984 |
extern __shared__ float data_soft_max_f32[];
|
| 5985 |
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
|
|
|
| 5999 |
const int ix = rowx*ncols + col;
|
| 6000 |
const int iy = rowy*ncols + col;
|
| 6001 |
|
| 6002 |
+
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col];
|
| 6003 |
+
|
| 6004 |
vals[col] = val;
|
| 6005 |
max_val = max(max_val, val);
|
| 6006 |
}
|
|
|
|
| 7472 |
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
| 7473 |
}
|
| 7474 |
|
| 7475 |
+
static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7476 |
int nth = WARP_SIZE;
|
| 7477 |
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
| 7478 |
const dim3 block_dims(nth, 1, 1);
|
| 7479 |
const dim3 block_nums(nrows_x, 1, 1);
|
| 7480 |
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
| 7481 |
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
| 7482 |
+
|
| 7483 |
+
const uint32_t n_head_kv = nrows_x/nrows_y;
|
| 7484 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
| 7485 |
+
|
| 7486 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 7487 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 7488 |
+
|
| 7489 |
if (shmem < g_device_caps[g_main_device].smpb) {
|
| 7490 |
switch (ncols_x) {
|
| 7491 |
case 32:
|
| 7492 |
+
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7493 |
break;
|
| 7494 |
case 64:
|
| 7495 |
+
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7496 |
break;
|
| 7497 |
case 128:
|
| 7498 |
+
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7499 |
break;
|
| 7500 |
case 256:
|
| 7501 |
+
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7502 |
break;
|
| 7503 |
case 512:
|
| 7504 |
+
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7505 |
break;
|
| 7506 |
case 1024:
|
| 7507 |
+
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7508 |
break;
|
| 7509 |
case 2048:
|
| 7510 |
+
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7511 |
break;
|
| 7512 |
case 4096:
|
| 7513 |
+
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7514 |
break;
|
| 7515 |
default:
|
| 7516 |
+
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7517 |
break;
|
| 7518 |
}
|
| 7519 |
} else {
|
| 7520 |
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
| 7521 |
+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
|
| 7522 |
}
|
| 7523 |
}
|
| 7524 |
|
|
|
|
| 8937 |
|
| 8938 |
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
| 8939 |
|
| 8940 |
+
const int64_t ne00 = src0->ne[0];
|
| 8941 |
const int64_t nrows_x = ggml_nrows(src0);
|
| 8942 |
+
const int64_t nrows_y = src0->ne[1];
|
| 8943 |
|
| 8944 |
+
float scale = 1.0f;
|
| 8945 |
+
float max_bias = 0.0f;
|
| 8946 |
|
| 8947 |
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
| 8948 |
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8949 |
|
| 8950 |
+
// positions tensor
|
| 8951 |
+
float * src2_dd = dst_dd; // default to avoid null checks in the kernel
|
| 8952 |
+
cuda_pool_alloc<float> src2_f;
|
| 8953 |
+
|
| 8954 |
+
ggml_tensor * src2 = dst->src[2];
|
| 8955 |
+
const bool use_src2 = src2 != nullptr;
|
| 8956 |
+
|
| 8957 |
+
if (use_src2) {
|
| 8958 |
+
const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU;
|
| 8959 |
+
ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
|
| 8960 |
+
|
| 8961 |
+
if (src2_on_device) {
|
| 8962 |
+
src2_dd = (float *) src2_extra->data_device[g_main_device];
|
| 8963 |
+
} else {
|
| 8964 |
+
src2_dd = src2_f.alloc(ggml_nelements(src2));
|
| 8965 |
+
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
|
| 8966 |
+
}
|
| 8967 |
}
|
| 8968 |
|
| 8969 |
+
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
|
| 8970 |
}
|
| 8971 |
|
| 8972 |
static void ggml_cuda_op_scale(
|
ggml-metal.m
CHANGED
|
@@ -737,6 +737,7 @@ static bool ggml_metal_graph_compute(
|
|
| 737 |
|
| 738 |
size_t offs_src0 = 0;
|
| 739 |
size_t offs_src1 = 0;
|
|
|
|
| 740 |
size_t offs_dst = 0;
|
| 741 |
|
| 742 |
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
|
@@ -755,6 +756,7 @@ static bool ggml_metal_graph_compute(
|
|
| 755 |
|
| 756 |
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
| 757 |
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
|
|
| 758 |
struct ggml_tensor * dst = gf->nodes[i];
|
| 759 |
|
| 760 |
switch (dst->op) {
|
|
@@ -816,6 +818,7 @@ static bool ggml_metal_graph_compute(
|
|
| 816 |
|
| 817 |
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
| 818 |
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
|
|
|
| 819 |
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
| 820 |
|
| 821 |
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
@@ -1197,7 +1200,16 @@ static bool ggml_metal_graph_compute(
|
|
| 1197 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
| 1198 |
}
|
| 1199 |
|
| 1200 |
-
const float scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1201 |
|
| 1202 |
[encoder setComputePipelineState:pipeline];
|
| 1203 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
@@ -1206,11 +1218,20 @@ static bool ggml_metal_graph_compute(
|
|
| 1206 |
} else {
|
| 1207 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 1208 |
}
|
| 1209 |
-
|
| 1210 |
-
|
| 1211 |
-
|
| 1212 |
-
|
| 1213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1214 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 1215 |
|
| 1216 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
@@ -1523,8 +1544,6 @@ static bool ggml_metal_graph_compute(
|
|
| 1523 |
// max size of the src1ids array in the kernel stack
|
| 1524 |
GGML_ASSERT(ne11 <= 512);
|
| 1525 |
|
| 1526 |
-
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
| 1527 |
-
|
| 1528 |
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
| 1529 |
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
| 1530 |
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
|
|
|
| 737 |
|
| 738 |
size_t offs_src0 = 0;
|
| 739 |
size_t offs_src1 = 0;
|
| 740 |
+
size_t offs_src2 = 0;
|
| 741 |
size_t offs_dst = 0;
|
| 742 |
|
| 743 |
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
|
|
|
| 756 |
|
| 757 |
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
| 758 |
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
| 759 |
+
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
| 760 |
struct ggml_tensor * dst = gf->nodes[i];
|
| 761 |
|
| 762 |
switch (dst->op) {
|
|
|
|
| 818 |
|
| 819 |
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
| 820 |
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
| 821 |
+
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
| 822 |
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
| 823 |
|
| 824 |
//GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
|
|
|
| 1200 |
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
|
| 1201 |
}
|
| 1202 |
|
| 1203 |
+
const float scale = ((float *) dst->op_params)[0];
|
| 1204 |
+
const float max_bias = ((float *) dst->op_params)[1];
|
| 1205 |
+
|
| 1206 |
+
const int64_t nrows_x = ggml_nrows(src0);
|
| 1207 |
+
const int64_t nrows_y = src0->ne[1];
|
| 1208 |
+
const uint32_t n_head_kv = nrows_x/nrows_y;
|
| 1209 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
|
| 1210 |
+
|
| 1211 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 1212 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 1213 |
|
| 1214 |
[encoder setComputePipelineState:pipeline];
|
| 1215 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
|
|
|
| 1218 |
} else {
|
| 1219 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
| 1220 |
}
|
| 1221 |
+
if (id_src2) {
|
| 1222 |
+
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
| 1223 |
+
} else {
|
| 1224 |
+
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
| 1225 |
+
}
|
| 1226 |
+
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
| 1227 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4];
|
| 1228 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:5];
|
| 1229 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
| 1230 |
+
[encoder setBytes:&scale length:sizeof(scale) atIndex:7];
|
| 1231 |
+
[encoder setBytes:&max_bias length:sizeof(max_bias) atIndex:8];
|
| 1232 |
+
[encoder setBytes:&m0 length:sizeof(m0) atIndex:9];
|
| 1233 |
+
[encoder setBytes:&m1 length:sizeof(m1) atIndex:10];
|
| 1234 |
+
[encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
|
| 1235 |
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
| 1236 |
|
| 1237 |
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
|
|
|
| 1544 |
// max size of the src1ids array in the kernel stack
|
| 1545 |
GGML_ASSERT(ne11 <= 512);
|
| 1546 |
|
|
|
|
|
|
|
| 1547 |
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
| 1548 |
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
| 1549 |
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
ggml-metal.metal
CHANGED
|
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
|
|
| 351 |
kernel void kernel_soft_max(
|
| 352 |
device const float * src0,
|
| 353 |
device const float * src1,
|
|
|
|
| 354 |
device float * dst,
|
| 355 |
constant int64_t & ne00,
|
| 356 |
constant int64_t & ne01,
|
| 357 |
constant int64_t & ne02,
|
| 358 |
constant float & scale,
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 361 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 362 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
|
|
| 368 |
|
| 369 |
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 370 |
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
|
|
|
| 371 |
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
// parallel max
|
| 374 |
float lmax = -INFINITY;
|
| 375 |
|
| 376 |
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 377 |
-
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
| 378 |
}
|
| 379 |
|
| 380 |
// find the max value in the block
|
|
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
|
|
| 399 |
// parallel sum
|
| 400 |
float lsum = 0.0f;
|
| 401 |
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 402 |
-
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
| 403 |
lsum += exp_psrc0;
|
| 404 |
pdst[i00] = exp_psrc0;
|
| 405 |
}
|
|
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
|
|
| 437 |
kernel void kernel_soft_max_4(
|
| 438 |
device const float * src0,
|
| 439 |
device const float * src1,
|
|
|
|
| 440 |
device float * dst,
|
| 441 |
constant int64_t & ne00,
|
| 442 |
constant int64_t & ne01,
|
| 443 |
constant int64_t & ne02,
|
| 444 |
constant float & scale,
|
| 445 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 447 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 448 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
|
|
| 454 |
|
| 455 |
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 456 |
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
|
|
|
| 457 |
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 458 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
// parallel max
|
| 460 |
float4 lmax4 = -INFINITY;
|
| 461 |
|
| 462 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 463 |
-
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
| 464 |
}
|
| 465 |
|
| 466 |
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
|
|
| 486 |
// parallel sum
|
| 487 |
float4 lsum4 = 0.0f;
|
| 488 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 489 |
-
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
| 490 |
lsum4 += exp_psrc4;
|
| 491 |
pdst4[i00] = exp_psrc4;
|
| 492 |
}
|
|
|
|
| 351 |
kernel void kernel_soft_max(
|
| 352 |
device const float * src0,
|
| 353 |
device const float * src1,
|
| 354 |
+
device const float * src2,
|
| 355 |
device float * dst,
|
| 356 |
constant int64_t & ne00,
|
| 357 |
constant int64_t & ne01,
|
| 358 |
constant int64_t & ne02,
|
| 359 |
constant float & scale,
|
| 360 |
+
constant float & max_bias,
|
| 361 |
+
constant float & m0,
|
| 362 |
+
constant float & m1,
|
| 363 |
+
constant uint32_t & n_head_log2,
|
| 364 |
+
threadgroup float * buf [[threadgroup(0)]],
|
| 365 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 366 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 367 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
|
|
| 373 |
|
| 374 |
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 375 |
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
| 376 |
+
device const float * ppos = src2 != src0 ? src2 : nullptr;
|
| 377 |
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
| 378 |
|
| 379 |
+
float slope = 0.0f;
|
| 380 |
+
|
| 381 |
+
// ALiBi
|
| 382 |
+
if (max_bias > 0.0f) {
|
| 383 |
+
const int64_t h = i02;
|
| 384 |
+
|
| 385 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 386 |
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 387 |
+
|
| 388 |
+
slope = pow(base, exp);
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
// parallel max
|
| 392 |
float lmax = -INFINITY;
|
| 393 |
|
| 394 |
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 395 |
+
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
| 396 |
}
|
| 397 |
|
| 398 |
// find the max value in the block
|
|
|
|
| 417 |
// parallel sum
|
| 418 |
float lsum = 0.0f;
|
| 419 |
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
| 420 |
+
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
| 421 |
lsum += exp_psrc0;
|
| 422 |
pdst[i00] = exp_psrc0;
|
| 423 |
}
|
|
|
|
| 455 |
kernel void kernel_soft_max_4(
|
| 456 |
device const float * src0,
|
| 457 |
device const float * src1,
|
| 458 |
+
device const float * src2,
|
| 459 |
device float * dst,
|
| 460 |
constant int64_t & ne00,
|
| 461 |
constant int64_t & ne01,
|
| 462 |
constant int64_t & ne02,
|
| 463 |
constant float & scale,
|
| 464 |
+
constant float & max_bias,
|
| 465 |
+
constant float & m0,
|
| 466 |
+
constant float & m1,
|
| 467 |
+
constant uint32_t & n_head_log2,
|
| 468 |
+
threadgroup float * buf [[threadgroup(0)]],
|
| 469 |
uint tgpig[[threadgroup_position_in_grid]],
|
| 470 |
uint tpitg[[thread_position_in_threadgroup]],
|
| 471 |
uint sgitg[[simdgroup_index_in_threadgroup]],
|
|
|
|
| 477 |
|
| 478 |
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 479 |
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
| 480 |
+
device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr;
|
| 481 |
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
| 482 |
|
| 483 |
+
float slope = 0.0f;
|
| 484 |
+
|
| 485 |
+
if (max_bias > 0.0f) {
|
| 486 |
+
const int64_t h = i02;
|
| 487 |
+
|
| 488 |
+
const float base = h < n_head_log2 ? m0 : m1;
|
| 489 |
+
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
|
| 490 |
+
|
| 491 |
+
slope = pow(base, exp);
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
// parallel max
|
| 495 |
float4 lmax4 = -INFINITY;
|
| 496 |
|
| 497 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 498 |
+
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
|
| 499 |
}
|
| 500 |
|
| 501 |
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
|
|
|
| 521 |
// parallel sum
|
| 522 |
float4 lsum4 = 0.0f;
|
| 523 |
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
| 524 |
+
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
|
| 525 |
lsum4 += exp_psrc4;
|
| 526 |
pdst4[i00] = exp_psrc4;
|
| 527 |
}
|
ggml.c
CHANGED
|
@@ -5096,16 +5096,28 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|
| 5096 |
struct ggml_context * ctx,
|
| 5097 |
struct ggml_tensor * a,
|
| 5098 |
struct ggml_tensor * mask,
|
|
|
|
| 5099 |
float scale,
|
|
|
|
| 5100 |
bool inplace) {
|
| 5101 |
GGML_ASSERT(ggml_is_contiguous(a));
|
|
|
|
| 5102 |
if (mask) {
|
| 5103 |
GGML_ASSERT(ggml_is_contiguous(mask));
|
| 5104 |
-
GGML_ASSERT(mask
|
| 5105 |
-
GGML_ASSERT(mask->ne[3] == 1);
|
| 5106 |
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
| 5107 |
}
|
| 5108 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5109 |
bool is_node = false;
|
| 5110 |
|
| 5111 |
if (a->grad) {
|
|
@@ -5114,13 +5126,14 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|
| 5114 |
|
| 5115 |
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 5116 |
|
| 5117 |
-
float params[] = { scale };
|
| 5118 |
ggml_set_op_params(result, params, sizeof(params));
|
| 5119 |
|
| 5120 |
result->op = GGML_OP_SOFT_MAX;
|
| 5121 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5122 |
result->src[0] = a;
|
| 5123 |
result->src[1] = mask;
|
|
|
|
| 5124 |
|
| 5125 |
return result;
|
| 5126 |
}
|
|
@@ -5128,21 +5141,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
|
| 5128 |
struct ggml_tensor * ggml_soft_max(
|
| 5129 |
struct ggml_context * ctx,
|
| 5130 |
struct ggml_tensor * a) {
|
| 5131 |
-
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
|
| 5132 |
}
|
| 5133 |
|
| 5134 |
struct ggml_tensor * ggml_soft_max_inplace(
|
| 5135 |
struct ggml_context * ctx,
|
| 5136 |
struct ggml_tensor * a) {
|
| 5137 |
-
return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
|
| 5138 |
}
|
| 5139 |
|
| 5140 |
struct ggml_tensor * ggml_soft_max_ext(
|
| 5141 |
struct ggml_context * ctx,
|
| 5142 |
struct ggml_tensor * a,
|
| 5143 |
struct ggml_tensor * mask,
|
| 5144 |
-
|
| 5145 |
-
|
|
|
|
|
|
|
| 5146 |
}
|
| 5147 |
|
| 5148 |
// ggml_soft_max_back
|
|
@@ -11495,6 +11510,7 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 11495 |
const struct ggml_compute_params * params,
|
| 11496 |
const struct ggml_tensor * src0,
|
| 11497 |
const struct ggml_tensor * src1,
|
|
|
|
| 11498 |
struct ggml_tensor * dst) {
|
| 11499 |
assert(ggml_is_contiguous(dst));
|
| 11500 |
assert(ggml_are_same_shape(src0, dst));
|
|
@@ -11503,16 +11519,29 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 11503 |
return;
|
| 11504 |
}
|
| 11505 |
|
| 11506 |
-
float scale
|
| 11507 |
-
|
|
|
|
|
|
|
|
|
|
| 11508 |
|
| 11509 |
// TODO: handle transposed/permuted matrices
|
| 11510 |
|
| 11511 |
const int ith = params->ith;
|
| 11512 |
const int nth = params->nth;
|
| 11513 |
|
|
|
|
|
|
|
| 11514 |
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
| 11515 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11516 |
const int nc = src0->ne[0];
|
| 11517 |
const int nr = ggml_nrows(src0);
|
| 11518 |
|
|
@@ -11525,6 +11554,9 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 11525 |
|
| 11526 |
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
| 11527 |
|
|
|
|
|
|
|
|
|
|
| 11528 |
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 11529 |
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
| 11530 |
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
@@ -11538,6 +11570,16 @@ static void ggml_compute_forward_soft_max_f32(
|
|
| 11538 |
ggml_vec_acc_f32(nc, wp, mp);
|
| 11539 |
}
|
| 11540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11541 |
#ifndef NDEBUG
|
| 11542 |
for (int i = 0; i < nc; ++i) {
|
| 11543 |
//printf("p[%d] = %f\n", i, p[i]);
|
|
@@ -11582,11 +11624,12 @@ static void ggml_compute_forward_soft_max(
|
|
| 11582 |
const struct ggml_compute_params * params,
|
| 11583 |
const struct ggml_tensor * src0,
|
| 11584 |
const struct ggml_tensor * src1,
|
|
|
|
| 11585 |
struct ggml_tensor * dst) {
|
| 11586 |
switch (src0->type) {
|
| 11587 |
case GGML_TYPE_F32:
|
| 11588 |
{
|
| 11589 |
-
ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
|
| 11590 |
} break;
|
| 11591 |
default:
|
| 11592 |
{
|
|
@@ -11730,22 +11773,20 @@ static void ggml_compute_forward_alibi_f32(
|
|
| 11730 |
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
| 11731 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
| 11732 |
|
| 11733 |
-
for (int64_t
|
| 11734 |
-
|
| 11735 |
-
|
| 11736 |
-
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
| 11737 |
-
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
| 11738 |
-
|
| 11739 |
-
// TODO: k*nb2 or k*nb3
|
| 11740 |
|
| 11741 |
-
|
| 11742 |
-
|
| 11743 |
-
|
| 11744 |
-
|
| 11745 |
-
|
| 11746 |
-
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
| 11747 |
-
}
|
| 11748 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11749 |
pdst[0] = i * m_k + src[0];
|
| 11750 |
}
|
| 11751 |
}
|
|
@@ -11790,21 +11831,20 @@ static void ggml_compute_forward_alibi_f16(
|
|
| 11790 |
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
| 11791 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
| 11792 |
|
| 11793 |
-
for (int
|
| 11794 |
-
|
| 11795 |
-
|
| 11796 |
-
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
| 11797 |
-
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
| 11798 |
-
|
| 11799 |
-
// TODO: k*nb2 or k*nb3
|
| 11800 |
|
| 11801 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11802 |
|
| 11803 |
-
|
| 11804 |
-
|
| 11805 |
-
|
| 11806 |
-
|
| 11807 |
-
}
|
| 11808 |
|
| 11809 |
// we return F32
|
| 11810 |
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
|
|
@@ -15116,7 +15156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|
| 15116 |
} break;
|
| 15117 |
case GGML_OP_SOFT_MAX:
|
| 15118 |
{
|
| 15119 |
-
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
|
| 15120 |
} break;
|
| 15121 |
case GGML_OP_SOFT_MAX_BACK:
|
| 15122 |
{
|
|
|
|
| 5096 |
struct ggml_context * ctx,
|
| 5097 |
struct ggml_tensor * a,
|
| 5098 |
struct ggml_tensor * mask,
|
| 5099 |
+
struct ggml_tensor * pos,
|
| 5100 |
float scale,
|
| 5101 |
+
float max_bias,
|
| 5102 |
bool inplace) {
|
| 5103 |
GGML_ASSERT(ggml_is_contiguous(a));
|
| 5104 |
+
|
| 5105 |
if (mask) {
|
| 5106 |
GGML_ASSERT(ggml_is_contiguous(mask));
|
| 5107 |
+
GGML_ASSERT(ggml_is_matrix(mask));
|
|
|
|
| 5108 |
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
| 5109 |
}
|
| 5110 |
|
| 5111 |
+
if (pos) {
|
| 5112 |
+
GGML_ASSERT(ggml_is_vector(pos));
|
| 5113 |
+
GGML_ASSERT(pos->type == GGML_TYPE_F32);
|
| 5114 |
+
GGML_ASSERT(pos->ne[0] == a->ne[0]);
|
| 5115 |
+
}
|
| 5116 |
+
|
| 5117 |
+
if (max_bias > 0.0f) {
|
| 5118 |
+
GGML_ASSERT(pos);
|
| 5119 |
+
}
|
| 5120 |
+
|
| 5121 |
bool is_node = false;
|
| 5122 |
|
| 5123 |
if (a->grad) {
|
|
|
|
| 5126 |
|
| 5127 |
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
| 5128 |
|
| 5129 |
+
float params[] = { scale, max_bias };
|
| 5130 |
ggml_set_op_params(result, params, sizeof(params));
|
| 5131 |
|
| 5132 |
result->op = GGML_OP_SOFT_MAX;
|
| 5133 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 5134 |
result->src[0] = a;
|
| 5135 |
result->src[1] = mask;
|
| 5136 |
+
result->src[2] = pos;
|
| 5137 |
|
| 5138 |
return result;
|
| 5139 |
}
|
|
|
|
| 5141 |
struct ggml_tensor * ggml_soft_max(
|
| 5142 |
struct ggml_context * ctx,
|
| 5143 |
struct ggml_tensor * a) {
|
| 5144 |
+
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
|
| 5145 |
}
|
| 5146 |
|
| 5147 |
struct ggml_tensor * ggml_soft_max_inplace(
|
| 5148 |
struct ggml_context * ctx,
|
| 5149 |
struct ggml_tensor * a) {
|
| 5150 |
+
return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
|
| 5151 |
}
|
| 5152 |
|
| 5153 |
struct ggml_tensor * ggml_soft_max_ext(
|
| 5154 |
struct ggml_context * ctx,
|
| 5155 |
struct ggml_tensor * a,
|
| 5156 |
struct ggml_tensor * mask,
|
| 5157 |
+
struct ggml_tensor * pos,
|
| 5158 |
+
float scale,
|
| 5159 |
+
float max_bias) {
|
| 5160 |
+
return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
|
| 5161 |
}
|
| 5162 |
|
| 5163 |
// ggml_soft_max_back
|
|
|
|
| 11510 |
const struct ggml_compute_params * params,
|
| 11511 |
const struct ggml_tensor * src0,
|
| 11512 |
const struct ggml_tensor * src1,
|
| 11513 |
+
const struct ggml_tensor * src2,
|
| 11514 |
struct ggml_tensor * dst) {
|
| 11515 |
assert(ggml_is_contiguous(dst));
|
| 11516 |
assert(ggml_are_same_shape(src0, dst));
|
|
|
|
| 11519 |
return;
|
| 11520 |
}
|
| 11521 |
|
| 11522 |
+
float scale = 1.0f;
|
| 11523 |
+
float max_bias = 0.0f;
|
| 11524 |
+
|
| 11525 |
+
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
| 11526 |
+
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
|
| 11527 |
|
| 11528 |
// TODO: handle transposed/permuted matrices
|
| 11529 |
|
| 11530 |
const int ith = params->ith;
|
| 11531 |
const int nth = params->nth;
|
| 11532 |
|
| 11533 |
+
GGML_TENSOR_UNARY_OP_LOCALS
|
| 11534 |
+
|
| 11535 |
const int64_t ne11 = src1 ? src1->ne[1] : 1;
|
| 11536 |
|
| 11537 |
+
// TODO: is this supposed to be ceil instead of floor?
|
| 11538 |
+
// https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
|
| 11539 |
+
const uint32_t n_head_kv = ne02;
|
| 11540 |
+
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
|
| 11541 |
+
|
| 11542 |
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
| 11543 |
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
| 11544 |
+
|
| 11545 |
const int nc = src0->ne[0];
|
| 11546 |
const int nr = ggml_nrows(src0);
|
| 11547 |
|
|
|
|
| 11554 |
|
| 11555 |
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
|
| 11556 |
|
| 11557 |
+
// when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
|
| 11558 |
+
float * pos = src2 ? (float *) src2->data : src0->data;
|
| 11559 |
+
|
| 11560 |
for (int i1 = ir0; i1 < ir1; i1++) {
|
| 11561 |
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
|
| 11562 |
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
|
|
|
| 11570 |
ggml_vec_acc_f32(nc, wp, mp);
|
| 11571 |
}
|
| 11572 |
|
| 11573 |
+
// ALiBi bias
|
| 11574 |
+
if (max_bias > 0.0f) {
|
| 11575 |
+
const uint32_t h = (i1/ne01)%ne02; // head
|
| 11576 |
+
const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
|
| 11577 |
+
|
| 11578 |
+
for (int i = 0; i < nc; i++) {
|
| 11579 |
+
wp[i] = wp[i] + slope*pos[i];
|
| 11580 |
+
}
|
| 11581 |
+
}
|
| 11582 |
+
|
| 11583 |
#ifndef NDEBUG
|
| 11584 |
for (int i = 0; i < nc; ++i) {
|
| 11585 |
//printf("p[%d] = %f\n", i, p[i]);
|
|
|
|
| 11624 |
const struct ggml_compute_params * params,
|
| 11625 |
const struct ggml_tensor * src0,
|
| 11626 |
const struct ggml_tensor * src1,
|
| 11627 |
+
const struct ggml_tensor * src2,
|
| 11628 |
struct ggml_tensor * dst) {
|
| 11629 |
switch (src0->type) {
|
| 11630 |
case GGML_TYPE_F32:
|
| 11631 |
{
|
| 11632 |
+
ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst);
|
| 11633 |
} break;
|
| 11634 |
default:
|
| 11635 |
{
|
|
|
|
| 11773 |
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
| 11774 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
| 11775 |
|
| 11776 |
+
for (int64_t k = 0; k < ne2_ne3; k++) {
|
| 11777 |
+
// TODO: k*nb2 or k*nb3
|
| 11778 |
+
float m_k;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11779 |
|
| 11780 |
+
if (k < n_heads_log2_floor) {
|
| 11781 |
+
m_k = powf(m0, k + 1);
|
| 11782 |
+
} else {
|
| 11783 |
+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
| 11784 |
+
}
|
|
|
|
|
|
|
| 11785 |
|
| 11786 |
+
for (int64_t i = 0; i < ne0; i++) {
|
| 11787 |
+
for (int64_t j = 0; j < ne1; j++) {
|
| 11788 |
+
float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
| 11789 |
+
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
| 11790 |
pdst[0] = i * m_k + src[0];
|
| 11791 |
}
|
| 11792 |
}
|
|
|
|
| 11831 |
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
|
| 11832 |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
|
| 11833 |
|
| 11834 |
+
for (int k = 0; k < ne2_ne3; k++) {
|
| 11835 |
+
// TODO: k*nb2 or k*nb3
|
| 11836 |
+
float m_k;
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11837 |
|
| 11838 |
+
if (k < n_heads_log2_floor) {
|
| 11839 |
+
m_k = powf(m0, k + 1);
|
| 11840 |
+
} else {
|
| 11841 |
+
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
|
| 11842 |
+
}
|
| 11843 |
|
| 11844 |
+
for (int i = 0; i < ne0; i++) {
|
| 11845 |
+
for (int j = 0; j < ne1; j++) {
|
| 11846 |
+
ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
|
| 11847 |
+
float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2);
|
|
|
|
| 11848 |
|
| 11849 |
// we return F32
|
| 11850 |
pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
|
|
|
|
| 15156 |
} break;
|
| 15157 |
case GGML_OP_SOFT_MAX:
|
| 15158 |
{
|
| 15159 |
+
ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
| 15160 |
} break;
|
| 15161 |
case GGML_OP_SOFT_MAX_BACK:
|
| 15162 |
{
|
ggml.h
CHANGED
|
@@ -1383,13 +1383,17 @@ extern "C" {
|
|
| 1383 |
struct ggml_context * ctx,
|
| 1384 |
struct ggml_tensor * a);
|
| 1385 |
|
| 1386 |
-
// fused soft_max(a*scale + mask)
|
| 1387 |
// mask is optional
|
|
|
|
|
|
|
| 1388 |
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
| 1389 |
struct ggml_context * ctx,
|
| 1390 |
struct ggml_tensor * a,
|
| 1391 |
struct ggml_tensor * mask,
|
| 1392 |
-
|
|
|
|
|
|
|
| 1393 |
|
| 1394 |
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
| 1395 |
struct ggml_context * ctx,
|
|
@@ -1491,12 +1495,13 @@ extern "C" {
|
|
| 1491 |
|
| 1492 |
// alibi position embedding
|
| 1493 |
// in-place, returns view(a)
|
| 1494 |
-
GGML_API struct ggml_tensor * ggml_alibi(
|
| 1495 |
struct ggml_context * ctx,
|
| 1496 |
struct ggml_tensor * a,
|
| 1497 |
int n_past,
|
| 1498 |
int n_head,
|
| 1499 |
-
float bias_max)
|
|
|
|
| 1500 |
|
| 1501 |
// clamp
|
| 1502 |
// in-place, returns view(a)
|
|
|
|
| 1383 |
struct ggml_context * ctx,
|
| 1384 |
struct ggml_tensor * a);
|
| 1385 |
|
| 1386 |
+
// fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
|
| 1387 |
// mask is optional
|
| 1388 |
+
// pos is required when max_bias > 0.0f
|
| 1389 |
+
// max_bias = 0.0f for no ALiBi
|
| 1390 |
GGML_API struct ggml_tensor * ggml_soft_max_ext(
|
| 1391 |
struct ggml_context * ctx,
|
| 1392 |
struct ggml_tensor * a,
|
| 1393 |
struct ggml_tensor * mask,
|
| 1394 |
+
struct ggml_tensor * pos,
|
| 1395 |
+
float scale,
|
| 1396 |
+
float max_bias);
|
| 1397 |
|
| 1398 |
GGML_API struct ggml_tensor * ggml_soft_max_back(
|
| 1399 |
struct ggml_context * ctx,
|
|
|
|
| 1495 |
|
| 1496 |
// alibi position embedding
|
| 1497 |
// in-place, returns view(a)
|
| 1498 |
+
GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
|
| 1499 |
struct ggml_context * ctx,
|
| 1500 |
struct ggml_tensor * a,
|
| 1501 |
int n_past,
|
| 1502 |
int n_head,
|
| 1503 |
+
float bias_max),
|
| 1504 |
+
"use ggml_soft_max_ext instead (will be removed in Mar 2024)");
|
| 1505 |
|
| 1506 |
// clamp
|
| 1507 |
// in-place, returns view(a)
|