Spaces:
Sleeping
Sleeping
ggml : group all experts in a single ggml_mul_mat_id (llama/6505)
Browse files* ggml : group all experts in a single ggml_mul_mat_id
cuda : improve mmid row copy
* cuda : fix bin bcast with non-cont src0
* test-backend-ops : only run all mul mat tests for base types
* llama : disable moe offloading with SYCL
---------
Co-authored-by: Georgi Gerganov <[email protected]>
- ggml-cuda.cu +134 -45
- ggml-cuda/binbcast.cu +68 -24
- ggml-cuda/convert.cu +2 -0
- ggml-metal.m +61 -68
- ggml-metal.metal +400 -478
- ggml-sycl.cpp +1 -1
- ggml.c +62 -61
- ggml.h +2 -4
ggml-cuda.cu
CHANGED
|
@@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
|
| 1231 |
|
| 1232 |
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
| 1233 |
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
| 1234 |
-
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
|
| 1235 |
if (src0->type != GGML_TYPE_F16) {
|
| 1236 |
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
| 1237 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
|
@@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
|
| 1241 |
}
|
| 1242 |
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
|
| 1243 |
|
| 1244 |
-
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
|
| 1245 |
if (src1->type != GGML_TYPE_F16) {
|
| 1246 |
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
| 1247 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
|
@@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
|
|
| 1250 |
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
|
| 1251 |
}
|
| 1252 |
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
|
| 1253 |
-
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
|
| 1254 |
|
| 1255 |
const half alpha_f16 = 1.0f;
|
| 1256 |
const half beta_f16 = 0.0f;
|
|
@@ -1960,20 +1960,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
|
|
| 1960 |
}
|
| 1961 |
}
|
| 1962 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1963 |
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 1964 |
const ggml_tensor * src0 = dst->src[0];
|
| 1965 |
const ggml_tensor * src1 = dst->src[1];
|
| 1966 |
const ggml_tensor * ids = dst->src[2];
|
| 1967 |
|
|
|
|
|
|
|
| 1968 |
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
|
| 1969 |
|
| 1970 |
cudaStream_t stream = ctx.stream();
|
| 1971 |
|
| 1972 |
-
const
|
| 1973 |
-
const
|
| 1974 |
-
|
| 1975 |
-
const int32_t id = ((int32_t *) dst->op_params)[0];
|
| 1976 |
-
const int32_t n_as = src0->ne[2];
|
| 1977 |
|
| 1978 |
std::vector<char> ids_host(ggml_nbytes(ids));
|
| 1979 |
const char * ids_dev = (const char *) ids->data;
|
|
@@ -1982,7 +2035,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 1982 |
|
| 1983 |
ggml_tensor src0_row = *src0;
|
| 1984 |
ggml_tensor src1_row = *src1;
|
| 1985 |
-
ggml_tensor dst_row
|
| 1986 |
|
| 1987 |
char * src0_original = (char *) src0->data;
|
| 1988 |
char * src1_original = (char *) src1->data;
|
|
@@ -1990,19 +2043,39 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 1990 |
|
| 1991 |
src0_row.ne[2] = 1;
|
| 1992 |
src0_row.ne[3] = 1;
|
| 1993 |
-
src0_row.nb[3] =
|
| 1994 |
|
| 1995 |
-
|
| 1996 |
-
|
| 1997 |
-
|
|
|
|
|
|
|
| 1998 |
|
| 1999 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2000 |
|
| 2001 |
-
|
| 2002 |
-
|
| 2003 |
-
|
|
|
|
| 2004 |
|
| 2005 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2006 |
}
|
| 2007 |
} else {
|
| 2008 |
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
|
@@ -2011,54 +2084,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
|
|
| 2011 |
src1_row.data = src1_contiguous.get();
|
| 2012 |
dst_row.data = dst_contiguous.get();
|
| 2013 |
|
| 2014 |
-
for (
|
| 2015 |
int64_t num_src1_rows = 0;
|
| 2016 |
-
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
|
| 2017 |
-
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
|
| 2018 |
|
| 2019 |
-
|
| 2020 |
-
|
| 2021 |
-
|
| 2022 |
|
| 2023 |
-
|
| 2024 |
|
| 2025 |
-
|
| 2026 |
-
|
| 2027 |
-
|
|
|
|
|
|
|
|
|
|
| 2028 |
}
|
| 2029 |
|
| 2030 |
if (num_src1_rows == 0) {
|
| 2031 |
continue;
|
| 2032 |
}
|
| 2033 |
|
| 2034 |
-
|
|
|
|
|
|
|
| 2035 |
|
| 2036 |
-
|
| 2037 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2038 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2039 |
src1_row.nb[1] = nb11;
|
| 2040 |
src1_row.nb[2] = num_src1_rows*nb11;
|
| 2041 |
src1_row.nb[3] = num_src1_rows*nb11;
|
| 2042 |
|
|
|
|
| 2043 |
dst_row.nb[1] = nb1;
|
| 2044 |
dst_row.nb[2] = num_src1_rows*nb1;
|
| 2045 |
dst_row.nb[3] = num_src1_rows*nb1;
|
| 2046 |
|
| 2047 |
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
| 2048 |
|
| 2049 |
-
|
| 2050 |
-
|
| 2051 |
-
|
| 2052 |
-
|
| 2053 |
-
|
| 2054 |
-
|
| 2055 |
-
|
| 2056 |
-
|
| 2057 |
-
|
| 2058 |
-
|
| 2059 |
-
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
|
| 2060 |
-
nb1, cudaMemcpyDeviceToDevice, stream));
|
| 2061 |
-
num_src1_rows++;
|
| 2062 |
}
|
| 2063 |
}
|
| 2064 |
}
|
|
@@ -2491,7 +2579,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
|
| 2491 |
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
| 2492 |
const int min_batch_size = 32;
|
| 2493 |
|
| 2494 |
-
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS
|
|
|
|
| 2495 |
|
| 2496 |
GGML_UNUSED(backend);
|
| 2497 |
}
|
|
|
|
| 1231 |
|
| 1232 |
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
|
| 1233 |
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
| 1234 |
+
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
|
| 1235 |
if (src0->type != GGML_TYPE_F16) {
|
| 1236 |
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
|
| 1237 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
|
|
|
| 1241 |
}
|
| 1242 |
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
|
| 1243 |
|
| 1244 |
+
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
|
| 1245 |
if (src1->type != GGML_TYPE_F16) {
|
| 1246 |
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
|
| 1247 |
GGML_ASSERT(to_fp16_cuda != nullptr);
|
|
|
|
| 1250 |
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
|
| 1251 |
}
|
| 1252 |
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
|
| 1253 |
+
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
|
| 1254 |
|
| 1255 |
const half alpha_f16 = 1.0f;
|
| 1256 |
const half beta_f16 = 0.0f;
|
|
|
|
| 1960 |
}
|
| 1961 |
}
|
| 1962 |
|
| 1963 |
+
struct mmid_row_mapping {
|
| 1964 |
+
int32_t i1;
|
| 1965 |
+
int32_t i2;
|
| 1966 |
+
};
|
| 1967 |
+
|
| 1968 |
+
static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
|
| 1969 |
+
int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
|
| 1970 |
+
const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
|
| 1971 |
+
int64_t ne11, int64_t ne10,
|
| 1972 |
+
size_t nb11, size_t nb12) {
|
| 1973 |
+
int32_t iid1 = blockIdx.x;
|
| 1974 |
+
int32_t id = blockIdx.y;
|
| 1975 |
+
|
| 1976 |
+
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
|
| 1977 |
+
|
| 1978 |
+
if (row_id_i != i02) {
|
| 1979 |
+
return;
|
| 1980 |
+
}
|
| 1981 |
+
|
| 1982 |
+
const int64_t i11 = id % ne11;
|
| 1983 |
+
const int64_t i12 = iid1;
|
| 1984 |
+
|
| 1985 |
+
__shared__ int src1_row;
|
| 1986 |
+
if (threadIdx.x == 0) {
|
| 1987 |
+
src1_row = atomicAdd(cur_src1_row, 1);
|
| 1988 |
+
row_mapping[src1_row] = {id, iid1};
|
| 1989 |
+
}
|
| 1990 |
+
__syncthreads();
|
| 1991 |
+
|
| 1992 |
+
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
|
| 1993 |
+
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
|
| 1994 |
+
|
| 1995 |
+
for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
|
| 1996 |
+
src1_row_contiguous[i] = src1_row_original[i];
|
| 1997 |
+
}
|
| 1998 |
+
}
|
| 1999 |
+
|
| 2000 |
+
static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
|
| 2001 |
+
const mmid_row_mapping * __restrict__ row_mapping,
|
| 2002 |
+
int64_t ne0,
|
| 2003 |
+
size_t nb1, size_t nb2) {
|
| 2004 |
+
int32_t i = blockIdx.x;
|
| 2005 |
+
|
| 2006 |
+
const int32_t i1 = row_mapping[i].i1;
|
| 2007 |
+
const int32_t i2 = row_mapping[i].i2;
|
| 2008 |
+
|
| 2009 |
+
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
|
| 2010 |
+
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
|
| 2011 |
+
|
| 2012 |
+
for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
|
| 2013 |
+
dst_row_original[j] = dst_row_contiguous[j];
|
| 2014 |
+
}
|
| 2015 |
+
}
|
| 2016 |
+
|
| 2017 |
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
| 2018 |
const ggml_tensor * src0 = dst->src[0];
|
| 2019 |
const ggml_tensor * src1 = dst->src[1];
|
| 2020 |
const ggml_tensor * ids = dst->src[2];
|
| 2021 |
|
| 2022 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 2023 |
+
|
| 2024 |
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
|
| 2025 |
|
| 2026 |
cudaStream_t stream = ctx.stream();
|
| 2027 |
|
| 2028 |
+
const int64_t n_as = ne02;
|
| 2029 |
+
const int64_t n_ids = ids->ne[0];
|
|
|
|
|
|
|
|
|
|
| 2030 |
|
| 2031 |
std::vector<char> ids_host(ggml_nbytes(ids));
|
| 2032 |
const char * ids_dev = (const char *) ids->data;
|
|
|
|
| 2035 |
|
| 2036 |
ggml_tensor src0_row = *src0;
|
| 2037 |
ggml_tensor src1_row = *src1;
|
| 2038 |
+
ggml_tensor dst_row = *dst;
|
| 2039 |
|
| 2040 |
char * src0_original = (char *) src0->data;
|
| 2041 |
char * src1_original = (char *) src1->data;
|
|
|
|
| 2043 |
|
| 2044 |
src0_row.ne[2] = 1;
|
| 2045 |
src0_row.ne[3] = 1;
|
| 2046 |
+
src0_row.nb[3] = nb02;
|
| 2047 |
|
| 2048 |
+
src1_row.ne[1] = 1;
|
| 2049 |
+
src1_row.ne[2] = 1;
|
| 2050 |
+
src1_row.ne[3] = 1;
|
| 2051 |
+
src1_row.nb[2] = nb11;
|
| 2052 |
+
src1_row.nb[3] = nb11;
|
| 2053 |
|
| 2054 |
+
dst_row.ne[1] = 1;
|
| 2055 |
+
dst_row.ne[2] = 1;
|
| 2056 |
+
dst_row.ne[3] = 1;
|
| 2057 |
+
dst_row.nb[2] = nb1;
|
| 2058 |
+
dst_row.nb[3] = nb1;
|
| 2059 |
|
| 2060 |
+
if (ne12 == 1) {
|
| 2061 |
+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
| 2062 |
+
for (int64_t id = 0; id < n_ids; id++) {
|
| 2063 |
+
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
| 2064 |
|
| 2065 |
+
GGML_ASSERT(i02 >= 0 && i02 < n_as);
|
| 2066 |
+
|
| 2067 |
+
const int64_t i11 = id % ne11;
|
| 2068 |
+
const int64_t i12 = iid1;
|
| 2069 |
+
|
| 2070 |
+
const int64_t i1 = id;
|
| 2071 |
+
const int64_t i2 = i12;
|
| 2072 |
+
|
| 2073 |
+
src0_row.data = src0_original + i02*nb02;
|
| 2074 |
+
src1_row.data = src1_original + i11*nb11 + i12*nb12;
|
| 2075 |
+
dst_row.data = dst_original + i1*nb1 + i2*nb2;
|
| 2076 |
+
|
| 2077 |
+
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
| 2078 |
+
}
|
| 2079 |
}
|
| 2080 |
} else {
|
| 2081 |
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
|
|
|
| 2084 |
src1_row.data = src1_contiguous.get();
|
| 2085 |
dst_row.data = dst_contiguous.get();
|
| 2086 |
|
| 2087 |
+
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
| 2088 |
int64_t num_src1_rows = 0;
|
|
|
|
|
|
|
| 2089 |
|
| 2090 |
+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
|
| 2091 |
+
for (int64_t id = 0; id < n_ids; id++) {
|
| 2092 |
+
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
|
| 2093 |
|
| 2094 |
+
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
|
| 2095 |
|
| 2096 |
+
if (row_id_i != i02) {
|
| 2097 |
+
continue;
|
| 2098 |
+
}
|
| 2099 |
+
|
| 2100 |
+
num_src1_rows++;
|
| 2101 |
+
}
|
| 2102 |
}
|
| 2103 |
|
| 2104 |
if (num_src1_rows == 0) {
|
| 2105 |
continue;
|
| 2106 |
}
|
| 2107 |
|
| 2108 |
+
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
|
| 2109 |
+
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
|
| 2110 |
+
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
|
| 2111 |
|
| 2112 |
+
{
|
| 2113 |
+
dim3 block_dims(std::min((unsigned int)ne10, 768u));
|
| 2114 |
+
dim3 grid_dims(ids->ne[1], n_ids);
|
| 2115 |
+
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
| 2116 |
+
src1_original, src1_contiguous.get(),
|
| 2117 |
+
dev_cur_src1_row.get(), dev_row_mapping.get(),
|
| 2118 |
+
ids_dev, i02, ids->nb[1], ids->nb[0],
|
| 2119 |
+
ne11, ne10,
|
| 2120 |
+
nb11, nb12);
|
| 2121 |
+
CUDA_CHECK(cudaGetLastError());
|
| 2122 |
+
}
|
| 2123 |
+
|
| 2124 |
+
src0_row.data = src0_original + i02*nb02;
|
| 2125 |
|
| 2126 |
+
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
| 2127 |
+
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
| 2128 |
+
|
| 2129 |
+
src1_row.ne[1] = num_src1_rows;
|
| 2130 |
src1_row.nb[1] = nb11;
|
| 2131 |
src1_row.nb[2] = num_src1_rows*nb11;
|
| 2132 |
src1_row.nb[3] = num_src1_rows*nb11;
|
| 2133 |
|
| 2134 |
+
dst_row.ne[1] = num_src1_rows;
|
| 2135 |
dst_row.nb[1] = nb1;
|
| 2136 |
dst_row.nb[2] = num_src1_rows*nb1;
|
| 2137 |
dst_row.nb[3] = num_src1_rows*nb1;
|
| 2138 |
|
| 2139 |
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
| 2140 |
|
| 2141 |
+
{
|
| 2142 |
+
dim3 block_dims(std::min((unsigned int)ne0, 768u));
|
| 2143 |
+
dim3 grid_dims(num_src1_rows);
|
| 2144 |
+
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
|
| 2145 |
+
dst_original, dst_contiguous.get(),
|
| 2146 |
+
dev_row_mapping.get(),
|
| 2147 |
+
ne0,
|
| 2148 |
+
nb1, nb2);
|
| 2149 |
+
CUDA_CHECK(cudaGetLastError());
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2150 |
}
|
| 2151 |
}
|
| 2152 |
}
|
|
|
|
| 2579 |
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
| 2580 |
const int min_batch_size = 32;
|
| 2581 |
|
| 2582 |
+
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
|
| 2583 |
+
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
| 2584 |
|
| 2585 |
GGML_UNUSED(backend);
|
| 2586 |
}
|
ggml-cuda/binbcast.cu
CHANGED
|
@@ -22,6 +22,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
|
|
| 22 |
int ne0, int ne1, int ne2, int ne3,
|
| 23 |
int ne10, int ne11, int ne12, int ne13,
|
| 24 |
/*int s0, */ int s1, int s2, int s3,
|
|
|
|
| 25 |
/*int s10,*/ int s11, int s12, int s13) {
|
| 26 |
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
|
| 27 |
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
|
|
@@ -36,9 +37,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
|
|
| 36 |
const int i12 = i2 % ne12;
|
| 37 |
const int i13 = i3 % ne13;
|
| 38 |
|
| 39 |
-
const size_t i_src0 =
|
| 40 |
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
| 41 |
-
const size_t i_dst =
|
| 42 |
|
| 43 |
const src0_t * src0_row = src0 + i_src0;
|
| 44 |
const src1_t * src1_row = src1 + i_src1;
|
|
@@ -55,6 +56,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
|
| 55 |
int ne0, int ne1, int ne2, int ne3,
|
| 56 |
int ne10, int ne11, int ne12, int ne13,
|
| 57 |
/*int s0, */ int s1, int s2, int s3,
|
|
|
|
| 58 |
/*int s10,*/ int s11, int s12, int s13) {
|
| 59 |
|
| 60 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
@@ -72,9 +74,9 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
|
|
| 72 |
const int i12 = i2 % ne12;
|
| 73 |
const int i13 = i3 % ne13;
|
| 74 |
|
| 75 |
-
const size_t i_src0 =
|
| 76 |
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
| 77 |
-
const size_t i_dst =
|
| 78 |
|
| 79 |
const src0_t * src0_row = src0 + i_src0;
|
| 80 |
const src1_t * src1_row = src1 + i_src1;
|
|
@@ -101,10 +103,14 @@ struct bin_bcast_cuda {
|
|
| 101 |
int nr[4] = { nr0, nr1, nr2, nr3 };
|
| 102 |
|
| 103 |
// collapse dimensions until first broadcast dimension
|
| 104 |
-
int64_t
|
|
|
|
| 105 |
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
|
|
|
| 108 |
auto collapse = [](int64_t cne[]) {
|
| 109 |
cne[0] *= cne[1];
|
| 110 |
cne[1] = cne[2];
|
|
@@ -118,32 +124,47 @@ struct bin_bcast_cuda {
|
|
| 118 |
cnb[3] *= cne[3];
|
| 119 |
};
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
}
|
| 131 |
}
|
|
|
|
| 132 |
{
|
| 133 |
-
int64_t ne0 =
|
| 134 |
-
int64_t ne1 =
|
| 135 |
-
int64_t ne2 =
|
| 136 |
-
int64_t ne3 =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
int64_t ne10 = cne1[0];
|
| 139 |
int64_t ne11 = cne1[1];
|
| 140 |
int64_t ne12 = cne1[2];
|
| 141 |
int64_t ne13 = cne1[3];
|
| 142 |
|
| 143 |
-
size_t nb0 =
|
| 144 |
-
size_t nb1 =
|
| 145 |
-
size_t nb2 =
|
| 146 |
-
size_t nb3 =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
size_t nb10 = cnb1[0];
|
| 149 |
size_t nb11 = cnb1[1];
|
|
@@ -160,7 +181,28 @@ struct bin_bcast_cuda {
|
|
| 160 |
size_t s12 = nb12 / sizeof(src1_t);
|
| 161 |
size_t s13 = nb13 / sizeof(src1_t);
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
GGML_ASSERT(s0 == 1);
|
|
|
|
| 164 |
GGML_ASSERT(s10 == 1);
|
| 165 |
|
| 166 |
const int block_size = 128;
|
|
@@ -179,13 +221,14 @@ struct bin_bcast_cuda {
|
|
| 179 |
);
|
| 180 |
|
| 181 |
if (block_nums.z > 65535) {
|
| 182 |
-
// this is the maximum number of blocks in z
|
| 183 |
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
| 184 |
k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
|
| 185 |
src0_dd, src1_dd, dst_dd,
|
| 186 |
ne0, ne1, ne2, ne3,
|
| 187 |
ne10, ne11, ne12, ne13,
|
| 188 |
/* s0, */ s1, s2, s3,
|
|
|
|
| 189 |
/* s10, */ s11, s12, s13);
|
| 190 |
} else {
|
| 191 |
k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
|
|
@@ -193,6 +236,7 @@ struct bin_bcast_cuda {
|
|
| 193 |
ne0, ne1, ne2, ne3,
|
| 194 |
ne10, ne11, ne12, ne13,
|
| 195 |
/* s0, */ s1, s2, s3,
|
|
|
|
| 196 |
/* s10, */ s11, s12, s13);
|
| 197 |
}
|
| 198 |
}
|
|
|
|
| 22 |
int ne0, int ne1, int ne2, int ne3,
|
| 23 |
int ne10, int ne11, int ne12, int ne13,
|
| 24 |
/*int s0, */ int s1, int s2, int s3,
|
| 25 |
+
/*int s00,*/ int s01, int s02, int s03,
|
| 26 |
/*int s10,*/ int s11, int s12, int s13) {
|
| 27 |
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
|
| 28 |
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
|
|
|
|
| 37 |
const int i12 = i2 % ne12;
|
| 38 |
const int i13 = i3 % ne13;
|
| 39 |
|
| 40 |
+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
| 41 |
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
| 42 |
+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
| 43 |
|
| 44 |
const src0_t * src0_row = src0 + i_src0;
|
| 45 |
const src1_t * src1_row = src1 + i_src1;
|
|
|
|
| 56 |
int ne0, int ne1, int ne2, int ne3,
|
| 57 |
int ne10, int ne11, int ne12, int ne13,
|
| 58 |
/*int s0, */ int s1, int s2, int s3,
|
| 59 |
+
/*int s00,*/ int s01, int s02, int s03,
|
| 60 |
/*int s10,*/ int s11, int s12, int s13) {
|
| 61 |
|
| 62 |
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
|
|
|
| 74 |
const int i12 = i2 % ne12;
|
| 75 |
const int i13 = i3 % ne13;
|
| 76 |
|
| 77 |
+
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
|
| 78 |
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
|
| 79 |
+
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
|
| 80 |
|
| 81 |
const src0_t * src0_row = src0 + i_src0;
|
| 82 |
const src1_t * src1_row = src1 + i_src1;
|
|
|
|
| 103 |
int nr[4] = { nr0, nr1, nr2, nr3 };
|
| 104 |
|
| 105 |
// collapse dimensions until first broadcast dimension
|
| 106 |
+
int64_t cne[] = {ne0, ne1, ne2, ne3};
|
| 107 |
+
int64_t cne0[] = {ne00, ne01, ne02, ne03};
|
| 108 |
int64_t cne1[] = {ne10, ne11, ne12, ne13};
|
| 109 |
+
|
| 110 |
+
size_t cnb[] = {nb0, nb1, nb2, nb3};
|
| 111 |
+
size_t cnb0[] = {nb00, nb01, nb02, nb03};
|
| 112 |
size_t cnb1[] = {nb10, nb11, nb12, nb13};
|
| 113 |
+
|
| 114 |
auto collapse = [](int64_t cne[]) {
|
| 115 |
cne[0] *= cne[1];
|
| 116 |
cne[1] = cne[2];
|
|
|
|
| 124 |
cnb[3] *= cne[3];
|
| 125 |
};
|
| 126 |
|
| 127 |
+
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
|
| 128 |
+
for (int i = 0; i < 4; i++) {
|
| 129 |
+
if (nr[i] != 1) {
|
| 130 |
+
break;
|
| 131 |
+
}
|
| 132 |
+
if (i > 0) {
|
| 133 |
+
collapse_nb(cnb, cne);
|
| 134 |
+
collapse_nb(cnb0, cne0);
|
| 135 |
+
collapse_nb(cnb1, cne1);
|
| 136 |
+
collapse(cne);
|
| 137 |
+
collapse(cne0);
|
| 138 |
+
collapse(cne1);
|
| 139 |
+
}
|
| 140 |
}
|
| 141 |
}
|
| 142 |
+
|
| 143 |
{
|
| 144 |
+
int64_t ne0 = cne[0];
|
| 145 |
+
int64_t ne1 = cne[1];
|
| 146 |
+
int64_t ne2 = cne[2];
|
| 147 |
+
int64_t ne3 = cne[3];
|
| 148 |
+
|
| 149 |
+
//int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
|
| 150 |
+
//int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
|
| 151 |
+
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
|
| 152 |
+
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
|
| 153 |
|
| 154 |
int64_t ne10 = cne1[0];
|
| 155 |
int64_t ne11 = cne1[1];
|
| 156 |
int64_t ne12 = cne1[2];
|
| 157 |
int64_t ne13 = cne1[3];
|
| 158 |
|
| 159 |
+
size_t nb0 = cnb[0];
|
| 160 |
+
size_t nb1 = cnb[1];
|
| 161 |
+
size_t nb2 = cnb[2];
|
| 162 |
+
size_t nb3 = cnb[3];
|
| 163 |
+
|
| 164 |
+
size_t nb00 = cnb0[0];
|
| 165 |
+
size_t nb01 = cnb0[1];
|
| 166 |
+
size_t nb02 = cnb0[2];
|
| 167 |
+
size_t nb03 = cnb0[3];
|
| 168 |
|
| 169 |
size_t nb10 = cnb1[0];
|
| 170 |
size_t nb11 = cnb1[1];
|
|
|
|
| 181 |
size_t s12 = nb12 / sizeof(src1_t);
|
| 182 |
size_t s13 = nb13 / sizeof(src1_t);
|
| 183 |
|
| 184 |
+
size_t s00 = nb00 / sizeof(src0_t);
|
| 185 |
+
size_t s01 = nb01 / sizeof(src0_t);
|
| 186 |
+
size_t s02 = nb02 / sizeof(src0_t);
|
| 187 |
+
size_t s03 = nb03 / sizeof(src0_t);
|
| 188 |
+
|
| 189 |
+
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
|
| 190 |
+
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
|
| 191 |
+
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
|
| 192 |
+
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
|
| 193 |
+
|
| 194 |
+
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
|
| 195 |
+
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
|
| 196 |
+
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
|
| 197 |
+
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
|
| 198 |
+
|
| 199 |
+
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
|
| 200 |
+
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
|
| 201 |
+
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
|
| 202 |
+
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
|
| 203 |
+
|
| 204 |
GGML_ASSERT(s0 == 1);
|
| 205 |
+
GGML_ASSERT(s00 == 1);
|
| 206 |
GGML_ASSERT(s10 == 1);
|
| 207 |
|
| 208 |
const int block_size = 128;
|
|
|
|
| 221 |
);
|
| 222 |
|
| 223 |
if (block_nums.z > 65535) {
|
| 224 |
+
// this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
|
| 225 |
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
|
| 226 |
k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
|
| 227 |
src0_dd, src1_dd, dst_dd,
|
| 228 |
ne0, ne1, ne2, ne3,
|
| 229 |
ne10, ne11, ne12, ne13,
|
| 230 |
/* s0, */ s1, s2, s3,
|
| 231 |
+
/* s00, */ s01, s02, s03,
|
| 232 |
/* s10, */ s11, s12, s13);
|
| 233 |
} else {
|
| 234 |
k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
|
|
|
|
| 236 |
ne0, ne1, ne2, ne3,
|
| 237 |
ne10, ne11, ne12, ne13,
|
| 238 |
/* s0, */ s1, s2, s3,
|
| 239 |
+
/* s00, */ s01, s02, s03,
|
| 240 |
/* s10, */ s11, s12, s13);
|
| 241 |
}
|
| 242 |
}
|
ggml-cuda/convert.cu
CHANGED
|
@@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
|
|
| 45 |
vals[ix] = x0[ix];
|
| 46 |
}
|
| 47 |
|
|
|
|
|
|
|
| 48 |
#pragma unroll
|
| 49 |
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
|
| 50 |
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
|
|
|
|
| 45 |
vals[ix] = x0[ix];
|
| 46 |
}
|
| 47 |
|
| 48 |
+
__syncthreads();
|
| 49 |
+
|
| 50 |
#pragma unroll
|
| 51 |
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
|
| 52 |
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
|
ggml-metal.m
CHANGED
|
@@ -1747,15 +1747,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1747 |
} break;
|
| 1748 |
case GGML_OP_MUL_MAT_ID:
|
| 1749 |
{
|
| 1750 |
-
//GGML_ASSERT(ne00 == ne10);
|
| 1751 |
-
//GGML_ASSERT(ne03 == ne13);
|
| 1752 |
const int n_as = src0->ne[2];
|
| 1753 |
|
| 1754 |
-
// max size of the src1ids array in the kernel shared buffer
|
| 1755 |
-
GGML_ASSERT(ne11 <= 4096);
|
| 1756 |
-
|
| 1757 |
// src2 = ids
|
| 1758 |
-
const int64_t ne20 = src2->ne[0];
|
| 1759 |
const int64_t ne21 = src2->ne[1];
|
| 1760 |
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
| 1761 |
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
|
@@ -1776,15 +1771,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1776 |
|
| 1777 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 1778 |
// to the matrix-vector kernel
|
| 1779 |
-
|
| 1780 |
-
|
| 1781 |
-
const int
|
|
|
|
| 1782 |
|
| 1783 |
-
//
|
| 1784 |
-
GGML_ASSERT(
|
| 1785 |
-
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
|
| 1786 |
-
const uint r2 = 1;
|
| 1787 |
-
const uint r3 = 1;
|
| 1788 |
|
| 1789 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1790 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
@@ -1794,7 +1787,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1794 |
// !!!
|
| 1795 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
| 1796 |
ne00 % 32 == 0 && ne00 >= 64 &&
|
| 1797 |
-
|
| 1798 |
|
| 1799 |
// some Metal matrix data types require aligned pointers
|
| 1800 |
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
@@ -1836,26 +1829,26 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 1836 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1837 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1838 |
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
| 1839 |
-
[encoder setBytes:&
|
| 1840 |
-
[encoder setBytes:&
|
| 1841 |
-
[encoder setBytes:&
|
| 1842 |
-
[encoder setBytes:&
|
| 1843 |
-
[encoder setBytes:&
|
| 1844 |
-
[encoder setBytes:&
|
| 1845 |
-
[encoder setBytes:&
|
| 1846 |
-
[encoder setBytes:&
|
| 1847 |
-
[encoder setBytes:&
|
| 1848 |
-
[encoder setBytes:&
|
| 1849 |
-
[encoder setBytes:&
|
| 1850 |
-
[encoder setBytes:&
|
| 1851 |
-
[encoder setBytes:&
|
| 1852 |
-
[encoder setBytes:&
|
| 1853 |
-
[encoder setBytes:&
|
| 1854 |
-
[encoder setBytes:&
|
| 1855 |
-
|
| 1856 |
-
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 +
|
| 1857 |
-
|
| 1858 |
-
[encoder dispatchThreadgroups:MTLSizeMake((
|
| 1859 |
} else {
|
| 1860 |
int nth0 = 32;
|
| 1861 |
int nth1 = 1;
|
|
@@ -2008,72 +2001,72 @@ static enum ggml_status ggml_metal_graph_compute(
|
|
| 2008 |
GGML_ASSERT(ne00 >= nth0*nth1);
|
| 2009 |
}
|
| 2010 |
|
| 2011 |
-
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
| 2012 |
-
|
| 2013 |
[encoder setComputePipelineState:pipeline];
|
| 2014 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2015 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2016 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 2017 |
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
| 2018 |
-
[encoder setBytes:&
|
| 2019 |
-
[encoder setBytes:&
|
| 2020 |
-
[encoder setBytes:&
|
| 2021 |
-
[encoder setBytes:&
|
| 2022 |
-
[encoder setBytes:&
|
| 2023 |
-
[encoder setBytes:&
|
| 2024 |
-
[encoder setBytes:&
|
| 2025 |
-
[encoder setBytes:&
|
| 2026 |
-
[encoder setBytes:&
|
| 2027 |
-
[encoder setBytes:&
|
| 2028 |
-
[encoder setBytes:&
|
| 2029 |
-
[encoder setBytes:&
|
| 2030 |
-
[encoder setBytes:&
|
| 2031 |
-
[encoder setBytes:&
|
| 2032 |
-
[encoder setBytes:&
|
| 2033 |
-
[encoder setBytes:&
|
| 2034 |
-
[encoder setBytes:&
|
| 2035 |
-
[encoder setBytes:&
|
| 2036 |
-
[encoder setBytes:&
|
| 2037 |
-
|
|
|
|
|
|
|
| 2038 |
|
| 2039 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
| 2040 |
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
| 2041 |
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
| 2042 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
| 2043 |
}
|
| 2044 |
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
| 2045 |
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
| 2046 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2047 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
| 2048 |
}
|
| 2049 |
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
| 2050 |
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
| 2051 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2052 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1,
|
| 2053 |
}
|
| 2054 |
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
| 2055 |
const int mem_size = 32*sizeof(float);
|
| 2056 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2057 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
| 2058 |
}
|
| 2059 |
else if (src0t == GGML_TYPE_Q4_K) {
|
| 2060 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
| 2061 |
}
|
| 2062 |
else if (src0t == GGML_TYPE_Q3_K) {
|
| 2063 |
#ifdef GGML_QKK_64
|
| 2064 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1,
|
| 2065 |
#else
|
| 2066 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
| 2067 |
#endif
|
| 2068 |
}
|
| 2069 |
else if (src0t == GGML_TYPE_Q5_K) {
|
| 2070 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1,
|
| 2071 |
}
|
| 2072 |
else if (src0t == GGML_TYPE_Q6_K) {
|
| 2073 |
-
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1,
|
| 2074 |
} else {
|
| 2075 |
-
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
| 2076 |
-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny,
|
| 2077 |
}
|
| 2078 |
}
|
| 2079 |
} break;
|
|
|
|
| 1747 |
} break;
|
| 1748 |
case GGML_OP_MUL_MAT_ID:
|
| 1749 |
{
|
|
|
|
|
|
|
| 1750 |
const int n_as = src0->ne[2];
|
| 1751 |
|
|
|
|
|
|
|
|
|
|
| 1752 |
// src2 = ids
|
| 1753 |
+
const int64_t ne20 = src2->ne[0];
|
| 1754 |
const int64_t ne21 = src2->ne[1];
|
| 1755 |
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
| 1756 |
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
|
|
|
| 1771 |
|
| 1772 |
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
| 1773 |
// to the matrix-vector kernel
|
| 1774 |
+
// ne20 = n_used_experts
|
| 1775 |
+
// ne21 = n_rows
|
| 1776 |
+
const int dst_rows = ne20*ne21;
|
| 1777 |
+
const int dst_rows_min = n_as;
|
| 1778 |
|
| 1779 |
+
// max size of the rowids array in the kernel shared buffer
|
| 1780 |
+
GGML_ASSERT(dst_rows <= 2048);
|
|
|
|
|
|
|
|
|
|
| 1781 |
|
| 1782 |
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
| 1783 |
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
|
|
|
| 1787 |
// !!!
|
| 1788 |
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
| 1789 |
ne00 % 32 == 0 && ne00 >= 64 &&
|
| 1790 |
+
dst_rows > dst_rows_min) {
|
| 1791 |
|
| 1792 |
// some Metal matrix data types require aligned pointers
|
| 1793 |
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
|
|
|
| 1829 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 1830 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 1831 |
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
| 1832 |
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
| 1833 |
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
| 1834 |
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
| 1835 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
| 1836 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
|
| 1837 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
| 1838 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
| 1839 |
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
|
| 1840 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
| 1841 |
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
| 1842 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
| 1843 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
| 1844 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
| 1845 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
| 1846 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
|
| 1847 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
| 1848 |
+
|
| 1849 |
+
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
|
| 1850 |
+
|
| 1851 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
| 1852 |
} else {
|
| 1853 |
int nth0 = 32;
|
| 1854 |
int nth1 = 1;
|
|
|
|
| 2001 |
GGML_ASSERT(ne00 >= nth0*nth1);
|
| 2002 |
}
|
| 2003 |
|
|
|
|
|
|
|
| 2004 |
[encoder setComputePipelineState:pipeline];
|
| 2005 |
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
| 2006 |
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
| 2007 |
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
| 2008 |
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
| 2009 |
+
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
| 2010 |
+
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
| 2011 |
+
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
| 2012 |
+
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
|
| 2013 |
+
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
|
| 2014 |
+
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
|
| 2015 |
+
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
|
| 2016 |
+
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
|
| 2017 |
+
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
|
| 2018 |
+
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
|
| 2019 |
+
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
|
| 2020 |
+
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
|
| 2021 |
+
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
|
| 2022 |
+
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
|
| 2023 |
+
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
|
| 2024 |
+
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
|
| 2025 |
+
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
|
| 2026 |
+
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
|
| 2027 |
+
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
|
| 2028 |
+
|
| 2029 |
+
const int64_t _ne1 = 1;
|
| 2030 |
+
const int tgz = dst_rows;
|
| 2031 |
|
| 2032 |
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
| 2033 |
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
| 2034 |
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
| 2035 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2036 |
}
|
| 2037 |
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
| 2038 |
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
| 2039 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2040 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2041 |
}
|
| 2042 |
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
| 2043 |
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
| 2044 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2045 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2046 |
}
|
| 2047 |
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
| 2048 |
const int mem_size = 32*sizeof(float);
|
| 2049 |
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
| 2050 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2051 |
}
|
| 2052 |
else if (src0t == GGML_TYPE_Q4_K) {
|
| 2053 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2054 |
}
|
| 2055 |
else if (src0t == GGML_TYPE_Q3_K) {
|
| 2056 |
#ifdef GGML_QKK_64
|
| 2057 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2058 |
#else
|
| 2059 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2060 |
#endif
|
| 2061 |
}
|
| 2062 |
else if (src0t == GGML_TYPE_Q5_K) {
|
| 2063 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2064 |
}
|
| 2065 |
else if (src0t == GGML_TYPE_Q6_K) {
|
| 2066 |
+
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2067 |
} else {
|
| 2068 |
+
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
|
| 2069 |
+
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
| 2070 |
}
|
| 2071 |
}
|
| 2072 |
} break;
|
ggml-metal.metal
CHANGED
|
@@ -899,16 +899,16 @@ void mul_vec_q_n_f32_impl(
|
|
| 899 |
device const void * src0,
|
| 900 |
device const float * src1,
|
| 901 |
device float * dst,
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
|
| 911 |
-
|
| 912 |
uint3 tgpig, uint tiisg, uint sgitg) {
|
| 913 |
const int nb = ne00/QK4_0;
|
| 914 |
|
|
@@ -1073,19 +1073,19 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
| 1073 |
device const void * src0,
|
| 1074 |
device const float * src1,
|
| 1075 |
device float * dst,
|
| 1076 |
-
|
| 1077 |
-
|
| 1078 |
-
|
| 1079 |
-
|
| 1080 |
-
|
| 1081 |
-
|
| 1082 |
-
|
| 1083 |
-
|
| 1084 |
-
|
| 1085 |
-
threadgroup int8_t
|
| 1086 |
-
|
| 1087 |
-
|
| 1088 |
-
|
| 1089 |
const int nr = N_DST;
|
| 1090 |
const int nsg = N_SIMDGROUP;
|
| 1091 |
const int nw = N_SIMDWIDTH;
|
|
@@ -1172,24 +1172,24 @@ void kernel_mul_mv_f32_f32_impl(
|
|
| 1172 |
device const char * src0,
|
| 1173 |
device const char * src1,
|
| 1174 |
device float * dst,
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
-
|
| 1178 |
-
|
| 1179 |
-
|
| 1180 |
-
|
| 1181 |
-
|
| 1182 |
-
|
| 1183 |
-
|
| 1184 |
-
|
| 1185 |
-
|
| 1186 |
-
|
| 1187 |
-
|
| 1188 |
-
|
| 1189 |
-
|
| 1190 |
-
|
| 1191 |
-
|
| 1192 |
-
|
| 1193 |
|
| 1194 |
const int64_t r0 = tgpig.x;
|
| 1195 |
const int64_t rb = tgpig.y*N_F32_F32;
|
|
@@ -1442,24 +1442,24 @@ void kernel_mul_mv_f16_f32_impl(
|
|
| 1442 |
device const char * src0,
|
| 1443 |
device const char * src1,
|
| 1444 |
device float * dst,
|
| 1445 |
-
|
| 1446 |
-
|
| 1447 |
-
|
| 1448 |
-
|
| 1449 |
-
|
| 1450 |
-
|
| 1451 |
-
|
| 1452 |
-
|
| 1453 |
-
|
| 1454 |
-
|
| 1455 |
-
|
| 1456 |
-
|
| 1457 |
-
|
| 1458 |
-
|
| 1459 |
-
|
| 1460 |
-
|
| 1461 |
-
|
| 1462 |
-
|
| 1463 |
|
| 1464 |
const int64_t r0 = tgpig.x;
|
| 1465 |
const int64_t rb = tgpig.y*N_F16_F32;
|
|
@@ -2744,19 +2744,19 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|
| 2744 |
device const void * src0,
|
| 2745 |
device const float * src1,
|
| 2746 |
device float * dst,
|
| 2747 |
-
|
| 2748 |
-
|
| 2749 |
-
|
| 2750 |
-
|
| 2751 |
-
|
| 2752 |
-
|
| 2753 |
-
|
| 2754 |
-
|
| 2755 |
-
|
| 2756 |
-
threadgroup int8_t * shared_values
|
| 2757 |
-
|
| 2758 |
-
|
| 2759 |
-
|
| 2760 |
|
| 2761 |
const int nb = ne00/QK_K;
|
| 2762 |
const int r0 = tgpig.x;
|
|
@@ -2924,19 +2924,19 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|
| 2924 |
device const void * src0,
|
| 2925 |
device const float * src1,
|
| 2926 |
device float * dst,
|
| 2927 |
-
|
| 2928 |
-
|
| 2929 |
-
|
| 2930 |
-
|
| 2931 |
-
|
| 2932 |
-
|
| 2933 |
-
|
| 2934 |
-
|
| 2935 |
-
|
| 2936 |
-
threadgroup int8_t * shared_values
|
| 2937 |
-
|
| 2938 |
-
|
| 2939 |
-
|
| 2940 |
|
| 2941 |
const int nb = ne00/QK_K;
|
| 2942 |
|
|
@@ -3190,19 +3190,19 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
| 3190 |
device const void * src0,
|
| 3191 |
device const float * src1,
|
| 3192 |
device float * dst,
|
| 3193 |
-
|
| 3194 |
-
|
| 3195 |
-
|
| 3196 |
-
|
| 3197 |
-
|
| 3198 |
-
|
| 3199 |
-
|
| 3200 |
-
|
| 3201 |
-
|
| 3202 |
-
threadgroup int8_t * shared_values
|
| 3203 |
-
|
| 3204 |
-
|
| 3205 |
-
|
| 3206 |
|
| 3207 |
const uint16_t kmask1 = 0x3f3f;
|
| 3208 |
const uint16_t kmask2 = 0x0f0f;
|
|
@@ -3429,19 +3429,19 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
| 3429 |
device const void * src0,
|
| 3430 |
device const float * src1,
|
| 3431 |
device float * dst,
|
| 3432 |
-
|
| 3433 |
-
|
| 3434 |
-
|
| 3435 |
-
|
| 3436 |
-
|
| 3437 |
-
|
| 3438 |
-
|
| 3439 |
-
|
| 3440 |
-
|
| 3441 |
-
threadgroup int8_t * shared_values
|
| 3442 |
-
|
| 3443 |
-
|
| 3444 |
-
|
| 3445 |
|
| 3446 |
const int nb = ne00/QK_K;
|
| 3447 |
|
|
@@ -3636,19 +3636,19 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
| 3636 |
device const void * src0,
|
| 3637 |
device const float * src1,
|
| 3638 |
device float * dst,
|
| 3639 |
-
|
| 3640 |
-
|
| 3641 |
-
|
| 3642 |
-
|
| 3643 |
-
|
| 3644 |
-
|
| 3645 |
-
|
| 3646 |
-
|
| 3647 |
-
|
| 3648 |
-
threadgroup int8_t * shared_values
|
| 3649 |
-
|
| 3650 |
-
|
| 3651 |
-
|
| 3652 |
|
| 3653 |
const uint8_t kmask1 = 0x03;
|
| 3654 |
const uint8_t kmask2 = 0x0C;
|
|
@@ -3773,19 +3773,19 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
| 3773 |
device const void * src0,
|
| 3774 |
device const float * src1,
|
| 3775 |
device float * dst,
|
| 3776 |
-
|
| 3777 |
-
|
| 3778 |
-
|
| 3779 |
-
|
| 3780 |
-
|
| 3781 |
-
|
| 3782 |
-
|
| 3783 |
-
|
| 3784 |
-
|
| 3785 |
-
threadgroup int8_t * shared_values
|
| 3786 |
-
|
| 3787 |
-
|
| 3788 |
-
|
| 3789 |
|
| 3790 |
const int nb = ne00/QK_K;
|
| 3791 |
const int r0 = tgpig.x;
|
|
@@ -3902,19 +3902,19 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|
| 3902 |
device const void * src0,
|
| 3903 |
device const float * src1,
|
| 3904 |
device float * dst,
|
| 3905 |
-
|
| 3906 |
-
|
| 3907 |
-
|
| 3908 |
-
|
| 3909 |
-
|
| 3910 |
-
|
| 3911 |
-
|
| 3912 |
-
|
| 3913 |
-
|
| 3914 |
-
threadgroup int8_t * shared_values
|
| 3915 |
-
|
| 3916 |
-
|
| 3917 |
-
|
| 3918 |
|
| 3919 |
const int nb = ne00/QK_K;
|
| 3920 |
const int r0 = tgpig.x;
|
|
@@ -4041,19 +4041,19 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
| 4041 |
device const void * src0,
|
| 4042 |
device const float * src1,
|
| 4043 |
device float * dst,
|
| 4044 |
-
|
| 4045 |
-
|
| 4046 |
-
|
| 4047 |
-
|
| 4048 |
-
|
| 4049 |
-
|
| 4050 |
-
|
| 4051 |
-
|
| 4052 |
-
|
| 4053 |
-
threadgroup int8_t * shared_values
|
| 4054 |
-
|
| 4055 |
-
|
| 4056 |
-
|
| 4057 |
|
| 4058 |
const int nb = ne00/QK_K;
|
| 4059 |
const int r0 = tgpig.x;
|
|
@@ -4173,19 +4173,19 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|
| 4173 |
device const void * src0,
|
| 4174 |
device const float * src1,
|
| 4175 |
device float * dst,
|
| 4176 |
-
|
| 4177 |
-
|
| 4178 |
-
|
| 4179 |
-
|
| 4180 |
-
|
| 4181 |
-
|
| 4182 |
-
|
| 4183 |
-
|
| 4184 |
-
|
| 4185 |
-
threadgroup int8_t * shared_values
|
| 4186 |
-
|
| 4187 |
-
|
| 4188 |
-
|
| 4189 |
|
| 4190 |
const int nb = ne00/QK_K;
|
| 4191 |
const int r0 = tgpig.x;
|
|
@@ -4305,19 +4305,19 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|
| 4305 |
device const void * src0,
|
| 4306 |
device const float * src1,
|
| 4307 |
device float * dst,
|
| 4308 |
-
|
| 4309 |
-
|
| 4310 |
-
|
| 4311 |
-
|
| 4312 |
-
|
| 4313 |
-
|
| 4314 |
-
|
| 4315 |
-
|
| 4316 |
-
|
| 4317 |
-
threadgroup int8_t * shared_values
|
| 4318 |
-
|
| 4319 |
-
|
| 4320 |
-
|
| 4321 |
|
| 4322 |
const int nb = ne00/QK_K;
|
| 4323 |
const int r0 = tgpig.x;
|
|
@@ -4438,19 +4438,19 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|
| 4438 |
device const void * src0,
|
| 4439 |
device const float * src1,
|
| 4440 |
device float * dst,
|
| 4441 |
-
|
| 4442 |
-
|
| 4443 |
-
|
| 4444 |
-
|
| 4445 |
-
|
| 4446 |
-
|
| 4447 |
-
|
| 4448 |
-
|
| 4449 |
-
|
| 4450 |
-
threadgroup int8_t *
|
| 4451 |
-
|
| 4452 |
-
|
| 4453 |
-
|
| 4454 |
|
| 4455 |
const int nb = ne00/QK_K;
|
| 4456 |
const int r0 = tgpig.x;
|
|
@@ -4528,19 +4528,19 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|
| 4528 |
device const void * src0,
|
| 4529 |
device const float * src1,
|
| 4530 |
device float * dst,
|
| 4531 |
-
|
| 4532 |
-
|
| 4533 |
-
|
| 4534 |
-
|
| 4535 |
-
|
| 4536 |
-
|
| 4537 |
-
|
| 4538 |
-
|
| 4539 |
-
|
| 4540 |
-
threadgroup int8_t *
|
| 4541 |
-
|
| 4542 |
-
|
| 4543 |
-
|
| 4544 |
|
| 4545 |
const int nb = ne00/QK_K;
|
| 4546 |
const int r0 = tgpig.x;
|
|
@@ -4637,19 +4637,19 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|
| 4637 |
device const void * src0,
|
| 4638 |
device const float * src1,
|
| 4639 |
device float * dst,
|
| 4640 |
-
|
| 4641 |
-
|
| 4642 |
-
|
| 4643 |
-
|
| 4644 |
-
|
| 4645 |
-
|
| 4646 |
-
|
| 4647 |
-
|
| 4648 |
-
|
| 4649 |
-
threadgroup int8_t * shared_values_i8
|
| 4650 |
-
|
| 4651 |
-
|
| 4652 |
-
|
| 4653 |
|
| 4654 |
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
| 4655 |
const int nb = ne00/QK4_NL;
|
|
@@ -4732,19 +4732,20 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|
| 4732 |
device const void * src0,
|
| 4733 |
device const float * src1,
|
| 4734 |
device float * dst,
|
| 4735 |
-
|
| 4736 |
-
|
| 4737 |
-
|
| 4738 |
-
|
| 4739 |
-
|
| 4740 |
-
|
| 4741 |
-
|
| 4742 |
-
|
| 4743 |
-
|
| 4744 |
-
threadgroup int8_t
|
| 4745 |
-
|
| 4746 |
-
|
| 4747 |
-
|
|
|
|
| 4748 |
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
| 4749 |
const int nb = ne00/QK_K;
|
| 4750 |
const int r0 = tgpig.x;
|
|
@@ -5686,25 +5687,25 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
|
| 5686 |
}
|
| 5687 |
}
|
| 5688 |
|
| 5689 |
-
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in
|
| 5690 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 5691 |
void kernel_mul_mm_id_impl(
|
| 5692 |
device const uchar * src0,
|
| 5693 |
device const uchar * src1,
|
| 5694 |
-
threadgroup
|
| 5695 |
device float * dst,
|
| 5696 |
constant int64_t & ne00,
|
| 5697 |
constant int64_t & ne02,
|
| 5698 |
constant uint64_t & nb01,
|
| 5699 |
constant uint64_t & nb02,
|
|
|
|
| 5700 |
constant int64_t & ne12,
|
| 5701 |
constant uint64_t & nb10,
|
| 5702 |
constant uint64_t & nb11,
|
| 5703 |
constant uint64_t & nb12,
|
| 5704 |
constant int64_t & ne0,
|
| 5705 |
int64_t ne1,
|
| 5706 |
-
|
| 5707 |
-
constant uint & r3,
|
| 5708 |
threadgroup uchar * shared_memory,
|
| 5709 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5710 |
uint tiitg[[thread_index_in_threadgroup]],
|
|
@@ -5715,7 +5716,6 @@ void kernel_mul_mm_id_impl(
|
|
| 5715 |
|
| 5716 |
const uint r0 = tgpig.y;
|
| 5717 |
const uint r1 = tgpig.x;
|
| 5718 |
-
const uint im = tgpig.z;
|
| 5719 |
|
| 5720 |
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
| 5721 |
|
|
@@ -5733,19 +5733,16 @@ void kernel_mul_mm_id_impl(
|
|
| 5733 |
for (int i = 0; i < 8; i++){
|
| 5734 |
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
| 5735 |
}
|
| 5736 |
-
|
| 5737 |
short il = (tiitg % THREAD_PER_ROW);
|
| 5738 |
|
| 5739 |
-
const uint i12 = im%ne12;
|
| 5740 |
-
const uint i13 = im/ne12;
|
| 5741 |
-
|
| 5742 |
-
uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
|
| 5743 |
ushort offset1 = il/nl;
|
| 5744 |
|
| 5745 |
-
|
|
|
|
|
|
|
| 5746 |
device const float * y = (device const float *)(src1
|
| 5747 |
-
+ nb12 *
|
| 5748 |
-
+ nb11 *
|
| 5749 |
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
| 5750 |
|
| 5751 |
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
@@ -5774,11 +5771,11 @@ void kernel_mul_mm_id_impl(
|
|
| 5774 |
|
| 5775 |
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
| 5776 |
for (int i = 0; i < 4; i++) {
|
| 5777 |
-
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
|
| 5778 |
}
|
| 5779 |
simdgroup_barrier(mem_flags::mem_none);
|
| 5780 |
for (int i = 0; i < 2; i++) {
|
| 5781 |
-
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
|
| 5782 |
}
|
| 5783 |
|
| 5784 |
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
@@ -5800,11 +5797,13 @@ void kernel_mul_mm_id_impl(
|
|
| 5800 |
|
| 5801 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5802 |
|
| 5803 |
-
device float * C = dst + (BLOCK_SIZE_M * r0)
|
| 5804 |
if (sgitg == 0) {
|
| 5805 |
-
for (int
|
| 5806 |
-
|
| 5807 |
-
|
|
|
|
|
|
|
| 5808 |
}
|
| 5809 |
}
|
| 5810 |
}
|
|
@@ -5859,11 +5858,14 @@ kernel void kernel_mul_mm_id(
|
|
| 5859 |
device const uchar * src1,
|
| 5860 |
device float * dst,
|
| 5861 |
device const uchar * ids,
|
|
|
|
|
|
|
| 5862 |
constant uint64_t & nbi1,
|
| 5863 |
constant int64_t & ne00,
|
| 5864 |
constant int64_t & ne02,
|
| 5865 |
constant uint64_t & nb01,
|
| 5866 |
constant uint64_t & nb02,
|
|
|
|
| 5867 |
constant int64_t & ne12,
|
| 5868 |
constant int64_t & ne13,
|
| 5869 |
constant uint64_t & nb10,
|
|
@@ -5872,47 +5874,52 @@ kernel void kernel_mul_mm_id(
|
|
| 5872 |
constant int64_t & ne0,
|
| 5873 |
constant int64_t & ne1,
|
| 5874 |
constant uint64_t & nb1,
|
| 5875 |
-
constant uint & r2,
|
| 5876 |
-
constant uint & r3,
|
| 5877 |
-
constant int & idx,
|
| 5878 |
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
| 5879 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5880 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5881 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5882 |
|
| 5883 |
-
|
| 5884 |
-
|
| 5885 |
-
device const uchar * src0 = src0s + id*nb02;
|
| 5886 |
|
| 5887 |
-
|
| 5888 |
|
| 5889 |
-
// row indices
|
| 5890 |
-
threadgroup
|
| 5891 |
|
|
|
|
| 5892 |
int64_t _ne1 = 0;
|
| 5893 |
-
for (
|
| 5894 |
-
|
| 5895 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5896 |
}
|
| 5897 |
}
|
| 5898 |
|
|
|
|
|
|
|
| 5899 |
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
| 5900 |
src0,
|
| 5901 |
src1,
|
| 5902 |
-
|
| 5903 |
dst,
|
| 5904 |
ne00,
|
| 5905 |
ne02,
|
| 5906 |
nb01,
|
| 5907 |
nb02,
|
|
|
|
| 5908 |
ne12,
|
| 5909 |
nb10,
|
| 5910 |
nb11,
|
| 5911 |
nb12,
|
| 5912 |
ne0,
|
| 5913 |
_ne1,
|
| 5914 |
-
|
| 5915 |
-
r3,
|
| 5916 |
shared_memory,
|
| 5917 |
tgpig,
|
| 5918 |
tiitg,
|
|
@@ -5973,24 +5980,7 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_r
|
|
| 5973 |
// matrix-matrix multiplication
|
| 5974 |
//
|
| 5975 |
|
| 5976 |
-
typedef
|
| 5977 |
-
device const uchar * src0,
|
| 5978 |
-
device const uchar * src1,
|
| 5979 |
-
device float * dst,
|
| 5980 |
-
constant int64_t & ne00,
|
| 5981 |
-
constant int64_t & ne02,
|
| 5982 |
-
constant uint64_t & nb01,
|
| 5983 |
-
constant uint64_t & nb02,
|
| 5984 |
-
constant int64_t & ne12,
|
| 5985 |
-
constant uint64_t & nb10,
|
| 5986 |
-
constant uint64_t & nb11,
|
| 5987 |
-
constant uint64_t & nb12,
|
| 5988 |
-
constant int64_t & ne0,
|
| 5989 |
-
constant int64_t & ne1,
|
| 5990 |
-
constant uint & r2,
|
| 5991 |
-
constant uint & r3,
|
| 5992 |
-
threadgroup uchar *,
|
| 5993 |
-
uint3, uint, uint);
|
| 5994 |
|
| 5995 |
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
| 5996 |
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
@@ -6022,29 +6012,7 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|
| 6022 |
// indirect matrix-matrix multiplication
|
| 6023 |
//
|
| 6024 |
|
| 6025 |
-
typedef
|
| 6026 |
-
device const uchar * src0s,
|
| 6027 |
-
device const uchar * src1,
|
| 6028 |
-
device float * dst,
|
| 6029 |
-
device const uchar * ids,
|
| 6030 |
-
constant uint64_t & nbi1,
|
| 6031 |
-
constant int64_t & ne00,
|
| 6032 |
-
constant int64_t & ne02,
|
| 6033 |
-
constant uint64_t & nb01,
|
| 6034 |
-
constant uint64_t & nb02,
|
| 6035 |
-
constant int64_t & ne12,
|
| 6036 |
-
constant int64_t & ne13,
|
| 6037 |
-
constant uint64_t & nb10,
|
| 6038 |
-
constant uint64_t & nb11,
|
| 6039 |
-
constant uint64_t & nb12,
|
| 6040 |
-
constant int64_t & ne0,
|
| 6041 |
-
constant int64_t & ne1,
|
| 6042 |
-
constant uint64_t & nb1,
|
| 6043 |
-
constant uint & r2,
|
| 6044 |
-
constant uint & r3,
|
| 6045 |
-
constant int & idx,
|
| 6046 |
-
threadgroup uchar *,
|
| 6047 |
-
uint3, uint, uint);
|
| 6048 |
|
| 6049 |
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
| 6050 |
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
@@ -6080,71 +6048,71 @@ typedef void (kernel_mul_mv_impl_t)(
|
|
| 6080 |
device const char * src0,
|
| 6081 |
device const char * src1,
|
| 6082 |
device float * dst,
|
| 6083 |
-
|
| 6084 |
-
|
| 6085 |
-
|
| 6086 |
-
|
| 6087 |
-
|
| 6088 |
-
|
| 6089 |
-
|
| 6090 |
-
|
| 6091 |
-
|
| 6092 |
-
|
| 6093 |
-
|
| 6094 |
-
|
| 6095 |
-
|
| 6096 |
-
|
| 6097 |
-
|
| 6098 |
-
|
| 6099 |
-
|
| 6100 |
-
|
| 6101 |
|
| 6102 |
typedef void (kernel_mul_mv2_impl_t)(
|
| 6103 |
device const void * src0,
|
| 6104 |
device const float * src1,
|
| 6105 |
device float * dst,
|
| 6106 |
-
|
| 6107 |
-
|
| 6108 |
-
|
| 6109 |
-
|
| 6110 |
-
|
| 6111 |
-
|
| 6112 |
-
|
| 6113 |
-
|
| 6114 |
-
|
| 6115 |
-
threadgroup int8_t * shared_values
|
| 6116 |
-
|
| 6117 |
-
|
| 6118 |
-
|
| 6119 |
|
| 6120 |
template<kernel_mul_mv_impl_t impl_fn>
|
| 6121 |
void mmv_fn(
|
| 6122 |
device const char * src0,
|
| 6123 |
device const char * src1,
|
| 6124 |
device float * dst,
|
| 6125 |
-
|
| 6126 |
-
|
| 6127 |
-
|
| 6128 |
-
|
| 6129 |
-
|
| 6130 |
-
|
| 6131 |
-
|
| 6132 |
-
|
| 6133 |
-
|
| 6134 |
-
|
| 6135 |
-
|
| 6136 |
-
|
| 6137 |
-
|
| 6138 |
-
|
| 6139 |
-
|
| 6140 |
-
|
| 6141 |
-
|
| 6142 |
-
|
| 6143 |
-
threadgroup int8_t * shared_values
|
| 6144 |
-
uint3 tgpig
|
| 6145 |
-
uint tiitg
|
| 6146 |
-
uint tiisg
|
| 6147 |
-
uint sgitg
|
| 6148 |
impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
|
| 6149 |
}
|
| 6150 |
|
|
@@ -6153,59 +6121,33 @@ void mmv_fn(
|
|
| 6153 |
device const char * src0,
|
| 6154 |
device const char * src1,
|
| 6155 |
device float * dst,
|
| 6156 |
-
|
| 6157 |
-
|
| 6158 |
-
|
| 6159 |
-
|
| 6160 |
-
|
| 6161 |
-
|
| 6162 |
-
|
| 6163 |
-
|
| 6164 |
-
|
| 6165 |
-
|
| 6166 |
-
|
| 6167 |
-
|
| 6168 |
-
|
| 6169 |
-
|
| 6170 |
-
|
| 6171 |
-
|
| 6172 |
-
|
| 6173 |
-
|
| 6174 |
-
threadgroup int8_t * shared_values
|
| 6175 |
-
uint3 tgpig
|
| 6176 |
-
uint tiitg
|
| 6177 |
-
uint tiisg
|
| 6178 |
-
uint sgitg
|
| 6179 |
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
| 6180 |
}
|
| 6181 |
|
| 6182 |
-
typedef
|
| 6183 |
-
device const char * src0,
|
| 6184 |
-
device const char * src1,
|
| 6185 |
-
device float * dst,
|
| 6186 |
-
constant int64_t & ne00,
|
| 6187 |
-
constant int64_t & ne01,
|
| 6188 |
-
constant int64_t & ne02,
|
| 6189 |
-
constant uint64_t & nb00,
|
| 6190 |
-
constant uint64_t & nb01,
|
| 6191 |
-
constant uint64_t & nb02,
|
| 6192 |
-
constant int64_t & ne10,
|
| 6193 |
-
constant int64_t & ne11,
|
| 6194 |
-
constant int64_t & ne12,
|
| 6195 |
-
constant int64_t & ne13,
|
| 6196 |
-
constant uint64_t & nb10,
|
| 6197 |
-
constant uint64_t & nb11,
|
| 6198 |
-
constant uint64_t & nb12,
|
| 6199 |
-
constant int64_t & ne0,
|
| 6200 |
-
constant int64_t & ne1,
|
| 6201 |
-
constant uint64_t & nb1,
|
| 6202 |
-
constant uint & r2,
|
| 6203 |
-
constant uint & r3,
|
| 6204 |
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6205 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6206 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6207 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6208 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
| 6209 |
|
| 6210 |
template<mul_mv_impl_fn_t impl_fn>
|
| 6211 |
kernel void kernel_mul_mv_id(
|
|
@@ -6213,6 +6155,8 @@ kernel void kernel_mul_mv_id(
|
|
| 6213 |
device const char * src1,
|
| 6214 |
device float * dst,
|
| 6215 |
device const char * ids,
|
|
|
|
|
|
|
| 6216 |
constant uint64_t & nbi1,
|
| 6217 |
constant int64_t & ne00,
|
| 6218 |
constant int64_t & ne01,
|
|
@@ -6230,43 +6174,50 @@ kernel void kernel_mul_mv_id(
|
|
| 6230 |
constant int64_t & ne0,
|
| 6231 |
constant int64_t & ne1,
|
| 6232 |
constant uint64_t & nb1,
|
| 6233 |
-
constant uint & r2,
|
| 6234 |
-
constant uint & r3,
|
| 6235 |
-
constant int & idx,
|
| 6236 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6237 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6238 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6239 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6240 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6241 |
-
const
|
|
|
|
|
|
|
|
|
|
| 6242 |
|
| 6243 |
-
|
| 6244 |
|
| 6245 |
-
const
|
| 6246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6247 |
|
| 6248 |
impl_fn(
|
| 6249 |
-
src0,
|
| 6250 |
-
src1
|
| 6251 |
-
dst
|
| 6252 |
-
ne00,
|
| 6253 |
-
ne01,
|
| 6254 |
-
ne02,
|
| 6255 |
-
nb00,
|
| 6256 |
-
nb01,
|
| 6257 |
-
nb02,
|
| 6258 |
-
ne10,
|
| 6259 |
-
ne11,
|
| 6260 |
-
ne12,
|
| 6261 |
-
ne13,
|
| 6262 |
-
nb10,
|
| 6263 |
-
nb11,
|
| 6264 |
-
nb12,
|
| 6265 |
-
ne0,
|
| 6266 |
-
ne1,
|
| 6267 |
-
nb1,
|
| 6268 |
-
r2,
|
| 6269 |
-
r3,
|
| 6270 |
shared_values,
|
| 6271 |
tgpig,
|
| 6272 |
tiitg,
|
|
@@ -6274,36 +6225,7 @@ kernel void kernel_mul_mv_id(
|
|
| 6274 |
sgitg);
|
| 6275 |
}
|
| 6276 |
|
| 6277 |
-
typedef
|
| 6278 |
-
device const char * src0s,
|
| 6279 |
-
device const char * src1,
|
| 6280 |
-
device float * dst,
|
| 6281 |
-
device const char * ids,
|
| 6282 |
-
constant uint64_t & nbi1,
|
| 6283 |
-
constant int64_t & ne00,
|
| 6284 |
-
constant int64_t & ne01,
|
| 6285 |
-
constant int64_t & ne02,
|
| 6286 |
-
constant uint64_t & nb00,
|
| 6287 |
-
constant uint64_t & nb01,
|
| 6288 |
-
constant uint64_t & nb02,
|
| 6289 |
-
constant int64_t & ne10,
|
| 6290 |
-
constant int64_t & ne11,
|
| 6291 |
-
constant int64_t & ne12,
|
| 6292 |
-
constant int64_t & ne13,
|
| 6293 |
-
constant uint64_t & nb10,
|
| 6294 |
-
constant uint64_t & nb11,
|
| 6295 |
-
constant uint64_t & nb12,
|
| 6296 |
-
constant int64_t & ne0,
|
| 6297 |
-
constant int64_t & ne1,
|
| 6298 |
-
constant uint64_t & nb1,
|
| 6299 |
-
constant uint & r2,
|
| 6300 |
-
constant uint & r3,
|
| 6301 |
-
constant int & idx,
|
| 6302 |
-
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6303 |
-
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6304 |
-
uint tiitg[[thread_index_in_threadgroup]],
|
| 6305 |
-
uint tiisg[[thread_index_in_simdgroup]],
|
| 6306 |
-
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
| 6307 |
|
| 6308 |
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
| 6309 |
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
|
|
|
| 899 |
device const void * src0,
|
| 900 |
device const float * src1,
|
| 901 |
device float * dst,
|
| 902 |
+
int64_t ne00,
|
| 903 |
+
int64_t ne01,
|
| 904 |
+
int64_t ne02,
|
| 905 |
+
int64_t ne10,
|
| 906 |
+
int64_t ne12,
|
| 907 |
+
int64_t ne0,
|
| 908 |
+
int64_t ne1,
|
| 909 |
+
uint r2,
|
| 910 |
+
uint r3,
|
| 911 |
+
threadgroup int8_t * shared_values,
|
| 912 |
uint3 tgpig, uint tiisg, uint sgitg) {
|
| 913 |
const int nb = ne00/QK4_0;
|
| 914 |
|
|
|
|
| 1073 |
device const void * src0,
|
| 1074 |
device const float * src1,
|
| 1075 |
device float * dst,
|
| 1076 |
+
int64_t ne00,
|
| 1077 |
+
int64_t ne01,
|
| 1078 |
+
int64_t ne02,
|
| 1079 |
+
int64_t ne10,
|
| 1080 |
+
int64_t ne12,
|
| 1081 |
+
int64_t ne0,
|
| 1082 |
+
int64_t ne1,
|
| 1083 |
+
uint r2,
|
| 1084 |
+
uint r3,
|
| 1085 |
+
threadgroup int8_t * shared_values,
|
| 1086 |
+
uint3 tgpig,
|
| 1087 |
+
uint tiisg,
|
| 1088 |
+
uint sgitg) {
|
| 1089 |
const int nr = N_DST;
|
| 1090 |
const int nsg = N_SIMDGROUP;
|
| 1091 |
const int nw = N_SIMDWIDTH;
|
|
|
|
| 1172 |
device const char * src0,
|
| 1173 |
device const char * src1,
|
| 1174 |
device float * dst,
|
| 1175 |
+
int64_t ne00,
|
| 1176 |
+
int64_t ne01,
|
| 1177 |
+
int64_t ne02,
|
| 1178 |
+
uint64_t nb00,
|
| 1179 |
+
uint64_t nb01,
|
| 1180 |
+
uint64_t nb02,
|
| 1181 |
+
int64_t ne10,
|
| 1182 |
+
int64_t ne11,
|
| 1183 |
+
int64_t ne12,
|
| 1184 |
+
uint64_t nb10,
|
| 1185 |
+
uint64_t nb11,
|
| 1186 |
+
uint64_t nb12,
|
| 1187 |
+
int64_t ne0,
|
| 1188 |
+
int64_t ne1,
|
| 1189 |
+
uint r2,
|
| 1190 |
+
uint r3,
|
| 1191 |
+
uint3 tgpig,
|
| 1192 |
+
uint tiisg) {
|
| 1193 |
|
| 1194 |
const int64_t r0 = tgpig.x;
|
| 1195 |
const int64_t rb = tgpig.y*N_F32_F32;
|
|
|
|
| 1442 |
device const char * src0,
|
| 1443 |
device const char * src1,
|
| 1444 |
device float * dst,
|
| 1445 |
+
int64_t ne00,
|
| 1446 |
+
int64_t ne01,
|
| 1447 |
+
int64_t ne02,
|
| 1448 |
+
uint64_t nb00,
|
| 1449 |
+
uint64_t nb01,
|
| 1450 |
+
uint64_t nb02,
|
| 1451 |
+
int64_t ne10,
|
| 1452 |
+
int64_t ne11,
|
| 1453 |
+
int64_t ne12,
|
| 1454 |
+
uint64_t nb10,
|
| 1455 |
+
uint64_t nb11,
|
| 1456 |
+
uint64_t nb12,
|
| 1457 |
+
int64_t ne0,
|
| 1458 |
+
int64_t ne1,
|
| 1459 |
+
uint r2,
|
| 1460 |
+
uint r3,
|
| 1461 |
+
uint3 tgpig,
|
| 1462 |
+
uint tiisg) {
|
| 1463 |
|
| 1464 |
const int64_t r0 = tgpig.x;
|
| 1465 |
const int64_t rb = tgpig.y*N_F16_F32;
|
|
|
|
| 2744 |
device const void * src0,
|
| 2745 |
device const float * src1,
|
| 2746 |
device float * dst,
|
| 2747 |
+
int64_t ne00,
|
| 2748 |
+
int64_t ne01,
|
| 2749 |
+
int64_t ne02,
|
| 2750 |
+
int64_t ne10,
|
| 2751 |
+
int64_t ne12,
|
| 2752 |
+
int64_t ne0,
|
| 2753 |
+
int64_t ne1,
|
| 2754 |
+
uint r2,
|
| 2755 |
+
uint r3,
|
| 2756 |
+
threadgroup int8_t * shared_values,
|
| 2757 |
+
uint3 tgpig,
|
| 2758 |
+
uint tiisg,
|
| 2759 |
+
uint sgitg) {
|
| 2760 |
|
| 2761 |
const int nb = ne00/QK_K;
|
| 2762 |
const int r0 = tgpig.x;
|
|
|
|
| 2924 |
device const void * src0,
|
| 2925 |
device const float * src1,
|
| 2926 |
device float * dst,
|
| 2927 |
+
int64_t ne00,
|
| 2928 |
+
int64_t ne01,
|
| 2929 |
+
int64_t ne02,
|
| 2930 |
+
int64_t ne10,
|
| 2931 |
+
int64_t ne12,
|
| 2932 |
+
int64_t ne0,
|
| 2933 |
+
int64_t ne1,
|
| 2934 |
+
uint r2,
|
| 2935 |
+
uint r3,
|
| 2936 |
+
threadgroup int8_t * shared_values,
|
| 2937 |
+
uint3 tgpig,
|
| 2938 |
+
uint tiisg,
|
| 2939 |
+
uint sgitg) {
|
| 2940 |
|
| 2941 |
const int nb = ne00/QK_K;
|
| 2942 |
|
|
|
|
| 3190 |
device const void * src0,
|
| 3191 |
device const float * src1,
|
| 3192 |
device float * dst,
|
| 3193 |
+
int64_t ne00,
|
| 3194 |
+
int64_t ne01,
|
| 3195 |
+
int64_t ne02,
|
| 3196 |
+
int64_t ne10,
|
| 3197 |
+
int64_t ne12,
|
| 3198 |
+
int64_t ne0,
|
| 3199 |
+
int64_t ne1,
|
| 3200 |
+
uint r2,
|
| 3201 |
+
uint r3,
|
| 3202 |
+
threadgroup int8_t * shared_values,
|
| 3203 |
+
uint3 tgpig,
|
| 3204 |
+
uint tiisg,
|
| 3205 |
+
uint sgitg) {
|
| 3206 |
|
| 3207 |
const uint16_t kmask1 = 0x3f3f;
|
| 3208 |
const uint16_t kmask2 = 0x0f0f;
|
|
|
|
| 3429 |
device const void * src0,
|
| 3430 |
device const float * src1,
|
| 3431 |
device float * dst,
|
| 3432 |
+
int64_t ne00,
|
| 3433 |
+
int64_t ne01,
|
| 3434 |
+
int64_t ne02,
|
| 3435 |
+
int64_t ne10,
|
| 3436 |
+
int64_t ne12,
|
| 3437 |
+
int64_t ne0,
|
| 3438 |
+
int64_t ne1,
|
| 3439 |
+
uint r2,
|
| 3440 |
+
uint r3,
|
| 3441 |
+
threadgroup int8_t * shared_values,
|
| 3442 |
+
uint3 tgpig,
|
| 3443 |
+
uint tiisg,
|
| 3444 |
+
uint sgitg) {
|
| 3445 |
|
| 3446 |
const int nb = ne00/QK_K;
|
| 3447 |
|
|
|
|
| 3636 |
device const void * src0,
|
| 3637 |
device const float * src1,
|
| 3638 |
device float * dst,
|
| 3639 |
+
int64_t ne00,
|
| 3640 |
+
int64_t ne01,
|
| 3641 |
+
int64_t ne02,
|
| 3642 |
+
int64_t ne10,
|
| 3643 |
+
int64_t ne12,
|
| 3644 |
+
int64_t ne0,
|
| 3645 |
+
int64_t ne1,
|
| 3646 |
+
uint r2,
|
| 3647 |
+
uint r3,
|
| 3648 |
+
threadgroup int8_t * shared_values,
|
| 3649 |
+
uint3 tgpig,
|
| 3650 |
+
uint tiisg,
|
| 3651 |
+
uint sgitg) {
|
| 3652 |
|
| 3653 |
const uint8_t kmask1 = 0x03;
|
| 3654 |
const uint8_t kmask2 = 0x0C;
|
|
|
|
| 3773 |
device const void * src0,
|
| 3774 |
device const float * src1,
|
| 3775 |
device float * dst,
|
| 3776 |
+
int64_t ne00,
|
| 3777 |
+
int64_t ne01,
|
| 3778 |
+
int64_t ne02,
|
| 3779 |
+
int64_t ne10,
|
| 3780 |
+
int64_t ne12,
|
| 3781 |
+
int64_t ne0,
|
| 3782 |
+
int64_t ne1,
|
| 3783 |
+
uint r2,
|
| 3784 |
+
uint r3,
|
| 3785 |
+
threadgroup int8_t * shared_values,
|
| 3786 |
+
uint3 tgpig,
|
| 3787 |
+
uint tiisg,
|
| 3788 |
+
uint sgitg) {
|
| 3789 |
|
| 3790 |
const int nb = ne00/QK_K;
|
| 3791 |
const int r0 = tgpig.x;
|
|
|
|
| 3902 |
device const void * src0,
|
| 3903 |
device const float * src1,
|
| 3904 |
device float * dst,
|
| 3905 |
+
int64_t ne00,
|
| 3906 |
+
int64_t ne01,
|
| 3907 |
+
int64_t ne02,
|
| 3908 |
+
int64_t ne10,
|
| 3909 |
+
int64_t ne12,
|
| 3910 |
+
int64_t ne0,
|
| 3911 |
+
int64_t ne1,
|
| 3912 |
+
uint r2,
|
| 3913 |
+
uint r3,
|
| 3914 |
+
threadgroup int8_t * shared_values,
|
| 3915 |
+
uint3 tgpig,
|
| 3916 |
+
uint tiisg,
|
| 3917 |
+
uint sgitg) {
|
| 3918 |
|
| 3919 |
const int nb = ne00/QK_K;
|
| 3920 |
const int r0 = tgpig.x;
|
|
|
|
| 4041 |
device const void * src0,
|
| 4042 |
device const float * src1,
|
| 4043 |
device float * dst,
|
| 4044 |
+
int64_t ne00,
|
| 4045 |
+
int64_t ne01,
|
| 4046 |
+
int64_t ne02,
|
| 4047 |
+
int64_t ne10,
|
| 4048 |
+
int64_t ne12,
|
| 4049 |
+
int64_t ne0,
|
| 4050 |
+
int64_t ne1,
|
| 4051 |
+
uint r2,
|
| 4052 |
+
uint r3,
|
| 4053 |
+
threadgroup int8_t * shared_values,
|
| 4054 |
+
uint3 tgpig,
|
| 4055 |
+
uint tiisg,
|
| 4056 |
+
uint sgitg) {
|
| 4057 |
|
| 4058 |
const int nb = ne00/QK_K;
|
| 4059 |
const int r0 = tgpig.x;
|
|
|
|
| 4173 |
device const void * src0,
|
| 4174 |
device const float * src1,
|
| 4175 |
device float * dst,
|
| 4176 |
+
int64_t ne00,
|
| 4177 |
+
int64_t ne01,
|
| 4178 |
+
int64_t ne02,
|
| 4179 |
+
int64_t ne10,
|
| 4180 |
+
int64_t ne12,
|
| 4181 |
+
int64_t ne0,
|
| 4182 |
+
int64_t ne1,
|
| 4183 |
+
uint r2,
|
| 4184 |
+
uint r3,
|
| 4185 |
+
threadgroup int8_t * shared_values,
|
| 4186 |
+
uint3 tgpig,
|
| 4187 |
+
uint tiisg,
|
| 4188 |
+
uint sgitg) {
|
| 4189 |
|
| 4190 |
const int nb = ne00/QK_K;
|
| 4191 |
const int r0 = tgpig.x;
|
|
|
|
| 4305 |
device const void * src0,
|
| 4306 |
device const float * src1,
|
| 4307 |
device float * dst,
|
| 4308 |
+
int64_t ne00,
|
| 4309 |
+
int64_t ne01,
|
| 4310 |
+
int64_t ne02,
|
| 4311 |
+
int64_t ne10,
|
| 4312 |
+
int64_t ne12,
|
| 4313 |
+
int64_t ne0,
|
| 4314 |
+
int64_t ne1,
|
| 4315 |
+
uint r2,
|
| 4316 |
+
uint r3,
|
| 4317 |
+
threadgroup int8_t * shared_values,
|
| 4318 |
+
uint3 tgpig,
|
| 4319 |
+
uint tiisg,
|
| 4320 |
+
uint sgitg) {
|
| 4321 |
|
| 4322 |
const int nb = ne00/QK_K;
|
| 4323 |
const int r0 = tgpig.x;
|
|
|
|
| 4438 |
device const void * src0,
|
| 4439 |
device const float * src1,
|
| 4440 |
device float * dst,
|
| 4441 |
+
int64_t ne00,
|
| 4442 |
+
int64_t ne01,
|
| 4443 |
+
int64_t ne02,
|
| 4444 |
+
int64_t ne10,
|
| 4445 |
+
int64_t ne12,
|
| 4446 |
+
int64_t ne0,
|
| 4447 |
+
int64_t ne1,
|
| 4448 |
+
uint r2,
|
| 4449 |
+
uint r3,
|
| 4450 |
+
threadgroup int8_t * shared_value,
|
| 4451 |
+
uint3 tgpig,
|
| 4452 |
+
uint tiisg,
|
| 4453 |
+
uint sgitg) {
|
| 4454 |
|
| 4455 |
const int nb = ne00/QK_K;
|
| 4456 |
const int r0 = tgpig.x;
|
|
|
|
| 4528 |
device const void * src0,
|
| 4529 |
device const float * src1,
|
| 4530 |
device float * dst,
|
| 4531 |
+
int64_t ne00,
|
| 4532 |
+
int64_t ne01,
|
| 4533 |
+
int64_t ne02,
|
| 4534 |
+
int64_t ne10,
|
| 4535 |
+
int64_t ne12,
|
| 4536 |
+
int64_t ne0,
|
| 4537 |
+
int64_t ne1,
|
| 4538 |
+
uint r2,
|
| 4539 |
+
uint r3,
|
| 4540 |
+
threadgroup int8_t * shared_value,
|
| 4541 |
+
uint3 tgpig,
|
| 4542 |
+
uint tiisg,
|
| 4543 |
+
uint sgitg) {
|
| 4544 |
|
| 4545 |
const int nb = ne00/QK_K;
|
| 4546 |
const int r0 = tgpig.x;
|
|
|
|
| 4637 |
device const void * src0,
|
| 4638 |
device const float * src1,
|
| 4639 |
device float * dst,
|
| 4640 |
+
int64_t ne00,
|
| 4641 |
+
int64_t ne01,
|
| 4642 |
+
int64_t ne02,
|
| 4643 |
+
int64_t ne10,
|
| 4644 |
+
int64_t ne12,
|
| 4645 |
+
int64_t ne0,
|
| 4646 |
+
int64_t ne1,
|
| 4647 |
+
uint r2,
|
| 4648 |
+
uint r3,
|
| 4649 |
+
threadgroup int8_t * shared_values_i8,
|
| 4650 |
+
uint3 tgpig,
|
| 4651 |
+
uint tiisg,
|
| 4652 |
+
uint sgitg) {
|
| 4653 |
|
| 4654 |
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
| 4655 |
const int nb = ne00/QK4_NL;
|
|
|
|
| 4732 |
device const void * src0,
|
| 4733 |
device const float * src1,
|
| 4734 |
device float * dst,
|
| 4735 |
+
int64_t ne00,
|
| 4736 |
+
int64_t ne01,
|
| 4737 |
+
int64_t ne02,
|
| 4738 |
+
int64_t ne10,
|
| 4739 |
+
int64_t ne12,
|
| 4740 |
+
int64_t ne0,
|
| 4741 |
+
int64_t ne1,
|
| 4742 |
+
uint r2,
|
| 4743 |
+
uint r3,
|
| 4744 |
+
threadgroup int8_t * shared_values_i8,
|
| 4745 |
+
uint3 tgpig,
|
| 4746 |
+
uint tiisg,
|
| 4747 |
+
uint sgitg) {
|
| 4748 |
+
|
| 4749 |
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
|
| 4750 |
const int nb = ne00/QK_K;
|
| 4751 |
const int r0 = tgpig.x;
|
|
|
|
| 5687 |
}
|
| 5688 |
}
|
| 5689 |
|
| 5690 |
+
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
|
| 5691 |
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
| 5692 |
void kernel_mul_mm_id_impl(
|
| 5693 |
device const uchar * src0,
|
| 5694 |
device const uchar * src1,
|
| 5695 |
+
threadgroup ushort2 * rowids,
|
| 5696 |
device float * dst,
|
| 5697 |
constant int64_t & ne00,
|
| 5698 |
constant int64_t & ne02,
|
| 5699 |
constant uint64_t & nb01,
|
| 5700 |
constant uint64_t & nb02,
|
| 5701 |
+
constant int64_t & ne11,
|
| 5702 |
constant int64_t & ne12,
|
| 5703 |
constant uint64_t & nb10,
|
| 5704 |
constant uint64_t & nb11,
|
| 5705 |
constant uint64_t & nb12,
|
| 5706 |
constant int64_t & ne0,
|
| 5707 |
int64_t ne1,
|
| 5708 |
+
int64_t ne0ne1,
|
|
|
|
| 5709 |
threadgroup uchar * shared_memory,
|
| 5710 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5711 |
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
| 5716 |
|
| 5717 |
const uint r0 = tgpig.y;
|
| 5718 |
const uint r1 = tgpig.x;
|
|
|
|
| 5719 |
|
| 5720 |
if (r1 * BLOCK_SIZE_N >= ne1) return;
|
| 5721 |
|
|
|
|
| 5733 |
for (int i = 0; i < 8; i++){
|
| 5734 |
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
|
| 5735 |
}
|
|
|
|
| 5736 |
short il = (tiitg % THREAD_PER_ROW);
|
| 5737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5738 |
ushort offset1 = il/nl;
|
| 5739 |
|
| 5740 |
+
threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
|
| 5741 |
+
|
| 5742 |
+
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
|
| 5743 |
device const float * y = (device const float *)(src1
|
| 5744 |
+
+ nb12 * id[1]
|
| 5745 |
+
+ nb11 * (id[0] % ne11)
|
| 5746 |
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
| 5747 |
|
| 5748 |
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
|
|
|
| 5771 |
|
| 5772 |
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
|
| 5773 |
for (int i = 0; i < 4; i++) {
|
| 5774 |
+
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
|
| 5775 |
}
|
| 5776 |
simdgroup_barrier(mem_flags::mem_none);
|
| 5777 |
for (int i = 0; i < 2; i++) {
|
| 5778 |
+
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
|
| 5779 |
}
|
| 5780 |
|
| 5781 |
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
|
|
|
|
| 5797 |
|
| 5798 |
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5799 |
|
| 5800 |
+
device float * C = dst + (BLOCK_SIZE_M * r0);
|
| 5801 |
if (sgitg == 0) {
|
| 5802 |
+
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
|
| 5803 |
+
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
|
| 5804 |
+
int joff = jid[0] * ne0 + jid[1] * ne0ne1;
|
| 5805 |
+
for (int i = 0; i < n_rows; i++) {
|
| 5806 |
+
*(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
|
| 5807 |
}
|
| 5808 |
}
|
| 5809 |
}
|
|
|
|
| 5858 |
device const uchar * src1,
|
| 5859 |
device float * dst,
|
| 5860 |
device const uchar * ids,
|
| 5861 |
+
constant int64_t & nei0,
|
| 5862 |
+
constant int64_t & nei1,
|
| 5863 |
constant uint64_t & nbi1,
|
| 5864 |
constant int64_t & ne00,
|
| 5865 |
constant int64_t & ne02,
|
| 5866 |
constant uint64_t & nb01,
|
| 5867 |
constant uint64_t & nb02,
|
| 5868 |
+
constant int64_t & ne11,
|
| 5869 |
constant int64_t & ne12,
|
| 5870 |
constant int64_t & ne13,
|
| 5871 |
constant uint64_t & nb10,
|
|
|
|
| 5874 |
constant int64_t & ne0,
|
| 5875 |
constant int64_t & ne1,
|
| 5876 |
constant uint64_t & nb1,
|
|
|
|
|
|
|
|
|
|
| 5877 |
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
| 5878 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 5879 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 5880 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 5881 |
|
| 5882 |
+
const int32_t i02 = tgpig.z;
|
| 5883 |
+
tgpig.z = 0;
|
|
|
|
| 5884 |
|
| 5885 |
+
device const uchar * src0 = src0s + i02*nb02;
|
| 5886 |
|
| 5887 |
+
// row indices
|
| 5888 |
+
threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
|
| 5889 |
|
| 5890 |
+
// TODO: parallelize this loop
|
| 5891 |
int64_t _ne1 = 0;
|
| 5892 |
+
for (ushort ii1 = 0; ii1 < nei1; ii1++) {
|
| 5893 |
+
for (ushort ii0 = 0; ii0 < nei0; ii0++) {
|
| 5894 |
+
int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
|
| 5895 |
+
if (id == i02) {
|
| 5896 |
+
//if (tiitg == 0) {
|
| 5897 |
+
rowids[_ne1] = ushort2(ii0, ii1);
|
| 5898 |
+
//}
|
| 5899 |
+
_ne1++;
|
| 5900 |
+
}
|
| 5901 |
}
|
| 5902 |
}
|
| 5903 |
|
| 5904 |
+
threadgroup_barrier(mem_flags::mem_threadgroup);
|
| 5905 |
+
|
| 5906 |
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
| 5907 |
src0,
|
| 5908 |
src1,
|
| 5909 |
+
rowids,
|
| 5910 |
dst,
|
| 5911 |
ne00,
|
| 5912 |
ne02,
|
| 5913 |
nb01,
|
| 5914 |
nb02,
|
| 5915 |
+
ne11,
|
| 5916 |
ne12,
|
| 5917 |
nb10,
|
| 5918 |
nb11,
|
| 5919 |
nb12,
|
| 5920 |
ne0,
|
| 5921 |
_ne1,
|
| 5922 |
+
ne0*ne1,
|
|
|
|
| 5923 |
shared_memory,
|
| 5924 |
tgpig,
|
| 5925 |
tiitg,
|
|
|
|
| 5980 |
// matrix-matrix multiplication
|
| 5981 |
//
|
| 5982 |
|
| 5983 |
+
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5984 |
|
| 5985 |
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
| 5986 |
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
|
|
|
| 6012 |
// indirect matrix-matrix multiplication
|
| 6013 |
//
|
| 6014 |
|
| 6015 |
+
typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6016 |
|
| 6017 |
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
|
| 6018 |
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
|
|
|
|
| 6048 |
device const char * src0,
|
| 6049 |
device const char * src1,
|
| 6050 |
device float * dst,
|
| 6051 |
+
int64_t ne00,
|
| 6052 |
+
int64_t ne01,
|
| 6053 |
+
int64_t ne02,
|
| 6054 |
+
uint64_t nb00,
|
| 6055 |
+
uint64_t nb01,
|
| 6056 |
+
uint64_t nb02,
|
| 6057 |
+
int64_t ne10,
|
| 6058 |
+
int64_t ne11,
|
| 6059 |
+
int64_t ne12,
|
| 6060 |
+
uint64_t nb10,
|
| 6061 |
+
uint64_t nb11,
|
| 6062 |
+
uint64_t nb12,
|
| 6063 |
+
int64_t ne0,
|
| 6064 |
+
int64_t ne1,
|
| 6065 |
+
uint r2,
|
| 6066 |
+
uint r3,
|
| 6067 |
+
uint3 tgpig,
|
| 6068 |
+
uint tiisg);
|
| 6069 |
|
| 6070 |
typedef void (kernel_mul_mv2_impl_t)(
|
| 6071 |
device const void * src0,
|
| 6072 |
device const float * src1,
|
| 6073 |
device float * dst,
|
| 6074 |
+
int64_t ne00,
|
| 6075 |
+
int64_t ne01,
|
| 6076 |
+
int64_t ne02,
|
| 6077 |
+
int64_t ne10,
|
| 6078 |
+
int64_t ne12,
|
| 6079 |
+
int64_t ne0,
|
| 6080 |
+
int64_t ne1,
|
| 6081 |
+
uint r2,
|
| 6082 |
+
uint r3,
|
| 6083 |
+
threadgroup int8_t * shared_values,
|
| 6084 |
+
uint3 tgpig,
|
| 6085 |
+
uint tiisg,
|
| 6086 |
+
uint sgitg);
|
| 6087 |
|
| 6088 |
template<kernel_mul_mv_impl_t impl_fn>
|
| 6089 |
void mmv_fn(
|
| 6090 |
device const char * src0,
|
| 6091 |
device const char * src1,
|
| 6092 |
device float * dst,
|
| 6093 |
+
int64_t ne00,
|
| 6094 |
+
int64_t ne01,
|
| 6095 |
+
int64_t ne02,
|
| 6096 |
+
uint64_t nb00,
|
| 6097 |
+
uint64_t nb01,
|
| 6098 |
+
uint64_t nb02,
|
| 6099 |
+
int64_t ne10,
|
| 6100 |
+
int64_t ne11,
|
| 6101 |
+
int64_t ne12,
|
| 6102 |
+
int64_t ne13,
|
| 6103 |
+
uint64_t nb10,
|
| 6104 |
+
uint64_t nb11,
|
| 6105 |
+
uint64_t nb12,
|
| 6106 |
+
int64_t ne0,
|
| 6107 |
+
int64_t ne1,
|
| 6108 |
+
uint64_t nb1,
|
| 6109 |
+
uint r2,
|
| 6110 |
+
uint r3,
|
| 6111 |
+
threadgroup int8_t * shared_values,
|
| 6112 |
+
uint3 tgpig,
|
| 6113 |
+
uint tiitg,
|
| 6114 |
+
uint tiisg,
|
| 6115 |
+
uint sgitg) {
|
| 6116 |
impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
|
| 6117 |
}
|
| 6118 |
|
|
|
|
| 6121 |
device const char * src0,
|
| 6122 |
device const char * src1,
|
| 6123 |
device float * dst,
|
| 6124 |
+
int64_t ne00,
|
| 6125 |
+
int64_t ne01,
|
| 6126 |
+
int64_t ne02,
|
| 6127 |
+
uint64_t nb00,
|
| 6128 |
+
uint64_t nb01,
|
| 6129 |
+
uint64_t nb02,
|
| 6130 |
+
int64_t ne10,
|
| 6131 |
+
int64_t ne11,
|
| 6132 |
+
int64_t ne12,
|
| 6133 |
+
int64_t ne13,
|
| 6134 |
+
uint64_t nb10,
|
| 6135 |
+
uint64_t nb11,
|
| 6136 |
+
uint64_t nb12,
|
| 6137 |
+
int64_t ne0,
|
| 6138 |
+
int64_t ne1,
|
| 6139 |
+
uint64_t nb1,
|
| 6140 |
+
uint r2,
|
| 6141 |
+
uint r3,
|
| 6142 |
+
threadgroup int8_t * shared_values,
|
| 6143 |
+
uint3 tgpig,
|
| 6144 |
+
uint tiitg,
|
| 6145 |
+
uint tiisg,
|
| 6146 |
+
uint sgitg) {
|
| 6147 |
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
| 6148 |
}
|
| 6149 |
|
| 6150 |
+
typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6151 |
|
| 6152 |
template<mul_mv_impl_fn_t impl_fn>
|
| 6153 |
kernel void kernel_mul_mv_id(
|
|
|
|
| 6155 |
device const char * src1,
|
| 6156 |
device float * dst,
|
| 6157 |
device const char * ids,
|
| 6158 |
+
constant int64_t & nei0,
|
| 6159 |
+
constant int64_t & nei1,
|
| 6160 |
constant uint64_t & nbi1,
|
| 6161 |
constant int64_t & ne00,
|
| 6162 |
constant int64_t & ne01,
|
|
|
|
| 6174 |
constant int64_t & ne0,
|
| 6175 |
constant int64_t & ne1,
|
| 6176 |
constant uint64_t & nb1,
|
|
|
|
|
|
|
|
|
|
| 6177 |
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
| 6178 |
uint3 tgpig[[threadgroup_position_in_grid]],
|
| 6179 |
uint tiitg[[thread_index_in_threadgroup]],
|
| 6180 |
uint tiisg[[thread_index_in_simdgroup]],
|
| 6181 |
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
| 6182 |
+
const int iid1 = tgpig.z/nei0;
|
| 6183 |
+
const int idx = tgpig.z%nei0;
|
| 6184 |
+
|
| 6185 |
+
tgpig.z = 0;
|
| 6186 |
|
| 6187 |
+
const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
|
| 6188 |
|
| 6189 |
+
const int64_t i11 = idx % ne11;
|
| 6190 |
+
const int64_t i12 = iid1;
|
| 6191 |
+
|
| 6192 |
+
const int64_t i1 = idx;
|
| 6193 |
+
const int64_t i2 = i12;
|
| 6194 |
+
|
| 6195 |
+
device const char * src0_cur = src0s + i02*nb02;
|
| 6196 |
+
device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
|
| 6197 |
+
device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
|
| 6198 |
|
| 6199 |
impl_fn(
|
| 6200 |
+
/* src0 */ src0_cur,
|
| 6201 |
+
/* src1 */ src1_cur,
|
| 6202 |
+
/* dst */ dst_cur,
|
| 6203 |
+
/* ne00 */ ne00,
|
| 6204 |
+
/* ne01 */ ne01,
|
| 6205 |
+
/* ne02 */ 1,//ne02,
|
| 6206 |
+
/* nb00 */ nb00,
|
| 6207 |
+
/* nb01 */ nb01,
|
| 6208 |
+
/* nb02 */ nb02,
|
| 6209 |
+
/* ne10 */ ne10,
|
| 6210 |
+
/* ne11 */ 1,//ne11,
|
| 6211 |
+
/* ne12 */ 1,//ne12,
|
| 6212 |
+
/* ne13 */ 1,//ne13,
|
| 6213 |
+
/* nb10 */ nb10,
|
| 6214 |
+
/* nb11 */ nb11,
|
| 6215 |
+
/* nb12 */ nb12,
|
| 6216 |
+
/* ne0 */ ne0,
|
| 6217 |
+
/* ne1 */ 1,//ne1,
|
| 6218 |
+
/* nb1 */ nb1,
|
| 6219 |
+
/* r2 */ 1,
|
| 6220 |
+
/* r3 */ 1,
|
| 6221 |
shared_values,
|
| 6222 |
tgpig,
|
| 6223 |
tiitg,
|
|
|
|
| 6225 |
sgitg);
|
| 6226 |
}
|
| 6227 |
|
| 6228 |
+
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6229 |
|
| 6230 |
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
| 6231 |
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
ggml-sycl.cpp
CHANGED
|
@@ -17752,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
|
| 17752 |
|
| 17753 |
GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
| 17754 |
const int min_batch_size = 32;
|
| 17755 |
-
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
| 17756 |
GGML_UNUSED(backend);
|
| 17757 |
}
|
| 17758 |
|
|
|
|
| 17752 |
|
| 17753 |
GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
|
| 17754 |
const int min_batch_size = 32;
|
| 17755 |
+
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
|
| 17756 |
GGML_UNUSED(backend);
|
| 17757 |
}
|
| 17758 |
|
ggml.c
CHANGED
|
@@ -4594,21 +4594,32 @@ void ggml_mul_mat_set_prec(
|
|
| 4594 |
|
| 4595 |
// ggml_mul_mat_id
|
| 4596 |
|
| 4597 |
-
|
| 4598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4599 |
struct ggml_tensor * ggml_mul_mat_id(
|
| 4600 |
struct ggml_context * ctx,
|
| 4601 |
struct ggml_tensor * as,
|
| 4602 |
-
struct ggml_tensor *
|
| 4603 |
-
|
| 4604 |
-
|
| 4605 |
-
|
| 4606 |
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
|
|
|
|
|
|
|
|
|
| 4607 |
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
|
| 4608 |
-
GGML_ASSERT(ids->ne[1] == b->ne[
|
| 4609 |
-
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
|
| 4610 |
-
GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
|
| 4611 |
GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
|
|
|
|
| 4612 |
|
| 4613 |
bool is_node = false;
|
| 4614 |
|
|
@@ -4616,11 +4627,9 @@ struct ggml_tensor * ggml_mul_mat_id(
|
|
| 4616 |
is_node = true;
|
| 4617 |
}
|
| 4618 |
|
| 4619 |
-
const int64_t ne[4] = { as->ne[1],
|
| 4620 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4621 |
|
| 4622 |
-
ggml_set_op_params_i32(result, 0, id);
|
| 4623 |
-
|
| 4624 |
result->op = GGML_OP_MUL_MAT_ID;
|
| 4625 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 4626 |
result->src[0] = as;
|
|
@@ -11071,11 +11080,6 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11071 |
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
| 11072 |
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
| 11073 |
|
| 11074 |
-
GGML_ASSERT(ne0 == ne01);
|
| 11075 |
-
GGML_ASSERT(ne1 == ne11);
|
| 11076 |
-
GGML_ASSERT(ne2 == ne12);
|
| 11077 |
-
GGML_ASSERT(ne3 == ne13);
|
| 11078 |
-
|
| 11079 |
// we don't support permuted src0 or src1
|
| 11080 |
GGML_ASSERT(nb00 == ggml_type_size(type));
|
| 11081 |
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
|
@@ -11086,22 +11090,21 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11086 |
GGML_ASSERT(nb1 <= nb2);
|
| 11087 |
GGML_ASSERT(nb2 <= nb3);
|
| 11088 |
|
| 11089 |
-
// broadcast is not supported with mmid
|
| 11090 |
-
assert(ne12 == 1);
|
| 11091 |
-
assert(ne13 == 1);
|
| 11092 |
-
|
| 11093 |
// row groups
|
| 11094 |
-
const int
|
| 11095 |
-
const int n_as
|
| 11096 |
|
| 11097 |
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
| 11098 |
(char *) params->wdata :
|
| 11099 |
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
| 11100 |
|
| 11101 |
-
|
| 11102 |
-
|
|
|
|
|
|
|
| 11103 |
|
| 11104 |
-
|
|
|
|
| 11105 |
|
| 11106 |
if (params->type == GGML_TASK_TYPE_INIT) {
|
| 11107 |
if (ith != 0) {
|
|
@@ -11127,13 +11130,18 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11127 |
// initialize matrix_row_counts
|
| 11128 |
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
| 11129 |
|
|
|
|
|
|
|
| 11130 |
// group rows by src0 matrix
|
| 11131 |
-
for (int64_t
|
| 11132 |
-
|
|
|
|
|
|
|
|
|
|
| 11133 |
|
| 11134 |
-
|
| 11135 |
-
|
| 11136 |
-
|
| 11137 |
}
|
| 11138 |
|
| 11139 |
return;
|
|
@@ -11151,15 +11159,13 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11151 |
continue;
|
| 11152 |
}
|
| 11153 |
|
| 11154 |
-
|
| 11155 |
|
| 11156 |
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
| 11157 |
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
| 11158 |
|
| 11159 |
-
const int64_t nr0 = ne01;
|
| 11160 |
-
const int64_t nr1 = cne1
|
| 11161 |
-
|
| 11162 |
-
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
| 11163 |
|
| 11164 |
// distribute the thread work across the inner or outer loop based on which one is larger
|
| 11165 |
|
|
@@ -11178,13 +11184,11 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11178 |
const int64_t ir110 = dr1*ith1;
|
| 11179 |
const int64_t ir111 = MIN(ir110 + dr1, nr1);
|
| 11180 |
|
| 11181 |
-
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
|
| 11182 |
-
|
| 11183 |
// threads with no work simply yield (not sure if it helps)
|
| 11184 |
-
if (ir010 >= ir011 || ir110 >= ir111) {
|
| 11185 |
-
|
| 11186 |
-
|
| 11187 |
-
}
|
| 11188 |
|
| 11189 |
// block-tiling attempt
|
| 11190 |
const int64_t blck_0 = 16;
|
|
@@ -11196,20 +11200,16 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11196 |
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
| 11197 |
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
| 11198 |
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
| 11199 |
-
const int64_t
|
| 11200 |
-
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
|
| 11201 |
-
const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
|
| 11202 |
-
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
|
| 11203 |
|
| 11204 |
-
|
| 11205 |
-
|
| 11206 |
-
//const int64_t i02 = i12/r2;
|
| 11207 |
|
| 11208 |
-
const int64_t
|
| 11209 |
-
const int64_t i2
|
| 11210 |
-
const int64_t i3 = i13;
|
| 11211 |
|
| 11212 |
-
const
|
|
|
|
| 11213 |
|
| 11214 |
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
| 11215 |
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
|
@@ -11217,25 +11217,26 @@ static void ggml_compute_forward_mul_mat_id(
|
|
| 11217 |
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
| 11218 |
const char * src1_col = (const char *) wdata +
|
| 11219 |
(src1_cont || src1->type != vec_dot_type
|
| 11220 |
-
? (i11 + i12*ne11
|
| 11221 |
-
: (i11*nb11 + i12*nb12
|
| 11222 |
|
| 11223 |
-
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2
|
| 11224 |
|
| 11225 |
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
| 11226 |
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
| 11227 |
//}
|
| 11228 |
|
| 11229 |
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
| 11230 |
-
vec_dot(ne00, &tmp[ir0 - iir0], 0,
|
| 11231 |
}
|
|
|
|
| 11232 |
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
| 11233 |
}
|
| 11234 |
}
|
| 11235 |
}
|
| 11236 |
}
|
| 11237 |
|
| 11238 |
-
|
| 11239 |
}
|
| 11240 |
|
| 11241 |
// ggml_compute_forward_out_prod
|
|
@@ -18583,7 +18584,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
|
| 18583 |
const int n_as = src0->ne[2];
|
| 18584 |
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
| 18585 |
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
| 18586 |
-
cur += n_as * src1->ne[
|
| 18587 |
} break;
|
| 18588 |
case GGML_OP_OUT_PROD:
|
| 18589 |
{
|
|
@@ -21009,12 +21010,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
|
| 21009 |
|
| 21010 |
ok = ok && cur != NULL;
|
| 21011 |
|
| 21012 |
-
ggml_set_name(cur, ctx->infos[i].name.data);
|
| 21013 |
-
|
| 21014 |
if (!ok) {
|
| 21015 |
break;
|
| 21016 |
}
|
| 21017 |
|
|
|
|
|
|
|
| 21018 |
// point the data member to the appropriate location in the binary blob using the tensor infos
|
| 21019 |
if (!params.no_alloc) {
|
| 21020 |
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
|
|
|
|
| 4594 |
|
| 4595 |
// ggml_mul_mat_id
|
| 4596 |
|
| 4597 |
+
/*
|
| 4598 |
+
c = ggml_mul_mat_id(ctx, as, b, ids);
|
| 4599 |
+
|
| 4600 |
+
as -> [cols, rows, n_expert]
|
| 4601 |
+
ids -> [n_experts_used, n_tokens] (i32)
|
| 4602 |
+
b -> [cols, n_expert_used, n_tokens]
|
| 4603 |
+
c -> [cols, n_expert_used, n_tokens]
|
| 4604 |
+
|
| 4605 |
+
in b, n_experts_used can be broadcasted to match the n_expert_used of ids
|
| 4606 |
+
|
| 4607 |
+
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
|
| 4608 |
+
*/
|
| 4609 |
struct ggml_tensor * ggml_mul_mat_id(
|
| 4610 |
struct ggml_context * ctx,
|
| 4611 |
struct ggml_tensor * as,
|
| 4612 |
+
struct ggml_tensor * b,
|
| 4613 |
+
struct ggml_tensor * ids) {
|
| 4614 |
+
GGML_ASSERT(!ggml_is_transposed(as));
|
|
|
|
| 4615 |
GGML_ASSERT(ids->type == GGML_TYPE_I32);
|
| 4616 |
+
|
| 4617 |
+
GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
|
| 4618 |
+
GGML_ASSERT(b->ne[3] == 1); // b is 3d
|
| 4619 |
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
|
| 4620 |
+
GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
|
|
|
|
|
|
|
| 4621 |
GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
|
| 4622 |
+
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
|
| 4623 |
|
| 4624 |
bool is_node = false;
|
| 4625 |
|
|
|
|
| 4627 |
is_node = true;
|
| 4628 |
}
|
| 4629 |
|
| 4630 |
+
const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
|
| 4631 |
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
| 4632 |
|
|
|
|
|
|
|
| 4633 |
result->op = GGML_OP_MUL_MAT_ID;
|
| 4634 |
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
| 4635 |
result->src[0] = as;
|
|
|
|
| 11080 |
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
| 11081 |
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
| 11082 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11083 |
// we don't support permuted src0 or src1
|
| 11084 |
GGML_ASSERT(nb00 == ggml_type_size(type));
|
| 11085 |
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
|
|
|
| 11090 |
GGML_ASSERT(nb1 <= nb2);
|
| 11091 |
GGML_ASSERT(nb2 <= nb3);
|
| 11092 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11093 |
// row groups
|
| 11094 |
+
const int n_ids = ids->ne[0]; // n_expert_used
|
| 11095 |
+
const int n_as = ne02; // n_expert
|
| 11096 |
|
| 11097 |
char * wdata_src1_end = (src1->type == vec_dot_type) ?
|
| 11098 |
(char *) params->wdata :
|
| 11099 |
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
|
| 11100 |
|
| 11101 |
+
struct mmid_row_mapping {
|
| 11102 |
+
int32_t i1;
|
| 11103 |
+
int32_t i2;
|
| 11104 |
+
};
|
| 11105 |
|
| 11106 |
+
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
|
| 11107 |
+
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
|
| 11108 |
|
| 11109 |
if (params->type == GGML_TASK_TYPE_INIT) {
|
| 11110 |
if (ith != 0) {
|
|
|
|
| 11130 |
// initialize matrix_row_counts
|
| 11131 |
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
|
| 11132 |
|
| 11133 |
+
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
|
| 11134 |
+
|
| 11135 |
// group rows by src0 matrix
|
| 11136 |
+
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
|
| 11137 |
+
for (int id = 0; id < n_ids; ++id) {
|
| 11138 |
+
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
|
| 11139 |
+
|
| 11140 |
+
assert(i02 >= 0 && i02 < n_as);
|
| 11141 |
|
| 11142 |
+
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
|
| 11143 |
+
matrix_row_counts[i02] += 1;
|
| 11144 |
+
}
|
| 11145 |
}
|
| 11146 |
|
| 11147 |
return;
|
|
|
|
| 11159 |
continue;
|
| 11160 |
}
|
| 11161 |
|
| 11162 |
+
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
|
| 11163 |
|
| 11164 |
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
| 11165 |
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
| 11166 |
|
| 11167 |
+
const int64_t nr0 = ne01; // src0 rows
|
| 11168 |
+
const int64_t nr1 = cne1; // src1 rows
|
|
|
|
|
|
|
| 11169 |
|
| 11170 |
// distribute the thread work across the inner or outer loop based on which one is larger
|
| 11171 |
|
|
|
|
| 11184 |
const int64_t ir110 = dr1*ith1;
|
| 11185 |
const int64_t ir111 = MIN(ir110 + dr1, nr1);
|
| 11186 |
|
|
|
|
|
|
|
| 11187 |
// threads with no work simply yield (not sure if it helps)
|
| 11188 |
+
//if (ir010 >= ir011 || ir110 >= ir111) {
|
| 11189 |
+
// sched_yield();
|
| 11190 |
+
// continue;
|
| 11191 |
+
//}
|
| 11192 |
|
| 11193 |
// block-tiling attempt
|
| 11194 |
const int64_t blck_0 = 16;
|
|
|
|
| 11200 |
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
| 11201 |
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
| 11202 |
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
| 11203 |
+
const int64_t _i12 = ir1; // logical row index for this expert
|
|
|
|
|
|
|
|
|
|
| 11204 |
|
| 11205 |
+
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
|
| 11206 |
+
const int id = row_mapping.i1; // selected expert index
|
|
|
|
| 11207 |
|
| 11208 |
+
const int64_t i11 = id % ne11;
|
| 11209 |
+
const int64_t i12 = row_mapping.i2; // row index in src1
|
|
|
|
| 11210 |
|
| 11211 |
+
const int64_t i1 = id; // selected expert index
|
| 11212 |
+
const int64_t i2 = i12; // row
|
| 11213 |
|
| 11214 |
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
| 11215 |
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
|
|
|
| 11217 |
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
| 11218 |
const char * src1_col = (const char *) wdata +
|
| 11219 |
(src1_cont || src1->type != vec_dot_type
|
| 11220 |
+
? (i11 + i12*ne11)*row_size
|
| 11221 |
+
: (i11*nb11 + i12*nb12));
|
| 11222 |
|
| 11223 |
+
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
|
| 11224 |
|
| 11225 |
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
| 11226 |
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
| 11227 |
//}
|
| 11228 |
|
| 11229 |
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
| 11230 |
+
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
|
| 11231 |
}
|
| 11232 |
+
|
| 11233 |
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
| 11234 |
}
|
| 11235 |
}
|
| 11236 |
}
|
| 11237 |
}
|
| 11238 |
|
| 11239 |
+
#undef MMID_MATRIX_ROW
|
| 11240 |
}
|
| 11241 |
|
| 11242 |
// ggml_compute_forward_out_prod
|
|
|
|
| 18584 |
const int n_as = src0->ne[2];
|
| 18585 |
cur += GGML_PAD(cur, sizeof(int64_t)); // align
|
| 18586 |
cur += n_as * sizeof(int64_t); // matrix_row_counts
|
| 18587 |
+
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
|
| 18588 |
} break;
|
| 18589 |
case GGML_OP_OUT_PROD:
|
| 18590 |
{
|
|
|
|
| 21010 |
|
| 21011 |
ok = ok && cur != NULL;
|
| 21012 |
|
|
|
|
|
|
|
| 21013 |
if (!ok) {
|
| 21014 |
break;
|
| 21015 |
}
|
| 21016 |
|
| 21017 |
+
ggml_set_name(cur, ctx->infos[i].name.data);
|
| 21018 |
+
|
| 21019 |
// point the data member to the appropriate location in the binary blob using the tensor infos
|
| 21020 |
if (!params.no_alloc) {
|
| 21021 |
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file
|
ggml.h
CHANGED
|
@@ -1170,13 +1170,11 @@ extern "C" {
|
|
| 1170 |
enum ggml_prec prec);
|
| 1171 |
|
| 1172 |
// indirect matrix multiplication
|
| 1173 |
-
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
|
| 1174 |
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
| 1175 |
struct ggml_context * ctx,
|
| 1176 |
struct ggml_tensor * as,
|
| 1177 |
-
struct ggml_tensor *
|
| 1178 |
-
|
| 1179 |
-
struct ggml_tensor * b);
|
| 1180 |
|
| 1181 |
// A: m columns, n rows,
|
| 1182 |
// B: p columns, n rows,
|
|
|
|
| 1170 |
enum ggml_prec prec);
|
| 1171 |
|
| 1172 |
// indirect matrix multiplication
|
|
|
|
| 1173 |
GGML_API struct ggml_tensor * ggml_mul_mat_id(
|
| 1174 |
struct ggml_context * ctx,
|
| 1175 |
struct ggml_tensor * as,
|
| 1176 |
+
struct ggml_tensor * b,
|
| 1177 |
+
struct ggml_tensor * ids);
|
|
|
|
| 1178 |
|
| 1179 |
// A: m columns, n rows,
|
| 1180 |
// B: p columns, n rows,
|