slaren ggerganov commited on
Commit
f0b5c67
·
1 Parent(s): 5e756db

ggml : group all experts in a single ggml_mul_mat_id (llama/6505)

Browse files

* ggml : group all experts in a single ggml_mul_mat_id
cuda : improve mmid row copy

* cuda : fix bin bcast with non-cont src0

* test-backend-ops : only run all mul mat tests for base types

* llama : disable moe offloading with SYCL

---------

Co-authored-by: Georgi Gerganov <[email protected]>

Files changed (8) hide show
  1. ggml-cuda.cu +134 -45
  2. ggml-cuda/binbcast.cu +68 -24
  3. ggml-cuda/convert.cu +2 -0
  4. ggml-metal.m +61 -68
  5. ggml-metal.metal +400 -478
  6. ggml-sycl.cpp +1 -1
  7. ggml.c +62 -61
  8. ggml.h +2 -4
ggml-cuda.cu CHANGED
@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1231
 
1232
  if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
1233
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1234
- ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
1235
  if (src0->type != GGML_TYPE_F16) {
1236
  const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
1237
  GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1241
  }
1242
  const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
1243
 
1244
- ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
1245
  if (src1->type != GGML_TYPE_F16) {
1246
  const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
1247
  GGML_ASSERT(to_fp16_cuda != nullptr);
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
1250
  to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
1251
  }
1252
  const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
1253
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
1254
 
1255
  const half alpha_f16 = 1.0f;
1256
  const half beta_f16 = 0.0f;
@@ -1960,20 +1960,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
1960
  }
1961
  }
1962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1963
  static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
1964
  const ggml_tensor * src0 = dst->src[0];
1965
  const ggml_tensor * src1 = dst->src[1];
1966
  const ggml_tensor * ids = dst->src[2];
1967
 
 
 
1968
  GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
1969
 
1970
  cudaStream_t stream = ctx.stream();
1971
 
1972
- const size_t nb11 = src1->nb[1];
1973
- const size_t nb1 = dst->nb[1];
1974
-
1975
- const int32_t id = ((int32_t *) dst->op_params)[0];
1976
- const int32_t n_as = src0->ne[2];
1977
 
1978
  std::vector<char> ids_host(ggml_nbytes(ids));
1979
  const char * ids_dev = (const char *) ids->data;
@@ -1982,7 +2035,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
1982
 
1983
  ggml_tensor src0_row = *src0;
1984
  ggml_tensor src1_row = *src1;
1985
- ggml_tensor dst_row = *dst;
1986
 
1987
  char * src0_original = (char *) src0->data;
1988
  char * src1_original = (char *) src1->data;
@@ -1990,19 +2043,39 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
1990
 
1991
  src0_row.ne[2] = 1;
1992
  src0_row.ne[3] = 1;
1993
- src0_row.nb[3] = src0->nb[2];
1994
 
1995
- if (src1->ne[1] == 1) {
1996
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
1997
- const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
 
 
1998
 
1999
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
 
 
 
 
2000
 
2001
- src0_row.data = src0_original + row_id*src0->nb[2];
2002
- src1_row.data = src1_original + i01*src1->nb[1];
2003
- dst_row.data = dst_original + i01*dst->nb[1];
 
2004
 
2005
- ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
 
 
 
 
 
 
 
 
 
 
 
 
2006
  }
2007
  } else {
2008
  ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
@@ -2011,54 +2084,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
2011
  src1_row.data = src1_contiguous.get();
2012
  dst_row.data = dst_contiguous.get();
2013
 
2014
- for (int32_t row_id = 0; row_id < n_as; ++row_id) {
2015
  int64_t num_src1_rows = 0;
2016
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2017
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
2018
 
2019
- if (row_id_i != row_id) {
2020
- continue;
2021
- }
2022
 
2023
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
2024
 
2025
- CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
2026
- nb11, cudaMemcpyDeviceToDevice, stream));
2027
- num_src1_rows++;
 
 
 
2028
  }
2029
 
2030
  if (num_src1_rows == 0) {
2031
  continue;
2032
  }
2033
 
2034
- src0_row.data = src0_original + row_id*src0->nb[2];
 
 
2035
 
2036
- src1_row.ne[1] = num_src1_rows;
2037
- dst_row.ne[1] = num_src1_rows;
 
 
 
 
 
 
 
 
 
 
 
2038
 
 
 
 
 
2039
  src1_row.nb[1] = nb11;
2040
  src1_row.nb[2] = num_src1_rows*nb11;
2041
  src1_row.nb[3] = num_src1_rows*nb11;
2042
 
 
2043
  dst_row.nb[1] = nb1;
2044
  dst_row.nb[2] = num_src1_rows*nb1;
2045
  dst_row.nb[3] = num_src1_rows*nb1;
2046
 
2047
  ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2048
 
2049
- num_src1_rows = 0;
2050
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
2051
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
2052
-
2053
- if (row_id_i != row_id) {
2054
- continue;
2055
- }
2056
-
2057
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
2058
-
2059
- CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
2060
- nb1, cudaMemcpyDeviceToDevice, stream));
2061
- num_src1_rows++;
2062
  }
2063
  }
2064
  }
@@ -2491,7 +2579,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
2491
  GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
2492
  const int min_batch_size = 32;
2493
 
2494
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
 
2495
 
2496
  GGML_UNUSED(backend);
2497
  }
 
1231
 
1232
  if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
1233
  // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
1234
+ ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
1235
  if (src0->type != GGML_TYPE_F16) {
1236
  const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
1237
  GGML_ASSERT(to_fp16_cuda != nullptr);
 
1241
  }
1242
  const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
1243
 
1244
+ ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
1245
  if (src1->type != GGML_TYPE_F16) {
1246
  const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
1247
  GGML_ASSERT(to_fp16_cuda != nullptr);
 
1250
  to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
1251
  }
1252
  const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
1253
+ ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
1254
 
1255
  const half alpha_f16 = 1.0f;
1256
  const half beta_f16 = 0.0f;
 
1960
  }
1961
  }
1962
 
1963
+ struct mmid_row_mapping {
1964
+ int32_t i1;
1965
+ int32_t i2;
1966
+ };
1967
+
1968
+ static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
1969
+ int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
1970
+ const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
1971
+ int64_t ne11, int64_t ne10,
1972
+ size_t nb11, size_t nb12) {
1973
+ int32_t iid1 = blockIdx.x;
1974
+ int32_t id = blockIdx.y;
1975
+
1976
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
1977
+
1978
+ if (row_id_i != i02) {
1979
+ return;
1980
+ }
1981
+
1982
+ const int64_t i11 = id % ne11;
1983
+ const int64_t i12 = iid1;
1984
+
1985
+ __shared__ int src1_row;
1986
+ if (threadIdx.x == 0) {
1987
+ src1_row = atomicAdd(cur_src1_row, 1);
1988
+ row_mapping[src1_row] = {id, iid1};
1989
+ }
1990
+ __syncthreads();
1991
+
1992
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
1993
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
1994
+
1995
+ for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
1996
+ src1_row_contiguous[i] = src1_row_original[i];
1997
+ }
1998
+ }
1999
+
2000
+ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
2001
+ const mmid_row_mapping * __restrict__ row_mapping,
2002
+ int64_t ne0,
2003
+ size_t nb1, size_t nb2) {
2004
+ int32_t i = blockIdx.x;
2005
+
2006
+ const int32_t i1 = row_mapping[i].i1;
2007
+ const int32_t i2 = row_mapping[i].i2;
2008
+
2009
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
2010
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
2011
+
2012
+ for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
2013
+ dst_row_original[j] = dst_row_contiguous[j];
2014
+ }
2015
+ }
2016
+
2017
  static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
2018
  const ggml_tensor * src0 = dst->src[0];
2019
  const ggml_tensor * src1 = dst->src[1];
2020
  const ggml_tensor * ids = dst->src[2];
2021
 
2022
+ GGML_TENSOR_BINARY_OP_LOCALS
2023
+
2024
  GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
2025
 
2026
  cudaStream_t stream = ctx.stream();
2027
 
2028
+ const int64_t n_as = ne02;
2029
+ const int64_t n_ids = ids->ne[0];
 
 
 
2030
 
2031
  std::vector<char> ids_host(ggml_nbytes(ids));
2032
  const char * ids_dev = (const char *) ids->data;
 
2035
 
2036
  ggml_tensor src0_row = *src0;
2037
  ggml_tensor src1_row = *src1;
2038
+ ggml_tensor dst_row = *dst;
2039
 
2040
  char * src0_original = (char *) src0->data;
2041
  char * src1_original = (char *) src1->data;
 
2043
 
2044
  src0_row.ne[2] = 1;
2045
  src0_row.ne[3] = 1;
2046
+ src0_row.nb[3] = nb02;
2047
 
2048
+ src1_row.ne[1] = 1;
2049
+ src1_row.ne[2] = 1;
2050
+ src1_row.ne[3] = 1;
2051
+ src1_row.nb[2] = nb11;
2052
+ src1_row.nb[3] = nb11;
2053
 
2054
+ dst_row.ne[1] = 1;
2055
+ dst_row.ne[2] = 1;
2056
+ dst_row.ne[3] = 1;
2057
+ dst_row.nb[2] = nb1;
2058
+ dst_row.nb[3] = nb1;
2059
 
2060
+ if (ne12 == 1) {
2061
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2062
+ for (int64_t id = 0; id < n_ids; id++) {
2063
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2064
 
2065
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
2066
+
2067
+ const int64_t i11 = id % ne11;
2068
+ const int64_t i12 = iid1;
2069
+
2070
+ const int64_t i1 = id;
2071
+ const int64_t i2 = i12;
2072
+
2073
+ src0_row.data = src0_original + i02*nb02;
2074
+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
2075
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
2076
+
2077
+ ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2078
+ }
2079
  }
2080
  } else {
2081
  ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
 
2084
  src1_row.data = src1_contiguous.get();
2085
  dst_row.data = dst_contiguous.get();
2086
 
2087
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
2088
  int64_t num_src1_rows = 0;
 
 
2089
 
2090
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
2091
+ for (int64_t id = 0; id < n_ids; id++) {
2092
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
2093
 
2094
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
2095
 
2096
+ if (row_id_i != i02) {
2097
+ continue;
2098
+ }
2099
+
2100
+ num_src1_rows++;
2101
+ }
2102
  }
2103
 
2104
  if (num_src1_rows == 0) {
2105
  continue;
2106
  }
2107
 
2108
+ ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
2109
+ ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
2110
+ CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
2111
 
2112
+ {
2113
+ dim3 block_dims(std::min((unsigned int)ne10, 768u));
2114
+ dim3 grid_dims(ids->ne[1], n_ids);
2115
+ k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2116
+ src1_original, src1_contiguous.get(),
2117
+ dev_cur_src1_row.get(), dev_row_mapping.get(),
2118
+ ids_dev, i02, ids->nb[1], ids->nb[0],
2119
+ ne11, ne10,
2120
+ nb11, nb12);
2121
+ CUDA_CHECK(cudaGetLastError());
2122
+ }
2123
+
2124
+ src0_row.data = src0_original + i02*nb02;
2125
 
2126
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
2127
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
2128
+
2129
+ src1_row.ne[1] = num_src1_rows;
2130
  src1_row.nb[1] = nb11;
2131
  src1_row.nb[2] = num_src1_rows*nb11;
2132
  src1_row.nb[3] = num_src1_rows*nb11;
2133
 
2134
+ dst_row.ne[1] = num_src1_rows;
2135
  dst_row.nb[1] = nb1;
2136
  dst_row.nb[2] = num_src1_rows*nb1;
2137
  dst_row.nb[3] = num_src1_rows*nb1;
2138
 
2139
  ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
2140
 
2141
+ {
2142
+ dim3 block_dims(std::min((unsigned int)ne0, 768u));
2143
+ dim3 grid_dims(num_src1_rows);
2144
+ k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
2145
+ dst_original, dst_contiguous.get(),
2146
+ dev_row_mapping.get(),
2147
+ ne0,
2148
+ nb1, nb2);
2149
+ CUDA_CHECK(cudaGetLastError());
 
 
 
 
2150
  }
2151
  }
2152
  }
 
2579
  GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
2580
  const int min_batch_size = 32;
2581
 
2582
+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
2583
+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
2584
 
2585
  GGML_UNUSED(backend);
2586
  }
ggml-cuda/binbcast.cu CHANGED
@@ -22,6 +22,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
22
  int ne0, int ne1, int ne2, int ne3,
23
  int ne10, int ne11, int ne12, int ne13,
24
  /*int s0, */ int s1, int s2, int s3,
 
25
  /*int s10,*/ int s11, int s12, int s13) {
26
  const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
27
  const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
@@ -36,9 +37,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
36
  const int i12 = i2 % ne12;
37
  const int i13 = i3 % ne13;
38
 
39
- const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
40
  const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
41
- const size_t i_dst = i_src0;
42
 
43
  const src0_t * src0_row = src0 + i_src0;
44
  const src1_t * src1_row = src1 + i_src1;
@@ -55,6 +56,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
55
  int ne0, int ne1, int ne2, int ne3,
56
  int ne10, int ne11, int ne12, int ne13,
57
  /*int s0, */ int s1, int s2, int s3,
 
58
  /*int s10,*/ int s11, int s12, int s13) {
59
 
60
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
@@ -72,9 +74,9 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
72
  const int i12 = i2 % ne12;
73
  const int i13 = i3 % ne13;
74
 
75
- const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
76
  const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
77
- const size_t i_dst = i_src0;
78
 
79
  const src0_t * src0_row = src0 + i_src0;
80
  const src1_t * src1_row = src1 + i_src1;
@@ -101,10 +103,14 @@ struct bin_bcast_cuda {
101
  int nr[4] = { nr0, nr1, nr2, nr3 };
102
 
103
  // collapse dimensions until first broadcast dimension
104
- int64_t cne0[] = {ne0, ne1, ne2, ne3};
 
105
  int64_t cne1[] = {ne10, ne11, ne12, ne13};
106
- size_t cnb0[] = {nb0, nb1, nb2, nb3};
 
 
107
  size_t cnb1[] = {nb10, nb11, nb12, nb13};
 
108
  auto collapse = [](int64_t cne[]) {
109
  cne[0] *= cne[1];
110
  cne[1] = cne[2];
@@ -118,32 +124,47 @@ struct bin_bcast_cuda {
118
  cnb[3] *= cne[3];
119
  };
120
 
121
- for (int i = 0; i < 4; i++) {
122
- if (nr[i] != 1) {
123
- break;
124
- }
125
- if (i > 0) {
126
- collapse_nb(cnb0, cne0);
127
- collapse_nb(cnb1, cne1);
128
- collapse(cne0);
129
- collapse(cne1);
 
 
 
 
130
  }
131
  }
 
132
  {
133
- int64_t ne0 = cne0[0];
134
- int64_t ne1 = cne0[1];
135
- int64_t ne2 = cne0[2];
136
- int64_t ne3 = cne0[3];
 
 
 
 
 
137
 
138
  int64_t ne10 = cne1[0];
139
  int64_t ne11 = cne1[1];
140
  int64_t ne12 = cne1[2];
141
  int64_t ne13 = cne1[3];
142
 
143
- size_t nb0 = cnb0[0];
144
- size_t nb1 = cnb0[1];
145
- size_t nb2 = cnb0[2];
146
- size_t nb3 = cnb0[3];
 
 
 
 
 
147
 
148
  size_t nb10 = cnb1[0];
149
  size_t nb11 = cnb1[1];
@@ -160,7 +181,28 @@ struct bin_bcast_cuda {
160
  size_t s12 = nb12 / sizeof(src1_t);
161
  size_t s13 = nb13 / sizeof(src1_t);
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  GGML_ASSERT(s0 == 1);
 
164
  GGML_ASSERT(s10 == 1);
165
 
166
  const int block_size = 128;
@@ -179,13 +221,14 @@ struct bin_bcast_cuda {
179
  );
180
 
181
  if (block_nums.z > 65535) {
182
- // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
183
  int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
184
  k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
185
  src0_dd, src1_dd, dst_dd,
186
  ne0, ne1, ne2, ne3,
187
  ne10, ne11, ne12, ne13,
188
  /* s0, */ s1, s2, s3,
 
189
  /* s10, */ s11, s12, s13);
190
  } else {
191
  k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
@@ -193,6 +236,7 @@ struct bin_bcast_cuda {
193
  ne0, ne1, ne2, ne3,
194
  ne10, ne11, ne12, ne13,
195
  /* s0, */ s1, s2, s3,
 
196
  /* s10, */ s11, s12, s13);
197
  }
198
  }
 
22
  int ne0, int ne1, int ne2, int ne3,
23
  int ne10, int ne11, int ne12, int ne13,
24
  /*int s0, */ int s1, int s2, int s3,
25
+ /*int s00,*/ int s01, int s02, int s03,
26
  /*int s10,*/ int s11, int s12, int s13) {
27
  const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
28
  const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
 
37
  const int i12 = i2 % ne12;
38
  const int i13 = i3 % ne13;
39
 
40
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
41
  const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
42
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
43
 
44
  const src0_t * src0_row = src0 + i_src0;
45
  const src1_t * src1_row = src1 + i_src1;
 
56
  int ne0, int ne1, int ne2, int ne3,
57
  int ne10, int ne11, int ne12, int ne13,
58
  /*int s0, */ int s1, int s2, int s3,
59
+ /*int s00,*/ int s01, int s02, int s03,
60
  /*int s10,*/ int s11, int s12, int s13) {
61
 
62
  const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
74
  const int i12 = i2 % ne12;
75
  const int i13 = i3 % ne13;
76
 
77
+ const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
78
  const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
79
+ const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
80
 
81
  const src0_t * src0_row = src0 + i_src0;
82
  const src1_t * src1_row = src1 + i_src1;
 
103
  int nr[4] = { nr0, nr1, nr2, nr3 };
104
 
105
  // collapse dimensions until first broadcast dimension
106
+ int64_t cne[] = {ne0, ne1, ne2, ne3};
107
+ int64_t cne0[] = {ne00, ne01, ne02, ne03};
108
  int64_t cne1[] = {ne10, ne11, ne12, ne13};
109
+
110
+ size_t cnb[] = {nb0, nb1, nb2, nb3};
111
+ size_t cnb0[] = {nb00, nb01, nb02, nb03};
112
  size_t cnb1[] = {nb10, nb11, nb12, nb13};
113
+
114
  auto collapse = [](int64_t cne[]) {
115
  cne[0] *= cne[1];
116
  cne[1] = cne[2];
 
124
  cnb[3] *= cne[3];
125
  };
126
 
127
+ if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
128
+ for (int i = 0; i < 4; i++) {
129
+ if (nr[i] != 1) {
130
+ break;
131
+ }
132
+ if (i > 0) {
133
+ collapse_nb(cnb, cne);
134
+ collapse_nb(cnb0, cne0);
135
+ collapse_nb(cnb1, cne1);
136
+ collapse(cne);
137
+ collapse(cne0);
138
+ collapse(cne1);
139
+ }
140
  }
141
  }
142
+
143
  {
144
+ int64_t ne0 = cne[0];
145
+ int64_t ne1 = cne[1];
146
+ int64_t ne2 = cne[2];
147
+ int64_t ne3 = cne[3];
148
+
149
+ //int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
150
+ //int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
151
+ //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
152
+ //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
153
 
154
  int64_t ne10 = cne1[0];
155
  int64_t ne11 = cne1[1];
156
  int64_t ne12 = cne1[2];
157
  int64_t ne13 = cne1[3];
158
 
159
+ size_t nb0 = cnb[0];
160
+ size_t nb1 = cnb[1];
161
+ size_t nb2 = cnb[2];
162
+ size_t nb3 = cnb[3];
163
+
164
+ size_t nb00 = cnb0[0];
165
+ size_t nb01 = cnb0[1];
166
+ size_t nb02 = cnb0[2];
167
+ size_t nb03 = cnb0[3];
168
 
169
  size_t nb10 = cnb1[0];
170
  size_t nb11 = cnb1[1];
 
181
  size_t s12 = nb12 / sizeof(src1_t);
182
  size_t s13 = nb13 / sizeof(src1_t);
183
 
184
+ size_t s00 = nb00 / sizeof(src0_t);
185
+ size_t s01 = nb01 / sizeof(src0_t);
186
+ size_t s02 = nb02 / sizeof(src0_t);
187
+ size_t s03 = nb03 / sizeof(src0_t);
188
+
189
+ GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
190
+ GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
191
+ GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
192
+ GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
193
+
194
+ GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
195
+ GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
196
+ GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
197
+ GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
198
+
199
+ GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
200
+ GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
201
+ GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
202
+ GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
203
+
204
  GGML_ASSERT(s0 == 1);
205
+ GGML_ASSERT(s00 == 1);
206
  GGML_ASSERT(s10 == 1);
207
 
208
  const int block_size = 128;
 
221
  );
222
 
223
  if (block_nums.z > 65535) {
224
+ // this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
225
  int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
226
  k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
227
  src0_dd, src1_dd, dst_dd,
228
  ne0, ne1, ne2, ne3,
229
  ne10, ne11, ne12, ne13,
230
  /* s0, */ s1, s2, s3,
231
+ /* s00, */ s01, s02, s03,
232
  /* s10, */ s11, s12, s13);
233
  } else {
234
  k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
 
236
  ne0, ne1, ne2, ne3,
237
  ne10, ne11, ne12, ne13,
238
  /* s0, */ s1, s2, s3,
239
+ /* s00, */ s01, s02, s03,
240
  /* s10, */ s11, s12, s13);
241
  }
242
  }
ggml-cuda/convert.cu CHANGED
@@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
45
  vals[ix] = x0[ix];
46
  }
47
 
 
 
48
  #pragma unroll
49
  for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
50
  if (need_check && i0 + iy + 2*threadIdx.x >= k) {
 
45
  vals[ix] = x0[ix];
46
  }
47
 
48
+ __syncthreads();
49
+
50
  #pragma unroll
51
  for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
52
  if (need_check && i0 + iy + 2*threadIdx.x >= k) {
ggml-metal.m CHANGED
@@ -1747,15 +1747,10 @@ static enum ggml_status ggml_metal_graph_compute(
1747
  } break;
1748
  case GGML_OP_MUL_MAT_ID:
1749
  {
1750
- //GGML_ASSERT(ne00 == ne10);
1751
- //GGML_ASSERT(ne03 == ne13);
1752
  const int n_as = src0->ne[2];
1753
 
1754
- // max size of the src1ids array in the kernel shared buffer
1755
- GGML_ASSERT(ne11 <= 4096);
1756
-
1757
  // src2 = ids
1758
- const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
1759
  const int64_t ne21 = src2->ne[1];
1760
  const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1761
  const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
@@ -1776,15 +1771,13 @@ static enum ggml_status ggml_metal_graph_compute(
1776
 
1777
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1778
  // to the matrix-vector kernel
1779
- int ne11_mm_min = n_as;
1780
-
1781
- const int idx = ((int32_t *) dst->op_params)[0];
 
1782
 
1783
- // batch size
1784
- GGML_ASSERT(ne21 == ne11); // ?
1785
- GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
1786
- const uint r2 = 1;
1787
- const uint r3 = 1;
1788
 
1789
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1790
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@@ -1794,7 +1787,7 @@ static enum ggml_status ggml_metal_graph_compute(
1794
  // !!!
1795
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1796
  ne00 % 32 == 0 && ne00 >= 64 &&
1797
- ne11 > ne11_mm_min) {
1798
 
1799
  // some Metal matrix data types require aligned pointers
1800
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@@ -1836,26 +1829,26 @@ static enum ggml_status ggml_metal_graph_compute(
1836
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1837
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1838
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1839
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
1840
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
1841
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
1842
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1843
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1844
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
1845
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
1846
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1847
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1848
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
1849
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
1850
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
1851
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
1852
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
1853
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
1854
- [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
1855
-
1856
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
1857
-
1858
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1859
  } else {
1860
  int nth0 = 32;
1861
  int nth1 = 1;
@@ -2008,72 +2001,72 @@ static enum ggml_status ggml_metal_graph_compute(
2008
  GGML_ASSERT(ne00 >= nth0*nth1);
2009
  }
2010
 
2011
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
2012
-
2013
  [encoder setComputePipelineState:pipeline];
2014
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2015
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2016
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2017
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2018
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
2019
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
2020
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
2021
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
2022
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
2023
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
2024
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
2025
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
2026
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
2027
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
2028
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
2029
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
2030
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
2031
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
2032
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
2033
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
2034
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
2035
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
2036
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
2037
- [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
 
 
2038
 
2039
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2040
  src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2041
  src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2042
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2043
  }
2044
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2045
  const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2046
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2047
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2048
  }
2049
  else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2050
  const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2051
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2052
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2053
  }
2054
  else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2055
  const int mem_size = 32*sizeof(float);
2056
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2057
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2058
  }
2059
  else if (src0t == GGML_TYPE_Q4_K) {
2060
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2061
  }
2062
  else if (src0t == GGML_TYPE_Q3_K) {
2063
  #ifdef GGML_QKK_64
2064
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2065
  #else
2066
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2067
  #endif
2068
  }
2069
  else if (src0t == GGML_TYPE_Q5_K) {
2070
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2071
  }
2072
  else if (src0t == GGML_TYPE_Q6_K) {
2073
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2074
  } else {
2075
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
2076
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2077
  }
2078
  }
2079
  } break;
 
1747
  } break;
1748
  case GGML_OP_MUL_MAT_ID:
1749
  {
 
 
1750
  const int n_as = src0->ne[2];
1751
 
 
 
 
1752
  // src2 = ids
1753
+ const int64_t ne20 = src2->ne[0];
1754
  const int64_t ne21 = src2->ne[1];
1755
  const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
1756
  const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
 
1771
 
1772
  // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1773
  // to the matrix-vector kernel
1774
+ // ne20 = n_used_experts
1775
+ // ne21 = n_rows
1776
+ const int dst_rows = ne20*ne21;
1777
+ const int dst_rows_min = n_as;
1778
 
1779
+ // max size of the rowids array in the kernel shared buffer
1780
+ GGML_ASSERT(dst_rows <= 2048);
 
 
 
1781
 
1782
  // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
1783
  // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
 
1787
  // !!!
1788
  if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1789
  ne00 % 32 == 0 && ne00 >= 64 &&
1790
+ dst_rows > dst_rows_min) {
1791
 
1792
  // some Metal matrix data types require aligned pointers
1793
  // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
 
1829
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1830
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1831
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
1832
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
1833
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
1834
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
1835
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
1836
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
1837
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
1838
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
1839
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
1840
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
1841
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
1842
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1843
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1844
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1845
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
1846
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
1847
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
1848
+
1849
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
1850
+
1851
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1852
  } else {
1853
  int nth0 = 32;
1854
  int nth1 = 1;
 
2001
  GGML_ASSERT(ne00 >= nth0*nth1);
2002
  }
2003
 
 
 
2004
  [encoder setComputePipelineState:pipeline];
2005
  [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
2006
  [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
2007
  [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
2008
  [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
2009
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
2010
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
2011
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
2012
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
2013
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
2014
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
2015
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
2016
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
2017
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
2018
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
2019
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
2020
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
2021
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
2022
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
2023
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
2024
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
2025
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
2026
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
2027
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
2028
+
2029
+ const int64_t _ne1 = 1;
2030
+ const int tgz = dst_rows;
2031
 
2032
  if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
2033
  src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
2034
  src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
2035
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2036
  }
2037
  else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
2038
  const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
2039
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2040
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2041
  }
2042
  else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
2043
  const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
2044
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2045
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2046
  }
2047
  else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
2048
  const int mem_size = 32*sizeof(float);
2049
  [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
2050
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2051
  }
2052
  else if (src0t == GGML_TYPE_Q4_K) {
2053
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2054
  }
2055
  else if (src0t == GGML_TYPE_Q3_K) {
2056
  #ifdef GGML_QKK_64
2057
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2058
  #else
2059
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2060
  #endif
2061
  }
2062
  else if (src0t == GGML_TYPE_Q5_K) {
2063
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2064
  }
2065
  else if (src0t == GGML_TYPE_Q6_K) {
2066
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2067
  } else {
2068
+ const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
2069
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
2070
  }
2071
  }
2072
  } break;
ggml-metal.metal CHANGED
@@ -899,16 +899,16 @@ void mul_vec_q_n_f32_impl(
899
  device const void * src0,
900
  device const float * src1,
901
  device float * dst,
902
- constant int64_t & ne00,
903
- constant int64_t & ne01,
904
- constant int64_t & ne02,
905
- constant int64_t & ne10,
906
- constant int64_t & ne12,
907
- constant int64_t & ne0,
908
- constant int64_t & ne1,
909
- constant uint & r2,
910
- constant uint & r3,
911
- threadgroup int8_t * shared_values,
912
  uint3 tgpig, uint tiisg, uint sgitg) {
913
  const int nb = ne00/QK4_0;
914
 
@@ -1073,19 +1073,19 @@ void kernel_mul_mv_q8_0_f32_impl(
1073
  device const void * src0,
1074
  device const float * src1,
1075
  device float * dst,
1076
- constant int64_t & ne00,
1077
- constant int64_t & ne01,
1078
- constant int64_t & ne02,
1079
- constant int64_t & ne10,
1080
- constant int64_t & ne12,
1081
- constant int64_t & ne0,
1082
- constant int64_t & ne1,
1083
- constant uint & r2,
1084
- constant uint & r3,
1085
- threadgroup int8_t * shared_values [[threadgroup(0)]],
1086
- uint3 tgpig[[threadgroup_position_in_grid]],
1087
- uint tiisg[[thread_index_in_simdgroup]],
1088
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
1089
  const int nr = N_DST;
1090
  const int nsg = N_SIMDGROUP;
1091
  const int nw = N_SIMDWIDTH;
@@ -1172,24 +1172,24 @@ void kernel_mul_mv_f32_f32_impl(
1172
  device const char * src0,
1173
  device const char * src1,
1174
  device float * dst,
1175
- constant int64_t & ne00,
1176
- constant int64_t & ne01,
1177
- constant int64_t & ne02,
1178
- constant uint64_t & nb00,
1179
- constant uint64_t & nb01,
1180
- constant uint64_t & nb02,
1181
- constant int64_t & ne10,
1182
- constant int64_t & ne11,
1183
- constant int64_t & ne12,
1184
- constant uint64_t & nb10,
1185
- constant uint64_t & nb11,
1186
- constant uint64_t & nb12,
1187
- constant int64_t & ne0,
1188
- constant int64_t & ne1,
1189
- constant uint & r2,
1190
- constant uint & r3,
1191
- uint3 tgpig[[threadgroup_position_in_grid]],
1192
- uint tiisg[[thread_index_in_simdgroup]]) {
1193
 
1194
  const int64_t r0 = tgpig.x;
1195
  const int64_t rb = tgpig.y*N_F32_F32;
@@ -1442,24 +1442,24 @@ void kernel_mul_mv_f16_f32_impl(
1442
  device const char * src0,
1443
  device const char * src1,
1444
  device float * dst,
1445
- constant int64_t & ne00,
1446
- constant int64_t & ne01,
1447
- constant int64_t & ne02,
1448
- constant uint64_t & nb00,
1449
- constant uint64_t & nb01,
1450
- constant uint64_t & nb02,
1451
- constant int64_t & ne10,
1452
- constant int64_t & ne11,
1453
- constant int64_t & ne12,
1454
- constant uint64_t & nb10,
1455
- constant uint64_t & nb11,
1456
- constant uint64_t & nb12,
1457
- constant int64_t & ne0,
1458
- constant int64_t & ne1,
1459
- constant uint & r2,
1460
- constant uint & r3,
1461
- uint3 tgpig[[threadgroup_position_in_grid]],
1462
- uint tiisg[[thread_index_in_simdgroup]]) {
1463
 
1464
  const int64_t r0 = tgpig.x;
1465
  const int64_t rb = tgpig.y*N_F16_F32;
@@ -2744,19 +2744,19 @@ void kernel_mul_mv_q2_K_f32_impl(
2744
  device const void * src0,
2745
  device const float * src1,
2746
  device float * dst,
2747
- constant int64_t & ne00,
2748
- constant int64_t & ne01,
2749
- constant int64_t & ne02,
2750
- constant int64_t & ne10,
2751
- constant int64_t & ne12,
2752
- constant int64_t & ne0,
2753
- constant int64_t & ne1,
2754
- constant uint & r2,
2755
- constant uint & r3,
2756
- threadgroup int8_t * shared_values [[threadgroup(0)]],
2757
- uint3 tgpig[[threadgroup_position_in_grid]],
2758
- uint tiisg[[thread_index_in_simdgroup]],
2759
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2760
 
2761
  const int nb = ne00/QK_K;
2762
  const int r0 = tgpig.x;
@@ -2924,19 +2924,19 @@ void kernel_mul_mv_q3_K_f32_impl(
2924
  device const void * src0,
2925
  device const float * src1,
2926
  device float * dst,
2927
- constant int64_t & ne00,
2928
- constant int64_t & ne01,
2929
- constant int64_t & ne02,
2930
- constant int64_t & ne10,
2931
- constant int64_t & ne12,
2932
- constant int64_t & ne0,
2933
- constant int64_t & ne1,
2934
- constant uint & r2,
2935
- constant uint & r3,
2936
- threadgroup int8_t * shared_values [[threadgroup(0)]],
2937
- uint3 tgpig[[threadgroup_position_in_grid]],
2938
- uint tiisg[[thread_index_in_simdgroup]],
2939
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
2940
 
2941
  const int nb = ne00/QK_K;
2942
 
@@ -3190,19 +3190,19 @@ void kernel_mul_mv_q4_K_f32_impl(
3190
  device const void * src0,
3191
  device const float * src1,
3192
  device float * dst,
3193
- constant int64_t & ne00,
3194
- constant int64_t & ne01,
3195
- constant int64_t & ne02,
3196
- constant int64_t & ne10,
3197
- constant int64_t & ne12,
3198
- constant int64_t & ne0,
3199
- constant int64_t & ne1,
3200
- constant uint & r2,
3201
- constant uint & r3,
3202
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3203
- uint3 tgpig[[threadgroup_position_in_grid]],
3204
- uint tiisg[[thread_index_in_simdgroup]],
3205
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3206
 
3207
  const uint16_t kmask1 = 0x3f3f;
3208
  const uint16_t kmask2 = 0x0f0f;
@@ -3429,19 +3429,19 @@ void kernel_mul_mv_q5_K_f32_impl(
3429
  device const void * src0,
3430
  device const float * src1,
3431
  device float * dst,
3432
- constant int64_t & ne00,
3433
- constant int64_t & ne01,
3434
- constant int64_t & ne02,
3435
- constant int64_t & ne10,
3436
- constant int64_t & ne12,
3437
- constant int64_t & ne0,
3438
- constant int64_t & ne1,
3439
- constant uint & r2,
3440
- constant uint & r3,
3441
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3442
- uint3 tgpig[[threadgroup_position_in_grid]],
3443
- uint tiisg[[thread_index_in_simdgroup]],
3444
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3445
 
3446
  const int nb = ne00/QK_K;
3447
 
@@ -3636,19 +3636,19 @@ void kernel_mul_mv_q6_K_f32_impl(
3636
  device const void * src0,
3637
  device const float * src1,
3638
  device float * dst,
3639
- constant int64_t & ne00,
3640
- constant int64_t & ne01,
3641
- constant int64_t & ne02,
3642
- constant int64_t & ne10,
3643
- constant int64_t & ne12,
3644
- constant int64_t & ne0,
3645
- constant int64_t & ne1,
3646
- constant uint & r2,
3647
- constant uint & r3,
3648
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3649
- uint3 tgpig[[threadgroup_position_in_grid]],
3650
- uint tiisg[[thread_index_in_simdgroup]],
3651
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3652
 
3653
  const uint8_t kmask1 = 0x03;
3654
  const uint8_t kmask2 = 0x0C;
@@ -3773,19 +3773,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
3773
  device const void * src0,
3774
  device const float * src1,
3775
  device float * dst,
3776
- constant int64_t & ne00,
3777
- constant int64_t & ne01,
3778
- constant int64_t & ne02,
3779
- constant int64_t & ne10,
3780
- constant int64_t & ne12,
3781
- constant int64_t & ne0,
3782
- constant int64_t & ne1,
3783
- constant uint & r2,
3784
- constant uint & r3,
3785
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3786
- uint3 tgpig[[threadgroup_position_in_grid]],
3787
- uint tiisg[[thread_index_in_simdgroup]],
3788
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3789
 
3790
  const int nb = ne00/QK_K;
3791
  const int r0 = tgpig.x;
@@ -3902,19 +3902,19 @@ void kernel_mul_mv_iq2_xs_f32_impl(
3902
  device const void * src0,
3903
  device const float * src1,
3904
  device float * dst,
3905
- constant int64_t & ne00,
3906
- constant int64_t & ne01,
3907
- constant int64_t & ne02,
3908
- constant int64_t & ne10,
3909
- constant int64_t & ne12,
3910
- constant int64_t & ne0,
3911
- constant int64_t & ne1,
3912
- constant uint & r2,
3913
- constant uint & r3,
3914
- threadgroup int8_t * shared_values [[threadgroup(0)]],
3915
- uint3 tgpig[[threadgroup_position_in_grid]],
3916
- uint tiisg[[thread_index_in_simdgroup]],
3917
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
3918
 
3919
  const int nb = ne00/QK_K;
3920
  const int r0 = tgpig.x;
@@ -4041,19 +4041,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
4041
  device const void * src0,
4042
  device const float * src1,
4043
  device float * dst,
4044
- constant int64_t & ne00,
4045
- constant int64_t & ne01,
4046
- constant int64_t & ne02,
4047
- constant int64_t & ne10,
4048
- constant int64_t & ne12,
4049
- constant int64_t & ne0,
4050
- constant int64_t & ne1,
4051
- constant uint & r2,
4052
- constant uint & r3,
4053
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4054
- uint3 tgpig[[threadgroup_position_in_grid]],
4055
- uint tiisg[[thread_index_in_simdgroup]],
4056
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4057
 
4058
  const int nb = ne00/QK_K;
4059
  const int r0 = tgpig.x;
@@ -4173,19 +4173,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
4173
  device const void * src0,
4174
  device const float * src1,
4175
  device float * dst,
4176
- constant int64_t & ne00,
4177
- constant int64_t & ne01,
4178
- constant int64_t & ne02,
4179
- constant int64_t & ne10,
4180
- constant int64_t & ne12,
4181
- constant int64_t & ne0,
4182
- constant int64_t & ne1,
4183
- constant uint & r2,
4184
- constant uint & r3,
4185
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4186
- uint3 tgpig[[threadgroup_position_in_grid]],
4187
- uint tiisg[[thread_index_in_simdgroup]],
4188
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4189
 
4190
  const int nb = ne00/QK_K;
4191
  const int r0 = tgpig.x;
@@ -4305,19 +4305,19 @@ void kernel_mul_mv_iq2_s_f32_impl(
4305
  device const void * src0,
4306
  device const float * src1,
4307
  device float * dst,
4308
- constant int64_t & ne00,
4309
- constant int64_t & ne01,
4310
- constant int64_t & ne02,
4311
- constant int64_t & ne10,
4312
- constant int64_t & ne12,
4313
- constant int64_t & ne0,
4314
- constant int64_t & ne1,
4315
- constant uint & r2,
4316
- constant uint & r3,
4317
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4318
- uint3 tgpig[[threadgroup_position_in_grid]],
4319
- uint tiisg[[thread_index_in_simdgroup]],
4320
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4321
 
4322
  const int nb = ne00/QK_K;
4323
  const int r0 = tgpig.x;
@@ -4438,19 +4438,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
4438
  device const void * src0,
4439
  device const float * src1,
4440
  device float * dst,
4441
- constant int64_t & ne00,
4442
- constant int64_t & ne01,
4443
- constant int64_t & ne02,
4444
- constant int64_t & ne10,
4445
- constant int64_t & ne12,
4446
- constant int64_t & ne0,
4447
- constant int64_t & ne1,
4448
- constant uint & r2,
4449
- constant uint & r3,
4450
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4451
- uint3 tgpig[[threadgroup_position_in_grid]],
4452
- uint tiisg[[thread_index_in_simdgroup]],
4453
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4454
 
4455
  const int nb = ne00/QK_K;
4456
  const int r0 = tgpig.x;
@@ -4528,19 +4528,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
4528
  device const void * src0,
4529
  device const float * src1,
4530
  device float * dst,
4531
- constant int64_t & ne00,
4532
- constant int64_t & ne01,
4533
- constant int64_t & ne02,
4534
- constant int64_t & ne10,
4535
- constant int64_t & ne12,
4536
- constant int64_t & ne0,
4537
- constant int64_t & ne1,
4538
- constant uint & r2,
4539
- constant uint & r3,
4540
- threadgroup int8_t * shared_values [[threadgroup(0)]],
4541
- uint3 tgpig[[threadgroup_position_in_grid]],
4542
- uint tiisg[[thread_index_in_simdgroup]],
4543
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4544
 
4545
  const int nb = ne00/QK_K;
4546
  const int r0 = tgpig.x;
@@ -4637,19 +4637,19 @@ void kernel_mul_mv_iq4_nl_f32_impl(
4637
  device const void * src0,
4638
  device const float * src1,
4639
  device float * dst,
4640
- constant int64_t & ne00,
4641
- constant int64_t & ne01,
4642
- constant int64_t & ne02,
4643
- constant int64_t & ne10,
4644
- constant int64_t & ne12,
4645
- constant int64_t & ne0,
4646
- constant int64_t & ne1,
4647
- constant uint & r2,
4648
- constant uint & r3,
4649
- threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
4650
- uint3 tgpig[[threadgroup_position_in_grid]],
4651
- uint tiisg[[thread_index_in_simdgroup]],
4652
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
4653
 
4654
  threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4655
  const int nb = ne00/QK4_NL;
@@ -4732,19 +4732,20 @@ void kernel_mul_mv_iq4_xs_f32_impl(
4732
  device const void * src0,
4733
  device const float * src1,
4734
  device float * dst,
4735
- constant int64_t & ne00,
4736
- constant int64_t & ne01,
4737
- constant int64_t & ne02,
4738
- constant int64_t & ne10,
4739
- constant int64_t & ne12,
4740
- constant int64_t & ne0,
4741
- constant int64_t & ne1,
4742
- constant uint & r2,
4743
- constant uint & r3,
4744
- threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
4745
- uint3 tgpig[[threadgroup_position_in_grid]],
4746
- uint tiisg[[thread_index_in_simdgroup]],
4747
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
4748
  threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4749
  const int nb = ne00/QK_K;
4750
  const int r0 = tgpig.x;
@@ -5686,25 +5687,25 @@ void kernel_mul_mm_impl(device const uchar * src0,
5686
  }
5687
  }
5688
 
5689
- // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
5690
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5691
  void kernel_mul_mm_id_impl(
5692
  device const uchar * src0,
5693
  device const uchar * src1,
5694
- threadgroup short * src1ids,
5695
  device float * dst,
5696
  constant int64_t & ne00,
5697
  constant int64_t & ne02,
5698
  constant uint64_t & nb01,
5699
  constant uint64_t & nb02,
 
5700
  constant int64_t & ne12,
5701
  constant uint64_t & nb10,
5702
  constant uint64_t & nb11,
5703
  constant uint64_t & nb12,
5704
  constant int64_t & ne0,
5705
  int64_t ne1,
5706
- constant uint & r2,
5707
- constant uint & r3,
5708
  threadgroup uchar * shared_memory,
5709
  uint3 tgpig[[threadgroup_position_in_grid]],
5710
  uint tiitg[[thread_index_in_threadgroup]],
@@ -5715,7 +5716,6 @@ void kernel_mul_mm_id_impl(
5715
 
5716
  const uint r0 = tgpig.y;
5717
  const uint r1 = tgpig.x;
5718
- const uint im = tgpig.z;
5719
 
5720
  if (r1 * BLOCK_SIZE_N >= ne1) return;
5721
 
@@ -5733,19 +5733,16 @@ void kernel_mul_mm_id_impl(
5733
  for (int i = 0; i < 8; i++){
5734
  c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
5735
  }
5736
-
5737
  short il = (tiitg % THREAD_PER_ROW);
5738
 
5739
- const uint i12 = im%ne12;
5740
- const uint i13 = im/ne12;
5741
-
5742
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
5743
  ushort offset1 = il/nl;
5744
 
5745
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
 
 
5746
  device const float * y = (device const float *)(src1
5747
- + nb12 * im
5748
- + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
5749
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5750
 
5751
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
@@ -5774,11 +5771,11 @@ void kernel_mul_mm_id_impl(
5774
 
5775
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
5776
  for (int i = 0; i < 4; i++) {
5777
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
5778
  }
5779
  simdgroup_barrier(mem_flags::mem_none);
5780
  for (int i = 0; i < 2; i++) {
5781
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
5782
  }
5783
 
5784
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
@@ -5800,11 +5797,13 @@ void kernel_mul_mm_id_impl(
5800
 
5801
  threadgroup_barrier(mem_flags::mem_threadgroup);
5802
 
5803
- device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
5804
  if (sgitg == 0) {
5805
- for (int i = 0; i < n_rows; i++) {
5806
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
5807
- *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
 
 
5808
  }
5809
  }
5810
  }
@@ -5859,11 +5858,14 @@ kernel void kernel_mul_mm_id(
5859
  device const uchar * src1,
5860
  device float * dst,
5861
  device const uchar * ids,
 
 
5862
  constant uint64_t & nbi1,
5863
  constant int64_t & ne00,
5864
  constant int64_t & ne02,
5865
  constant uint64_t & nb01,
5866
  constant uint64_t & nb02,
 
5867
  constant int64_t & ne12,
5868
  constant int64_t & ne13,
5869
  constant uint64_t & nb10,
@@ -5872,47 +5874,52 @@ kernel void kernel_mul_mm_id(
5872
  constant int64_t & ne0,
5873
  constant int64_t & ne1,
5874
  constant uint64_t & nb1,
5875
- constant uint & r2,
5876
- constant uint & r3,
5877
- constant int & idx,
5878
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5879
  uint3 tgpig[[threadgroup_position_in_grid]],
5880
  uint tiitg[[thread_index_in_threadgroup]],
5881
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5882
 
5883
- // expert id
5884
- const int32_t id = tgpig.z/(ne12*ne13);
5885
- device const uchar * src0 = src0s + id*nb02;
5886
 
5887
- tgpig.z = tgpig.z%(ne12*ne13);
5888
 
5889
- // row indices of src1 for expert id
5890
- threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
5891
 
 
5892
  int64_t _ne1 = 0;
5893
- for (int64_t i1 = 0; i1 < ne1; i1++) {
5894
- if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
5895
- src1ids[_ne1++] = i1;
 
 
 
 
 
 
5896
  }
5897
  }
5898
 
 
 
5899
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5900
  src0,
5901
  src1,
5902
- src1ids,
5903
  dst,
5904
  ne00,
5905
  ne02,
5906
  nb01,
5907
  nb02,
 
5908
  ne12,
5909
  nb10,
5910
  nb11,
5911
  nb12,
5912
  ne0,
5913
  _ne1,
5914
- r2,
5915
- r3,
5916
  shared_memory,
5917
  tgpig,
5918
  tiitg,
@@ -5973,24 +5980,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r
5973
  // matrix-matrix multiplication
5974
  //
5975
 
5976
- typedef void (mat_mm_t)(
5977
- device const uchar * src0,
5978
- device const uchar * src1,
5979
- device float * dst,
5980
- constant int64_t & ne00,
5981
- constant int64_t & ne02,
5982
- constant uint64_t & nb01,
5983
- constant uint64_t & nb02,
5984
- constant int64_t & ne12,
5985
- constant uint64_t & nb10,
5986
- constant uint64_t & nb11,
5987
- constant uint64_t & nb12,
5988
- constant int64_t & ne0,
5989
- constant int64_t & ne1,
5990
- constant uint & r2,
5991
- constant uint & r3,
5992
- threadgroup uchar *,
5993
- uint3, uint, uint);
5994
 
5995
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5996
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
@@ -6022,29 +6012,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
6022
  // indirect matrix-matrix multiplication
6023
  //
6024
 
6025
- typedef void (mat_mm_id_t)(
6026
- device const uchar * src0s,
6027
- device const uchar * src1,
6028
- device float * dst,
6029
- device const uchar * ids,
6030
- constant uint64_t & nbi1,
6031
- constant int64_t & ne00,
6032
- constant int64_t & ne02,
6033
- constant uint64_t & nb01,
6034
- constant uint64_t & nb02,
6035
- constant int64_t & ne12,
6036
- constant int64_t & ne13,
6037
- constant uint64_t & nb10,
6038
- constant uint64_t & nb11,
6039
- constant uint64_t & nb12,
6040
- constant int64_t & ne0,
6041
- constant int64_t & ne1,
6042
- constant uint64_t & nb1,
6043
- constant uint & r2,
6044
- constant uint & r3,
6045
- constant int & idx,
6046
- threadgroup uchar *,
6047
- uint3, uint, uint);
6048
 
6049
  template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
6050
  template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
@@ -6080,71 +6048,71 @@ typedef void (kernel_mul_mv_impl_t)(
6080
  device const char * src0,
6081
  device const char * src1,
6082
  device float * dst,
6083
- constant int64_t & ne00,
6084
- constant int64_t & ne01,
6085
- constant int64_t & ne02,
6086
- constant uint64_t & nb00,
6087
- constant uint64_t & nb01,
6088
- constant uint64_t & nb02,
6089
- constant int64_t & ne10,
6090
- constant int64_t & ne11,
6091
- constant int64_t & ne12,
6092
- constant uint64_t & nb10,
6093
- constant uint64_t & nb11,
6094
- constant uint64_t & nb12,
6095
- constant int64_t & ne0,
6096
- constant int64_t & ne1,
6097
- constant uint & r2,
6098
- constant uint & r3,
6099
- uint3 tgpig[[threadgroup_position_in_grid]],
6100
- uint tiisg[[thread_index_in_simdgroup]]);
6101
 
6102
  typedef void (kernel_mul_mv2_impl_t)(
6103
  device const void * src0,
6104
  device const float * src1,
6105
  device float * dst,
6106
- constant int64_t & ne00,
6107
- constant int64_t & ne01,
6108
- constant int64_t & ne02,
6109
- constant int64_t & ne10,
6110
- constant int64_t & ne12,
6111
- constant int64_t & ne0,
6112
- constant int64_t & ne1,
6113
- constant uint & r2,
6114
- constant uint & r3,
6115
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6116
- uint3 tgpig[[threadgroup_position_in_grid]],
6117
- uint tiisg[[thread_index_in_simdgroup]],
6118
- uint sgitg[[simdgroup_index_in_threadgroup]]);
6119
 
6120
  template<kernel_mul_mv_impl_t impl_fn>
6121
  void mmv_fn(
6122
  device const char * src0,
6123
  device const char * src1,
6124
  device float * dst,
6125
- constant int64_t & ne00,
6126
- constant int64_t & ne01,
6127
- constant int64_t & ne02,
6128
- constant uint64_t & nb00,
6129
- constant uint64_t & nb01,
6130
- constant uint64_t & nb02,
6131
- constant int64_t & ne10,
6132
- constant int64_t & ne11,
6133
- constant int64_t & ne12,
6134
- constant int64_t & ne13,
6135
- constant uint64_t & nb10,
6136
- constant uint64_t & nb11,
6137
- constant uint64_t & nb12,
6138
- constant int64_t & ne0,
6139
- constant int64_t & ne1,
6140
- constant uint64_t & nb1,
6141
- constant uint & r2,
6142
- constant uint & r3,
6143
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6144
- uint3 tgpig[[threadgroup_position_in_grid]],
6145
- uint tiitg[[thread_index_in_threadgroup]],
6146
- uint tiisg[[thread_index_in_simdgroup]],
6147
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6148
  impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
6149
  }
6150
 
@@ -6153,59 +6121,33 @@ void mmv_fn(
6153
  device const char * src0,
6154
  device const char * src1,
6155
  device float * dst,
6156
- constant int64_t & ne00,
6157
- constant int64_t & ne01,
6158
- constant int64_t & ne02,
6159
- constant uint64_t & nb00,
6160
- constant uint64_t & nb01,
6161
- constant uint64_t & nb02,
6162
- constant int64_t & ne10,
6163
- constant int64_t & ne11,
6164
- constant int64_t & ne12,
6165
- constant int64_t & ne13,
6166
- constant uint64_t & nb10,
6167
- constant uint64_t & nb11,
6168
- constant uint64_t & nb12,
6169
- constant int64_t & ne0,
6170
- constant int64_t & ne1,
6171
- constant uint64_t & nb1,
6172
- constant uint & r2,
6173
- constant uint & r3,
6174
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6175
- uint3 tgpig[[threadgroup_position_in_grid]],
6176
- uint tiitg[[thread_index_in_threadgroup]],
6177
- uint tiisg[[thread_index_in_simdgroup]],
6178
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
6179
  impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6180
  }
6181
 
6182
- typedef void (mul_mv_impl_fn_t)(
6183
- device const char * src0,
6184
- device const char * src1,
6185
- device float * dst,
6186
- constant int64_t & ne00,
6187
- constant int64_t & ne01,
6188
- constant int64_t & ne02,
6189
- constant uint64_t & nb00,
6190
- constant uint64_t & nb01,
6191
- constant uint64_t & nb02,
6192
- constant int64_t & ne10,
6193
- constant int64_t & ne11,
6194
- constant int64_t & ne12,
6195
- constant int64_t & ne13,
6196
- constant uint64_t & nb10,
6197
- constant uint64_t & nb11,
6198
- constant uint64_t & nb12,
6199
- constant int64_t & ne0,
6200
- constant int64_t & ne1,
6201
- constant uint64_t & nb1,
6202
- constant uint & r2,
6203
- constant uint & r3,
6204
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6205
- uint3 tgpig[[threadgroup_position_in_grid]],
6206
- uint tiitg[[thread_index_in_threadgroup]],
6207
- uint tiisg[[thread_index_in_simdgroup]],
6208
- uint sgitg[[simdgroup_index_in_threadgroup]]);
6209
 
6210
  template<mul_mv_impl_fn_t impl_fn>
6211
  kernel void kernel_mul_mv_id(
@@ -6213,6 +6155,8 @@ kernel void kernel_mul_mv_id(
6213
  device const char * src1,
6214
  device float * dst,
6215
  device const char * ids,
 
 
6216
  constant uint64_t & nbi1,
6217
  constant int64_t & ne00,
6218
  constant int64_t & ne01,
@@ -6230,43 +6174,50 @@ kernel void kernel_mul_mv_id(
6230
  constant int64_t & ne0,
6231
  constant int64_t & ne1,
6232
  constant uint64_t & nb1,
6233
- constant uint & r2,
6234
- constant uint & r3,
6235
- constant int & idx,
6236
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6237
  uint3 tgpig[[threadgroup_position_in_grid]],
6238
  uint tiitg[[thread_index_in_threadgroup]],
6239
  uint tiisg[[thread_index_in_simdgroup]],
6240
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6241
- const int64_t bid = tgpig.z/(ne12*ne13);
 
 
 
6242
 
6243
- tgpig.z = tgpig.z%(ne12*ne13);
6244
 
6245
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
6246
- device const char * src0 = src0s + id*nb02;
 
 
 
 
 
 
 
6247
 
6248
  impl_fn(
6249
- src0,
6250
- src1 + bid*nb11,
6251
- dst + bid*ne0,
6252
- ne00,
6253
- ne01,
6254
- ne02,
6255
- nb00,
6256
- nb01,
6257
- nb02,
6258
- ne10,
6259
- ne11,
6260
- ne12,
6261
- ne13,
6262
- nb10,
6263
- nb11,
6264
- nb12,
6265
- ne0,
6266
- ne1,
6267
- nb1,
6268
- r2,
6269
- r3,
6270
  shared_values,
6271
  tgpig,
6272
  tiitg,
@@ -6274,36 +6225,7 @@ kernel void kernel_mul_mv_id(
6274
  sgitg);
6275
  }
6276
 
6277
- typedef void (kernel_mul_mv_id_t)(
6278
- device const char * src0s,
6279
- device const char * src1,
6280
- device float * dst,
6281
- device const char * ids,
6282
- constant uint64_t & nbi1,
6283
- constant int64_t & ne00,
6284
- constant int64_t & ne01,
6285
- constant int64_t & ne02,
6286
- constant uint64_t & nb00,
6287
- constant uint64_t & nb01,
6288
- constant uint64_t & nb02,
6289
- constant int64_t & ne10,
6290
- constant int64_t & ne11,
6291
- constant int64_t & ne12,
6292
- constant int64_t & ne13,
6293
- constant uint64_t & nb10,
6294
- constant uint64_t & nb11,
6295
- constant uint64_t & nb12,
6296
- constant int64_t & ne0,
6297
- constant int64_t & ne1,
6298
- constant uint64_t & nb1,
6299
- constant uint & r2,
6300
- constant uint & r3,
6301
- constant int & idx,
6302
- threadgroup int8_t * shared_values [[threadgroup(0)]],
6303
- uint3 tgpig[[threadgroup_position_in_grid]],
6304
- uint tiitg[[thread_index_in_threadgroup]],
6305
- uint tiisg[[thread_index_in_simdgroup]],
6306
- uint sgitg[[simdgroup_index_in_threadgroup]]);
6307
 
6308
  template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
6309
  template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
 
899
  device const void * src0,
900
  device const float * src1,
901
  device float * dst,
902
+ int64_t ne00,
903
+ int64_t ne01,
904
+ int64_t ne02,
905
+ int64_t ne10,
906
+ int64_t ne12,
907
+ int64_t ne0,
908
+ int64_t ne1,
909
+ uint r2,
910
+ uint r3,
911
+ threadgroup int8_t * shared_values,
912
  uint3 tgpig, uint tiisg, uint sgitg) {
913
  const int nb = ne00/QK4_0;
914
 
 
1073
  device const void * src0,
1074
  device const float * src1,
1075
  device float * dst,
1076
+ int64_t ne00,
1077
+ int64_t ne01,
1078
+ int64_t ne02,
1079
+ int64_t ne10,
1080
+ int64_t ne12,
1081
+ int64_t ne0,
1082
+ int64_t ne1,
1083
+ uint r2,
1084
+ uint r3,
1085
+ threadgroup int8_t * shared_values,
1086
+ uint3 tgpig,
1087
+ uint tiisg,
1088
+ uint sgitg) {
1089
  const int nr = N_DST;
1090
  const int nsg = N_SIMDGROUP;
1091
  const int nw = N_SIMDWIDTH;
 
1172
  device const char * src0,
1173
  device const char * src1,
1174
  device float * dst,
1175
+ int64_t ne00,
1176
+ int64_t ne01,
1177
+ int64_t ne02,
1178
+ uint64_t nb00,
1179
+ uint64_t nb01,
1180
+ uint64_t nb02,
1181
+ int64_t ne10,
1182
+ int64_t ne11,
1183
+ int64_t ne12,
1184
+ uint64_t nb10,
1185
+ uint64_t nb11,
1186
+ uint64_t nb12,
1187
+ int64_t ne0,
1188
+ int64_t ne1,
1189
+ uint r2,
1190
+ uint r3,
1191
+ uint3 tgpig,
1192
+ uint tiisg) {
1193
 
1194
  const int64_t r0 = tgpig.x;
1195
  const int64_t rb = tgpig.y*N_F32_F32;
 
1442
  device const char * src0,
1443
  device const char * src1,
1444
  device float * dst,
1445
+ int64_t ne00,
1446
+ int64_t ne01,
1447
+ int64_t ne02,
1448
+ uint64_t nb00,
1449
+ uint64_t nb01,
1450
+ uint64_t nb02,
1451
+ int64_t ne10,
1452
+ int64_t ne11,
1453
+ int64_t ne12,
1454
+ uint64_t nb10,
1455
+ uint64_t nb11,
1456
+ uint64_t nb12,
1457
+ int64_t ne0,
1458
+ int64_t ne1,
1459
+ uint r2,
1460
+ uint r3,
1461
+ uint3 tgpig,
1462
+ uint tiisg) {
1463
 
1464
  const int64_t r0 = tgpig.x;
1465
  const int64_t rb = tgpig.y*N_F16_F32;
 
2744
  device const void * src0,
2745
  device const float * src1,
2746
  device float * dst,
2747
+ int64_t ne00,
2748
+ int64_t ne01,
2749
+ int64_t ne02,
2750
+ int64_t ne10,
2751
+ int64_t ne12,
2752
+ int64_t ne0,
2753
+ int64_t ne1,
2754
+ uint r2,
2755
+ uint r3,
2756
+ threadgroup int8_t * shared_values,
2757
+ uint3 tgpig,
2758
+ uint tiisg,
2759
+ uint sgitg) {
2760
 
2761
  const int nb = ne00/QK_K;
2762
  const int r0 = tgpig.x;
 
2924
  device const void * src0,
2925
  device const float * src1,
2926
  device float * dst,
2927
+ int64_t ne00,
2928
+ int64_t ne01,
2929
+ int64_t ne02,
2930
+ int64_t ne10,
2931
+ int64_t ne12,
2932
+ int64_t ne0,
2933
+ int64_t ne1,
2934
+ uint r2,
2935
+ uint r3,
2936
+ threadgroup int8_t * shared_values,
2937
+ uint3 tgpig,
2938
+ uint tiisg,
2939
+ uint sgitg) {
2940
 
2941
  const int nb = ne00/QK_K;
2942
 
 
3190
  device const void * src0,
3191
  device const float * src1,
3192
  device float * dst,
3193
+ int64_t ne00,
3194
+ int64_t ne01,
3195
+ int64_t ne02,
3196
+ int64_t ne10,
3197
+ int64_t ne12,
3198
+ int64_t ne0,
3199
+ int64_t ne1,
3200
+ uint r2,
3201
+ uint r3,
3202
+ threadgroup int8_t * shared_values,
3203
+ uint3 tgpig,
3204
+ uint tiisg,
3205
+ uint sgitg) {
3206
 
3207
  const uint16_t kmask1 = 0x3f3f;
3208
  const uint16_t kmask2 = 0x0f0f;
 
3429
  device const void * src0,
3430
  device const float * src1,
3431
  device float * dst,
3432
+ int64_t ne00,
3433
+ int64_t ne01,
3434
+ int64_t ne02,
3435
+ int64_t ne10,
3436
+ int64_t ne12,
3437
+ int64_t ne0,
3438
+ int64_t ne1,
3439
+ uint r2,
3440
+ uint r3,
3441
+ threadgroup int8_t * shared_values,
3442
+ uint3 tgpig,
3443
+ uint tiisg,
3444
+ uint sgitg) {
3445
 
3446
  const int nb = ne00/QK_K;
3447
 
 
3636
  device const void * src0,
3637
  device const float * src1,
3638
  device float * dst,
3639
+ int64_t ne00,
3640
+ int64_t ne01,
3641
+ int64_t ne02,
3642
+ int64_t ne10,
3643
+ int64_t ne12,
3644
+ int64_t ne0,
3645
+ int64_t ne1,
3646
+ uint r2,
3647
+ uint r3,
3648
+ threadgroup int8_t * shared_values,
3649
+ uint3 tgpig,
3650
+ uint tiisg,
3651
+ uint sgitg) {
3652
 
3653
  const uint8_t kmask1 = 0x03;
3654
  const uint8_t kmask2 = 0x0C;
 
3773
  device const void * src0,
3774
  device const float * src1,
3775
  device float * dst,
3776
+ int64_t ne00,
3777
+ int64_t ne01,
3778
+ int64_t ne02,
3779
+ int64_t ne10,
3780
+ int64_t ne12,
3781
+ int64_t ne0,
3782
+ int64_t ne1,
3783
+ uint r2,
3784
+ uint r3,
3785
+ threadgroup int8_t * shared_values,
3786
+ uint3 tgpig,
3787
+ uint tiisg,
3788
+ uint sgitg) {
3789
 
3790
  const int nb = ne00/QK_K;
3791
  const int r0 = tgpig.x;
 
3902
  device const void * src0,
3903
  device const float * src1,
3904
  device float * dst,
3905
+ int64_t ne00,
3906
+ int64_t ne01,
3907
+ int64_t ne02,
3908
+ int64_t ne10,
3909
+ int64_t ne12,
3910
+ int64_t ne0,
3911
+ int64_t ne1,
3912
+ uint r2,
3913
+ uint r3,
3914
+ threadgroup int8_t * shared_values,
3915
+ uint3 tgpig,
3916
+ uint tiisg,
3917
+ uint sgitg) {
3918
 
3919
  const int nb = ne00/QK_K;
3920
  const int r0 = tgpig.x;
 
4041
  device const void * src0,
4042
  device const float * src1,
4043
  device float * dst,
4044
+ int64_t ne00,
4045
+ int64_t ne01,
4046
+ int64_t ne02,
4047
+ int64_t ne10,
4048
+ int64_t ne12,
4049
+ int64_t ne0,
4050
+ int64_t ne1,
4051
+ uint r2,
4052
+ uint r3,
4053
+ threadgroup int8_t * shared_values,
4054
+ uint3 tgpig,
4055
+ uint tiisg,
4056
+ uint sgitg) {
4057
 
4058
  const int nb = ne00/QK_K;
4059
  const int r0 = tgpig.x;
 
4173
  device const void * src0,
4174
  device const float * src1,
4175
  device float * dst,
4176
+ int64_t ne00,
4177
+ int64_t ne01,
4178
+ int64_t ne02,
4179
+ int64_t ne10,
4180
+ int64_t ne12,
4181
+ int64_t ne0,
4182
+ int64_t ne1,
4183
+ uint r2,
4184
+ uint r3,
4185
+ threadgroup int8_t * shared_values,
4186
+ uint3 tgpig,
4187
+ uint tiisg,
4188
+ uint sgitg) {
4189
 
4190
  const int nb = ne00/QK_K;
4191
  const int r0 = tgpig.x;
 
4305
  device const void * src0,
4306
  device const float * src1,
4307
  device float * dst,
4308
+ int64_t ne00,
4309
+ int64_t ne01,
4310
+ int64_t ne02,
4311
+ int64_t ne10,
4312
+ int64_t ne12,
4313
+ int64_t ne0,
4314
+ int64_t ne1,
4315
+ uint r2,
4316
+ uint r3,
4317
+ threadgroup int8_t * shared_values,
4318
+ uint3 tgpig,
4319
+ uint tiisg,
4320
+ uint sgitg) {
4321
 
4322
  const int nb = ne00/QK_K;
4323
  const int r0 = tgpig.x;
 
4438
  device const void * src0,
4439
  device const float * src1,
4440
  device float * dst,
4441
+ int64_t ne00,
4442
+ int64_t ne01,
4443
+ int64_t ne02,
4444
+ int64_t ne10,
4445
+ int64_t ne12,
4446
+ int64_t ne0,
4447
+ int64_t ne1,
4448
+ uint r2,
4449
+ uint r3,
4450
+ threadgroup int8_t * shared_value,
4451
+ uint3 tgpig,
4452
+ uint tiisg,
4453
+ uint sgitg) {
4454
 
4455
  const int nb = ne00/QK_K;
4456
  const int r0 = tgpig.x;
 
4528
  device const void * src0,
4529
  device const float * src1,
4530
  device float * dst,
4531
+ int64_t ne00,
4532
+ int64_t ne01,
4533
+ int64_t ne02,
4534
+ int64_t ne10,
4535
+ int64_t ne12,
4536
+ int64_t ne0,
4537
+ int64_t ne1,
4538
+ uint r2,
4539
+ uint r3,
4540
+ threadgroup int8_t * shared_value,
4541
+ uint3 tgpig,
4542
+ uint tiisg,
4543
+ uint sgitg) {
4544
 
4545
  const int nb = ne00/QK_K;
4546
  const int r0 = tgpig.x;
 
4637
  device const void * src0,
4638
  device const float * src1,
4639
  device float * dst,
4640
+ int64_t ne00,
4641
+ int64_t ne01,
4642
+ int64_t ne02,
4643
+ int64_t ne10,
4644
+ int64_t ne12,
4645
+ int64_t ne0,
4646
+ int64_t ne1,
4647
+ uint r2,
4648
+ uint r3,
4649
+ threadgroup int8_t * shared_values_i8,
4650
+ uint3 tgpig,
4651
+ uint tiisg,
4652
+ uint sgitg) {
4653
 
4654
  threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4655
  const int nb = ne00/QK4_NL;
 
4732
  device const void * src0,
4733
  device const float * src1,
4734
  device float * dst,
4735
+ int64_t ne00,
4736
+ int64_t ne01,
4737
+ int64_t ne02,
4738
+ int64_t ne10,
4739
+ int64_t ne12,
4740
+ int64_t ne0,
4741
+ int64_t ne1,
4742
+ uint r2,
4743
+ uint r3,
4744
+ threadgroup int8_t * shared_values_i8,
4745
+ uint3 tgpig,
4746
+ uint tiisg,
4747
+ uint sgitg) {
4748
+
4749
  threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
4750
  const int nb = ne00/QK_K;
4751
  const int r0 = tgpig.x;
 
5687
  }
5688
  }
5689
 
5690
+ // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
5691
  template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
5692
  void kernel_mul_mm_id_impl(
5693
  device const uchar * src0,
5694
  device const uchar * src1,
5695
+ threadgroup ushort2 * rowids,
5696
  device float * dst,
5697
  constant int64_t & ne00,
5698
  constant int64_t & ne02,
5699
  constant uint64_t & nb01,
5700
  constant uint64_t & nb02,
5701
+ constant int64_t & ne11,
5702
  constant int64_t & ne12,
5703
  constant uint64_t & nb10,
5704
  constant uint64_t & nb11,
5705
  constant uint64_t & nb12,
5706
  constant int64_t & ne0,
5707
  int64_t ne1,
5708
+ int64_t ne0ne1,
 
5709
  threadgroup uchar * shared_memory,
5710
  uint3 tgpig[[threadgroup_position_in_grid]],
5711
  uint tiitg[[thread_index_in_threadgroup]],
 
5716
 
5717
  const uint r0 = tgpig.y;
5718
  const uint r1 = tgpig.x;
 
5719
 
5720
  if (r1 * BLOCK_SIZE_N >= ne1) return;
5721
 
 
5733
  for (int i = 0; i < 8; i++){
5734
  c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
5735
  }
 
5736
  short il = (tiitg % THREAD_PER_ROW);
5737
 
 
 
 
 
5738
  ushort offset1 = il/nl;
5739
 
5740
+ threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
5741
+
5742
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
5743
  device const float * y = (device const float *)(src1
5744
+ + nb12 * id[1]
5745
+ + nb11 * (id[0] % ne11)
5746
  + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
5747
 
5748
  for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
 
5771
 
5772
  for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
5773
  for (int i = 0; i < 4; i++) {
5774
+ simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
5775
  }
5776
  simdgroup_barrier(mem_flags::mem_none);
5777
  for (int i = 0; i < 2; i++) {
5778
+ simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
5779
  }
5780
 
5781
  lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
 
5797
 
5798
  threadgroup_barrier(mem_flags::mem_threadgroup);
5799
 
5800
+ device float * C = dst + (BLOCK_SIZE_M * r0);
5801
  if (sgitg == 0) {
5802
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
5803
+ threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
5804
+ int joff = jid[0] * ne0 + jid[1] * ne0ne1;
5805
+ for (int i = 0; i < n_rows; i++) {
5806
+ *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
5807
  }
5808
  }
5809
  }
 
5858
  device const uchar * src1,
5859
  device float * dst,
5860
  device const uchar * ids,
5861
+ constant int64_t & nei0,
5862
+ constant int64_t & nei1,
5863
  constant uint64_t & nbi1,
5864
  constant int64_t & ne00,
5865
  constant int64_t & ne02,
5866
  constant uint64_t & nb01,
5867
  constant uint64_t & nb02,
5868
+ constant int64_t & ne11,
5869
  constant int64_t & ne12,
5870
  constant int64_t & ne13,
5871
  constant uint64_t & nb10,
 
5874
  constant int64_t & ne0,
5875
  constant int64_t & ne1,
5876
  constant uint64_t & nb1,
 
 
 
5877
  threadgroup uchar * shared_memory [[threadgroup(0)]],
5878
  uint3 tgpig[[threadgroup_position_in_grid]],
5879
  uint tiitg[[thread_index_in_threadgroup]],
5880
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
5881
 
5882
+ const int32_t i02 = tgpig.z;
5883
+ tgpig.z = 0;
 
5884
 
5885
+ device const uchar * src0 = src0s + i02*nb02;
5886
 
5887
+ // row indices
5888
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
5889
 
5890
+ // TODO: parallelize this loop
5891
  int64_t _ne1 = 0;
5892
+ for (ushort ii1 = 0; ii1 < nei1; ii1++) {
5893
+ for (ushort ii0 = 0; ii0 < nei0; ii0++) {
5894
+ int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
5895
+ if (id == i02) {
5896
+ //if (tiitg == 0) {
5897
+ rowids[_ne1] = ushort2(ii0, ii1);
5898
+ //}
5899
+ _ne1++;
5900
+ }
5901
  }
5902
  }
5903
 
5904
+ threadgroup_barrier(mem_flags::mem_threadgroup);
5905
+
5906
  kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
5907
  src0,
5908
  src1,
5909
+ rowids,
5910
  dst,
5911
  ne00,
5912
  ne02,
5913
  nb01,
5914
  nb02,
5915
+ ne11,
5916
  ne12,
5917
  nb10,
5918
  nb11,
5919
  nb12,
5920
  ne0,
5921
  _ne1,
5922
+ ne0*ne1,
 
5923
  shared_memory,
5924
  tgpig,
5925
  tiitg,
 
5980
  // matrix-matrix multiplication
5981
  //
5982
 
5983
+ typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5984
 
5985
  template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
5986
  template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
 
6012
  // indirect matrix-matrix multiplication
6013
  //
6014
 
6015
+ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6016
 
6017
  template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
6018
  template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
 
6048
  device const char * src0,
6049
  device const char * src1,
6050
  device float * dst,
6051
+ int64_t ne00,
6052
+ int64_t ne01,
6053
+ int64_t ne02,
6054
+ uint64_t nb00,
6055
+ uint64_t nb01,
6056
+ uint64_t nb02,
6057
+ int64_t ne10,
6058
+ int64_t ne11,
6059
+ int64_t ne12,
6060
+ uint64_t nb10,
6061
+ uint64_t nb11,
6062
+ uint64_t nb12,
6063
+ int64_t ne0,
6064
+ int64_t ne1,
6065
+ uint r2,
6066
+ uint r3,
6067
+ uint3 tgpig,
6068
+ uint tiisg);
6069
 
6070
  typedef void (kernel_mul_mv2_impl_t)(
6071
  device const void * src0,
6072
  device const float * src1,
6073
  device float * dst,
6074
+ int64_t ne00,
6075
+ int64_t ne01,
6076
+ int64_t ne02,
6077
+ int64_t ne10,
6078
+ int64_t ne12,
6079
+ int64_t ne0,
6080
+ int64_t ne1,
6081
+ uint r2,
6082
+ uint r3,
6083
+ threadgroup int8_t * shared_values,
6084
+ uint3 tgpig,
6085
+ uint tiisg,
6086
+ uint sgitg);
6087
 
6088
  template<kernel_mul_mv_impl_t impl_fn>
6089
  void mmv_fn(
6090
  device const char * src0,
6091
  device const char * src1,
6092
  device float * dst,
6093
+ int64_t ne00,
6094
+ int64_t ne01,
6095
+ int64_t ne02,
6096
+ uint64_t nb00,
6097
+ uint64_t nb01,
6098
+ uint64_t nb02,
6099
+ int64_t ne10,
6100
+ int64_t ne11,
6101
+ int64_t ne12,
6102
+ int64_t ne13,
6103
+ uint64_t nb10,
6104
+ uint64_t nb11,
6105
+ uint64_t nb12,
6106
+ int64_t ne0,
6107
+ int64_t ne1,
6108
+ uint64_t nb1,
6109
+ uint r2,
6110
+ uint r3,
6111
+ threadgroup int8_t * shared_values,
6112
+ uint3 tgpig,
6113
+ uint tiitg,
6114
+ uint tiisg,
6115
+ uint sgitg) {
6116
  impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
6117
  }
6118
 
 
6121
  device const char * src0,
6122
  device const char * src1,
6123
  device float * dst,
6124
+ int64_t ne00,
6125
+ int64_t ne01,
6126
+ int64_t ne02,
6127
+ uint64_t nb00,
6128
+ uint64_t nb01,
6129
+ uint64_t nb02,
6130
+ int64_t ne10,
6131
+ int64_t ne11,
6132
+ int64_t ne12,
6133
+ int64_t ne13,
6134
+ uint64_t nb10,
6135
+ uint64_t nb11,
6136
+ uint64_t nb12,
6137
+ int64_t ne0,
6138
+ int64_t ne1,
6139
+ uint64_t nb1,
6140
+ uint r2,
6141
+ uint r3,
6142
+ threadgroup int8_t * shared_values,
6143
+ uint3 tgpig,
6144
+ uint tiitg,
6145
+ uint tiisg,
6146
+ uint sgitg) {
6147
  impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
6148
  }
6149
 
6150
+ typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6151
 
6152
  template<mul_mv_impl_fn_t impl_fn>
6153
  kernel void kernel_mul_mv_id(
 
6155
  device const char * src1,
6156
  device float * dst,
6157
  device const char * ids,
6158
+ constant int64_t & nei0,
6159
+ constant int64_t & nei1,
6160
  constant uint64_t & nbi1,
6161
  constant int64_t & ne00,
6162
  constant int64_t & ne01,
 
6174
  constant int64_t & ne0,
6175
  constant int64_t & ne1,
6176
  constant uint64_t & nb1,
 
 
 
6177
  threadgroup int8_t * shared_values [[threadgroup(0)]],
6178
  uint3 tgpig[[threadgroup_position_in_grid]],
6179
  uint tiitg[[thread_index_in_threadgroup]],
6180
  uint tiisg[[thread_index_in_simdgroup]],
6181
  uint sgitg[[simdgroup_index_in_threadgroup]]) {
6182
+ const int iid1 = tgpig.z/nei0;
6183
+ const int idx = tgpig.z%nei0;
6184
+
6185
+ tgpig.z = 0;
6186
 
6187
+ const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
6188
 
6189
+ const int64_t i11 = idx % ne11;
6190
+ const int64_t i12 = iid1;
6191
+
6192
+ const int64_t i1 = idx;
6193
+ const int64_t i2 = i12;
6194
+
6195
+ device const char * src0_cur = src0s + i02*nb02;
6196
+ device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
6197
+ device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
6198
 
6199
  impl_fn(
6200
+ /* src0 */ src0_cur,
6201
+ /* src1 */ src1_cur,
6202
+ /* dst */ dst_cur,
6203
+ /* ne00 */ ne00,
6204
+ /* ne01 */ ne01,
6205
+ /* ne02 */ 1,//ne02,
6206
+ /* nb00 */ nb00,
6207
+ /* nb01 */ nb01,
6208
+ /* nb02 */ nb02,
6209
+ /* ne10 */ ne10,
6210
+ /* ne11 */ 1,//ne11,
6211
+ /* ne12 */ 1,//ne12,
6212
+ /* ne13 */ 1,//ne13,
6213
+ /* nb10 */ nb10,
6214
+ /* nb11 */ nb11,
6215
+ /* nb12 */ nb12,
6216
+ /* ne0 */ ne0,
6217
+ /* ne1 */ 1,//ne1,
6218
+ /* nb1 */ nb1,
6219
+ /* r2 */ 1,
6220
+ /* r3 */ 1,
6221
  shared_values,
6222
  tgpig,
6223
  tiitg,
 
6225
  sgitg);
6226
  }
6227
 
6228
+ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6229
 
6230
  template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
6231
  template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
ggml-sycl.cpp CHANGED
@@ -17752,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
17752
 
17753
  GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
17754
  const int min_batch_size = 32;
17755
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
17756
  GGML_UNUSED(backend);
17757
  }
17758
 
 
17752
 
17753
  GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
17754
  const int min_batch_size = 32;
17755
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
17756
  GGML_UNUSED(backend);
17757
  }
17758
 
ggml.c CHANGED
@@ -4594,21 +4594,32 @@ void ggml_mul_mat_set_prec(
4594
 
4595
  // ggml_mul_mat_id
4596
 
4597
- // NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
4598
- // this will allow computing all the used experts in a single matrix multiplication
 
 
 
 
 
 
 
 
 
 
4599
  struct ggml_tensor * ggml_mul_mat_id(
4600
  struct ggml_context * ctx,
4601
  struct ggml_tensor * as,
4602
- struct ggml_tensor * ids,
4603
- int id,
4604
- struct ggml_tensor * b) {
4605
-
4606
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
 
 
 
4607
  GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
4608
- GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
4609
- GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
4610
- GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
4611
  GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
 
4612
 
4613
  bool is_node = false;
4614
 
@@ -4616,11 +4627,9 @@ struct ggml_tensor * ggml_mul_mat_id(
4616
  is_node = true;
4617
  }
4618
 
4619
- const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
4620
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4621
 
4622
- ggml_set_op_params_i32(result, 0, id);
4623
-
4624
  result->op = GGML_OP_MUL_MAT_ID;
4625
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4626
  result->src[0] = as;
@@ -11071,11 +11080,6 @@ static void ggml_compute_forward_mul_mat_id(
11071
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
11072
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
11073
 
11074
- GGML_ASSERT(ne0 == ne01);
11075
- GGML_ASSERT(ne1 == ne11);
11076
- GGML_ASSERT(ne2 == ne12);
11077
- GGML_ASSERT(ne3 == ne13);
11078
-
11079
  // we don't support permuted src0 or src1
11080
  GGML_ASSERT(nb00 == ggml_type_size(type));
11081
  GGML_ASSERT(nb10 == ggml_type_size(src1->type));
@@ -11086,22 +11090,21 @@ static void ggml_compute_forward_mul_mat_id(
11086
  GGML_ASSERT(nb1 <= nb2);
11087
  GGML_ASSERT(nb2 <= nb3);
11088
 
11089
- // broadcast is not supported with mmid
11090
- assert(ne12 == 1);
11091
- assert(ne13 == 1);
11092
-
11093
  // row groups
11094
- const int id = ggml_get_op_params_i32(dst, 0);
11095
- const int n_as = src0->ne[2];
11096
 
11097
  char * wdata_src1_end = (src1->type == vec_dot_type) ?
11098
  (char *) params->wdata :
11099
  (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
11100
 
11101
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
11102
- int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
 
 
11103
 
11104
- #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
 
11105
 
11106
  if (params->type == GGML_TASK_TYPE_INIT) {
11107
  if (ith != 0) {
@@ -11127,13 +11130,18 @@ static void ggml_compute_forward_mul_mat_id(
11127
  // initialize matrix_row_counts
11128
  memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
11129
 
 
 
11130
  // group rows by src0 matrix
11131
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
11132
- const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
 
 
 
11133
 
11134
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
11135
- MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
11136
- matrix_row_counts[row_id] += 1;
11137
  }
11138
 
11139
  return;
@@ -11151,15 +11159,13 @@ static void ggml_compute_forward_mul_mat_id(
11151
  continue;
11152
  }
11153
 
11154
- size_t src0_offset = cur_a*src0->nb[2];
11155
 
11156
  const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
11157
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
11158
 
11159
- const int64_t nr0 = ne01; // src0 rows
11160
- const int64_t nr1 = cne1*ne12*ne13; // src1 rows
11161
-
11162
- //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
11163
 
11164
  // distribute the thread work across the inner or outer loop based on which one is larger
11165
 
@@ -11178,13 +11184,11 @@ static void ggml_compute_forward_mul_mat_id(
11178
  const int64_t ir110 = dr1*ith1;
11179
  const int64_t ir111 = MIN(ir110 + dr1, nr1);
11180
 
11181
- //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
11182
-
11183
  // threads with no work simply yield (not sure if it helps)
11184
- if (ir010 >= ir011 || ir110 >= ir111) {
11185
- sched_yield();
11186
- continue;
11187
- }
11188
 
11189
  // block-tiling attempt
11190
  const int64_t blck_0 = 16;
@@ -11196,20 +11200,16 @@ static void ggml_compute_forward_mul_mat_id(
11196
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
11197
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
11198
  for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
11199
- const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
11200
- const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
11201
- const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
11202
- const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
11203
 
11204
- // broadcast src0 into src1
11205
- //const int64_t i03 = i13/r3;
11206
- //const int64_t i02 = i12/r2;
11207
 
11208
- const int64_t i1 = i11;
11209
- const int64_t i2 = i12;
11210
- const int64_t i3 = i13;
11211
 
11212
- const char * src0_row = (const char *) src0->data + src0_offset;
 
11213
 
11214
  // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
11215
  // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@@ -11217,25 +11217,26 @@ static void ggml_compute_forward_mul_mat_id(
11217
  // TODO: this is a bit of a hack, we should probably have a better way to handle this
11218
  const char * src1_col = (const char *) wdata +
11219
  (src1_cont || src1->type != vec_dot_type
11220
- ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
11221
- : (i11*nb11 + i12*nb12 + i13*nb13));
11222
 
11223
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
11224
 
11225
  //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
11226
  // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
11227
  //}
11228
 
11229
  for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
11230
- vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
11231
  }
 
11232
  memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
11233
  }
11234
  }
11235
  }
11236
  }
11237
 
11238
- #undef MMID_MATRIX_ROW
11239
  }
11240
 
11241
  // ggml_compute_forward_out_prod
@@ -18583,7 +18584,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
18583
  const int n_as = src0->ne[2];
18584
  cur += GGML_PAD(cur, sizeof(int64_t)); // align
18585
  cur += n_as * sizeof(int64_t); // matrix_row_counts
18586
- cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
18587
  } break;
18588
  case GGML_OP_OUT_PROD:
18589
  {
@@ -21009,12 +21010,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
21009
 
21010
  ok = ok && cur != NULL;
21011
 
21012
- ggml_set_name(cur, ctx->infos[i].name.data);
21013
-
21014
  if (!ok) {
21015
  break;
21016
  }
21017
 
 
 
21018
  // point the data member to the appropriate location in the binary blob using the tensor infos
21019
  if (!params.no_alloc) {
21020
  //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
 
4594
 
4595
  // ggml_mul_mat_id
4596
 
4597
+ /*
4598
+ c = ggml_mul_mat_id(ctx, as, b, ids);
4599
+
4600
+ as -> [cols, rows, n_expert]
4601
+ ids -> [n_experts_used, n_tokens] (i32)
4602
+ b -> [cols, n_expert_used, n_tokens]
4603
+ c -> [cols, n_expert_used, n_tokens]
4604
+
4605
+ in b, n_experts_used can be broadcasted to match the n_expert_used of ids
4606
+
4607
+ c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
4608
+ */
4609
  struct ggml_tensor * ggml_mul_mat_id(
4610
  struct ggml_context * ctx,
4611
  struct ggml_tensor * as,
4612
+ struct ggml_tensor * b,
4613
+ struct ggml_tensor * ids) {
4614
+ GGML_ASSERT(!ggml_is_transposed(as));
 
4615
  GGML_ASSERT(ids->type == GGML_TYPE_I32);
4616
+
4617
+ GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
4618
+ GGML_ASSERT(b->ne[3] == 1); // b is 3d
4619
  GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
4620
+ GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
 
 
4621
  GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
4622
+ GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
4623
 
4624
  bool is_node = false;
4625
 
 
4627
  is_node = true;
4628
  }
4629
 
4630
+ const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
4631
  struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
4632
 
 
 
4633
  result->op = GGML_OP_MUL_MAT_ID;
4634
  result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
4635
  result->src[0] = as;
 
11080
  enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
11081
  ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
11082
 
 
 
 
 
 
11083
  // we don't support permuted src0 or src1
11084
  GGML_ASSERT(nb00 == ggml_type_size(type));
11085
  GGML_ASSERT(nb10 == ggml_type_size(src1->type));
 
11090
  GGML_ASSERT(nb1 <= nb2);
11091
  GGML_ASSERT(nb2 <= nb3);
11092
 
 
 
 
 
11093
  // row groups
11094
+ const int n_ids = ids->ne[0]; // n_expert_used
11095
+ const int n_as = ne02; // n_expert
11096
 
11097
  char * wdata_src1_end = (src1->type == vec_dot_type) ?
11098
  (char *) params->wdata :
11099
  (char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
11100
 
11101
+ struct mmid_row_mapping {
11102
+ int32_t i1;
11103
+ int32_t i2;
11104
+ };
11105
 
11106
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
11107
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
11108
 
11109
  if (params->type == GGML_TASK_TYPE_INIT) {
11110
  if (ith != 0) {
 
11130
  // initialize matrix_row_counts
11131
  memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
11132
 
11133
+ #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
11134
+
11135
  // group rows by src0 matrix
11136
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
11137
+ for (int id = 0; id < n_ids; ++id) {
11138
+ const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
11139
+
11140
+ assert(i02 >= 0 && i02 < n_as);
11141
 
11142
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
11143
+ matrix_row_counts[i02] += 1;
11144
+ }
11145
  }
11146
 
11147
  return;
 
11159
  continue;
11160
  }
11161
 
11162
+ const char * src0_cur = (const char *) src0->data + cur_a*nb02;
11163
 
11164
  const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
11165
  const size_t row_size = ggml_row_size(vec_dot_type, ne10);
11166
 
11167
+ const int64_t nr0 = ne01; // src0 rows
11168
+ const int64_t nr1 = cne1; // src1 rows
 
 
11169
 
11170
  // distribute the thread work across the inner or outer loop based on which one is larger
11171
 
 
11184
  const int64_t ir110 = dr1*ith1;
11185
  const int64_t ir111 = MIN(ir110 + dr1, nr1);
11186
 
 
 
11187
  // threads with no work simply yield (not sure if it helps)
11188
+ //if (ir010 >= ir011 || ir110 >= ir111) {
11189
+ // sched_yield();
11190
+ // continue;
11191
+ //}
11192
 
11193
  // block-tiling attempt
11194
  const int64_t blck_0 = 16;
 
11200
  for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
11201
  for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
11202
  for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
11203
+ const int64_t _i12 = ir1; // logical row index for this expert
 
 
 
11204
 
11205
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
11206
+ const int id = row_mapping.i1; // selected expert index
 
11207
 
11208
+ const int64_t i11 = id % ne11;
11209
+ const int64_t i12 = row_mapping.i2; // row index in src1
 
11210
 
11211
+ const int64_t i1 = id; // selected expert index
11212
+ const int64_t i2 = i12; // row
11213
 
11214
  // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
11215
  // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
 
11217
  // TODO: this is a bit of a hack, we should probably have a better way to handle this
11218
  const char * src1_col = (const char *) wdata +
11219
  (src1_cont || src1->type != vec_dot_type
11220
+ ? (i11 + i12*ne11)*row_size
11221
+ : (i11*nb11 + i12*nb12));
11222
 
11223
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
11224
 
11225
  //for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
11226
  // vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
11227
  //}
11228
 
11229
  for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
11230
+ vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
11231
  }
11232
+
11233
  memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
11234
  }
11235
  }
11236
  }
11237
  }
11238
 
11239
+ #undef MMID_MATRIX_ROW
11240
  }
11241
 
11242
  // ggml_compute_forward_out_prod
 
18584
  const int n_as = src0->ne[2];
18585
  cur += GGML_PAD(cur, sizeof(int64_t)); // align
18586
  cur += n_as * sizeof(int64_t); // matrix_row_counts
18587
+ cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
18588
  } break;
18589
  case GGML_OP_OUT_PROD:
18590
  {
 
21010
 
21011
  ok = ok && cur != NULL;
21012
 
 
 
21013
  if (!ok) {
21014
  break;
21015
  }
21016
 
21017
+ ggml_set_name(cur, ctx->infos[i].name.data);
21018
+
21019
  // point the data member to the appropriate location in the binary blob using the tensor infos
21020
  if (!params.no_alloc) {
21021
  //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
ggml.h CHANGED
@@ -1170,13 +1170,11 @@ extern "C" {
1170
  enum ggml_prec prec);
1171
 
1172
  // indirect matrix multiplication
1173
- // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
1174
  GGML_API struct ggml_tensor * ggml_mul_mat_id(
1175
  struct ggml_context * ctx,
1176
  struct ggml_tensor * as,
1177
- struct ggml_tensor * ids,
1178
- int id,
1179
- struct ggml_tensor * b);
1180
 
1181
  // A: m columns, n rows,
1182
  // B: p columns, n rows,
 
1170
  enum ggml_prec prec);
1171
 
1172
  // indirect matrix multiplication
 
1173
  GGML_API struct ggml_tensor * ggml_mul_mat_id(
1174
  struct ggml_context * ctx,
1175
  struct ggml_tensor * as,
1176
+ struct ggml_tensor * b,
1177
+ struct ggml_tensor * ids);
 
1178
 
1179
  // A: m columns, n rows,
1180
  // B: p columns, n rows,