Spaces:
Sleeping
Sleeping
Łukasz Ślusarczyk
commited on
Commit
·
2008e08
1
Parent(s):
03048ea
sycl: use oneDNN for matrices multiplication (llama/12972)
Browse files- ggml/CMakeLists.txt +1 -0
- ggml/src/ggml-sycl/CMakeLists.txt +26 -22
- ggml/src/ggml-sycl/gemm.hpp +37 -8
- ggml/src/ggml-sycl/ggml-sycl.cpp +127 -67
ggml/CMakeLists.txt
CHANGED
|
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
|
|
| 193 |
option(GGML_SYCL "ggml: use SYCL" OFF)
|
| 194 |
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
| 195 |
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
|
|
|
| 196 |
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
| 197 |
"ggml: sycl target device")
|
| 198 |
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
|
|
|
| 193 |
option(GGML_SYCL "ggml: use SYCL" OFF)
|
| 194 |
option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
|
| 195 |
option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
|
| 196 |
+
option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
|
| 197 |
set (GGML_SYCL_TARGET "INTEL" CACHE STRING
|
| 198 |
"ggml: sycl target device")
|
| 199 |
set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
|
ggml/src/ggml-sycl/CMakeLists.txt
CHANGED
|
@@ -49,34 +49,38 @@ endif()
|
|
| 49 |
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
| 50 |
|
| 51 |
# Link against oneDNN
|
| 52 |
-
find_package(DNNL)
|
| 53 |
set(GGML_SYCL_DNNL 0)
|
| 54 |
-
if(
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
| 60 |
endif()
|
| 61 |
-
endif()
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
else()
|
| 73 |
-
message(
|
| 74 |
-
"oneDNN must be compiled for the same target as llama.cpp.
|
| 75 |
-
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
| 76 |
-
Disabling oneDNN support.")
|
| 77 |
endif()
|
| 78 |
else()
|
| 79 |
-
message(STATUS "oneDNN
|
| 80 |
endif()
|
| 81 |
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
|
| 82 |
|
|
|
|
| 49 |
target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
|
| 50 |
|
| 51 |
# Link against oneDNN
|
|
|
|
| 52 |
set(GGML_SYCL_DNNL 0)
|
| 53 |
+
if(GGML_SYCL_DNN)
|
| 54 |
+
find_package(DNNL)
|
| 55 |
+
if(DNNL_FOUND)
|
| 56 |
+
if (NOT DEFINED DNNL_GPU_VENDOR)
|
| 57 |
+
# default to intel target
|
| 58 |
+
set(DNNL_GPU_VENDOR "INTEL")
|
| 59 |
+
if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
|
| 60 |
+
message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
|
| 61 |
+
endif()
|
| 62 |
endif()
|
|
|
|
| 63 |
|
| 64 |
+
# Verify oneDNN was compiled for the same target as llama
|
| 65 |
+
if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
|
| 66 |
+
target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
|
| 67 |
+
set(GGML_SYCL_DNNL 1)
|
| 68 |
+
get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
|
| 69 |
+
foreach(CONFIG ${CONFIGS})
|
| 70 |
+
get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
|
| 71 |
+
message(STATUS "Found oneDNN: ${DNNL_LIB}")
|
| 72 |
+
endforeach()
|
| 73 |
+
else()
|
| 74 |
+
message(WARNING
|
| 75 |
+
"oneDNN must be compiled for the same target as llama.cpp.
|
| 76 |
+
llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
|
| 77 |
+
Disabling oneDNN support.")
|
| 78 |
+
endif()
|
| 79 |
else()
|
| 80 |
+
message(STATUS "oneDNN not found, disabling oneDNN support")
|
|
|
|
|
|
|
|
|
|
| 81 |
endif()
|
| 82 |
else()
|
| 83 |
+
message(STATUS "oneDNN support disabled by the user")
|
| 84 |
endif()
|
| 85 |
target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
|
| 86 |
|
ggml/src/ggml-sycl/gemm.hpp
CHANGED
|
@@ -32,16 +32,36 @@ public:
|
|
| 32 |
else static_assert(0);
|
| 33 |
}
|
| 34 |
|
| 35 |
-
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
auto stream = ctx.stream_dnnl(q);
|
| 38 |
auto eng = ctx.engine_dnnl(q);
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
dnnl::memory::dims
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
dnnl::primitive_attr primitive_attr;
|
| 47 |
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
|
@@ -63,6 +83,15 @@ public:
|
|
| 63 |
|
| 64 |
matmul_prim.execute(stream, matmul_args);
|
| 65 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
};
|
| 67 |
|
| 68 |
#endif
|
|
|
|
| 32 |
else static_assert(0);
|
| 33 |
}
|
| 34 |
|
| 35 |
+
// matrix A has m rows, k columns
|
| 36 |
+
// matrix B has k rows, n columns
|
| 37 |
+
// nra - number of elements to skip when moving into next row in A
|
| 38 |
+
// nrb - number of elements to skip when moving into next row in B
|
| 39 |
+
// nca - number of elements to skip when moving into next column in A
|
| 40 |
+
// ncb - number of elements to skip when moving into next column in B
|
| 41 |
+
// stride_a - number of elements to skip when moving to next A matrix
|
| 42 |
+
// stride_b - number of elements to skip when moving to next B matrix
|
| 43 |
+
// batches_a - number of A matrices
|
| 44 |
+
// batches_b - number of B matrices
|
| 45 |
+
static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
| 46 |
+
const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
|
| 47 |
+
const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
|
| 48 |
+
void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
|
| 49 |
+
|
| 50 |
auto stream = ctx.stream_dnnl(q);
|
| 51 |
auto eng = ctx.engine_dnnl(q);
|
| 52 |
+
|
| 53 |
+
// { # strides, # rows, # columns }
|
| 54 |
+
dnnl::memory::dims a_dims = { batches_a, m, k };
|
| 55 |
+
dnnl::memory::dims b_dims = { batches_b, k, n };
|
| 56 |
+
dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
|
| 57 |
+
|
| 58 |
+
// { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
|
| 59 |
+
dnnl::memory::dims a_strides = { stride_a, nra, nca };
|
| 60 |
+
dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
|
| 61 |
+
|
| 62 |
+
const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
|
| 63 |
+
const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
|
| 64 |
+
const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
|
| 65 |
|
| 66 |
dnnl::primitive_attr primitive_attr;
|
| 67 |
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
|
|
|
|
| 83 |
|
| 84 |
matmul_prim.execute(stream, matmul_args);
|
| 85 |
}
|
| 86 |
+
|
| 87 |
+
// matrices A and B are column major, both having k rows
|
| 88 |
+
// matrix A has m column, matrix B has n columns
|
| 89 |
+
// output: column major matrix C = A transposed * B
|
| 90 |
+
static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
|
| 91 |
+
const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
|
| 92 |
+
|
| 93 |
+
gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
|
| 94 |
+
}
|
| 95 |
};
|
| 96 |
|
| 97 |
#endif
|
ggml/src/ggml-sycl/ggml-sycl.cpp
CHANGED
|
@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
|
|
| 49 |
int g_ggml_sycl_debug = 0;
|
| 50 |
int g_ggml_sycl_disable_optimize = 0;
|
| 51 |
int g_ggml_sycl_disable_graph = 0;
|
|
|
|
| 52 |
int g_ggml_sycl_prioritize_dmmv = 0;
|
| 53 |
|
| 54 |
static ggml_sycl_device_info ggml_sycl_init() {
|
|
@@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
|
|
| 196 |
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
| 197 |
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
| 198 |
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
|
|
|
| 199 |
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
| 200 |
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
| 201 |
GGML_LOG_INFO("Running with Environment Variables:\n");
|
| 202 |
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
| 203 |
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
|
|
|
| 204 |
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
| 206 |
GGML_LOG_INFO("Build with Macros:\n");
|
| 207 |
#if defined(GGML_SYCL_FORCE_MMQ)
|
|
@@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 1985 |
|
| 1986 |
const int64_t ne00 = src0->ne[0];
|
| 1987 |
const int64_t ne10 = src1->ne[0];
|
| 1988 |
-
|
| 1989 |
|
| 1990 |
const int64_t row_diff = row_high - row_low;
|
| 1991 |
|
| 1992 |
int id;
|
| 1993 |
SYCL_CHECK(
|
| 1994 |
CHECK_TRY_ERROR(id = get_current_device_id()));
|
| 1995 |
-
|
| 1996 |
-
const int64_t ne0 = dst->ne[0];
|
| 1997 |
// the main device has a larger memory buffer to hold the results from all GPUs
|
| 1998 |
// ldc == nrows of the matrix that cuBLAS writes into
|
| 1999 |
-
int ldc = id == ctx.device ? ne0 : row_diff;
|
| 2000 |
-
#endif
|
| 2001 |
|
| 2002 |
#ifdef GGML_SYCL_F16
|
| 2003 |
bool use_fp16 = true; // TODO(Yu) SYCL capability check
|
|
@@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2033 |
: src1_as_f16.get();
|
| 2034 |
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
| 2035 |
|
| 2036 |
-
#if
|
| 2037 |
-
|
| 2038 |
-
|
| 2039 |
-
|
| 2040 |
-
|
| 2041 |
-
|
| 2042 |
-
|
| 2043 |
-
|
| 2044 |
-
|
| 2045 |
-
dpct::library_data_t::real_half)));
|
| 2046 |
-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
| 2047 |
-
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
| 2048 |
-
#else
|
| 2049 |
-
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
|
| 2050 |
-
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
| 2051 |
-
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
| 2052 |
-
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
| 2053 |
-
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
| 2054 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2055 |
}
|
| 2056 |
else {
|
| 2057 |
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
|
@@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
|
|
| 2072 |
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
| 2073 |
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
| 2074 |
|
| 2075 |
-
#if
|
| 2076 |
-
|
| 2077 |
-
|
| 2078 |
-
|
| 2079 |
-
|
| 2080 |
-
|
| 2081 |
-
|
| 2082 |
-
#else
|
| 2083 |
-
DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
| 2084 |
-
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
| 2085 |
-
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
| 2086 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2087 |
}
|
| 2088 |
GGML_UNUSED(dst);
|
| 2089 |
GGML_UNUSED(src1_ddq_i);
|
|
@@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
|
|
| 2697 |
std::exit(1);
|
| 2698 |
}
|
| 2699 |
|
| 2700 |
-
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16,
|
| 2701 |
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
| 2702 |
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
| 2703 |
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
|
@@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
|
|
| 2713 |
|
| 2714 |
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
| 2715 |
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
| 2716 |
-
uint8_t * dst_bytes =
|
| 2717 |
|
| 2718 |
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
| 2719 |
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
|
@@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
| 2726 |
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 2727 |
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
| 2728 |
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
|
|
|
| 2729 |
|
| 2730 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 2731 |
|
|
@@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
| 2766 |
}
|
| 2767 |
|
| 2768 |
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
| 2769 |
-
char * dst_t = reinterpret_cast<char *>(dst_ddf);
|
| 2770 |
|
| 2771 |
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
| 2772 |
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
|
@@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
|
|
| 2783 |
|
| 2784 |
GGML_ASSERT(ne12 % ne02 == 0);
|
| 2785 |
GGML_ASSERT(ne13 % ne03 == 0);
|
|
|
|
|
|
|
| 2786 |
|
| 2787 |
// broadcast factors
|
| 2788 |
const int64_t r2 = ne12 / ne02;
|
| 2789 |
const int64_t r3 = ne13 / ne03;
|
| 2790 |
|
| 2791 |
-
|
| 2792 |
-
|
| 2793 |
-
|
| 2794 |
-
|
| 2795 |
-
|
| 2796 |
-
|
| 2797 |
-
|
| 2798 |
-
|
| 2799 |
-
|
| 2800 |
-
|
| 2801 |
-
|
| 2802 |
-
|
| 2803 |
-
|
| 2804 |
-
|
| 2805 |
-
|
| 2806 |
-
|
| 2807 |
-
|
| 2808 |
-
|
| 2809 |
-
|
| 2810 |
-
|
| 2811 |
-
|
| 2812 |
-
|
| 2813 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2814 |
});
|
| 2815 |
-
});
|
| 2816 |
|
| 2817 |
-
|
| 2818 |
-
|
| 2819 |
-
|
| 2820 |
-
|
| 2821 |
-
|
|
|
|
| 2822 |
}
|
| 2823 |
} catch (const sycl::exception & exc) {
|
| 2824 |
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
|
@@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
|
|
| 3713 |
return GGML_STATUS_SUCCESS;
|
| 3714 |
}
|
| 3715 |
|
| 3716 |
-
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
|
|
|
|
| 3717 |
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
| 3718 |
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
| 3719 |
model_sycl_graph.end_recording();
|
|
|
|
| 49 |
int g_ggml_sycl_debug = 0;
|
| 50 |
int g_ggml_sycl_disable_optimize = 0;
|
| 51 |
int g_ggml_sycl_disable_graph = 0;
|
| 52 |
+
int g_ggml_sycl_disable_dnn = 0;
|
| 53 |
int g_ggml_sycl_prioritize_dmmv = 0;
|
| 54 |
|
| 55 |
static ggml_sycl_device_info ggml_sycl_init() {
|
|
|
|
| 197 |
g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
|
| 198 |
g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
|
| 199 |
g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
|
| 200 |
+
g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
|
| 201 |
g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
|
| 202 |
GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
|
| 203 |
GGML_LOG_INFO("Running with Environment Variables:\n");
|
| 204 |
GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
|
| 205 |
GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
|
| 206 |
+
#ifdef GGML_SYCL_GRAPH
|
| 207 |
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
|
| 208 |
+
#else
|
| 209 |
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
|
| 210 |
+
#endif
|
| 211 |
+
#if GGML_SYCL_DNNL
|
| 212 |
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
|
| 213 |
+
#else
|
| 214 |
+
GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
|
| 215 |
+
#endif
|
| 216 |
GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
|
| 217 |
GGML_LOG_INFO("Build with Macros:\n");
|
| 218 |
#if defined(GGML_SYCL_FORCE_MMQ)
|
|
|
|
| 1996 |
|
| 1997 |
const int64_t ne00 = src0->ne[0];
|
| 1998 |
const int64_t ne10 = src1->ne[0];
|
| 1999 |
+
GGML_ASSERT(ne00 == ne10);
|
| 2000 |
|
| 2001 |
const int64_t row_diff = row_high - row_low;
|
| 2002 |
|
| 2003 |
int id;
|
| 2004 |
SYCL_CHECK(
|
| 2005 |
CHECK_TRY_ERROR(id = get_current_device_id()));
|
| 2006 |
+
|
| 2007 |
+
const int64_t ne0 = dst->ne[0]; // used by MKL only
|
| 2008 |
// the main device has a larger memory buffer to hold the results from all GPUs
|
| 2009 |
// ldc == nrows of the matrix that cuBLAS writes into
|
| 2010 |
+
int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
|
|
|
|
| 2011 |
|
| 2012 |
#ifdef GGML_SYCL_F16
|
| 2013 |
bool use_fp16 = true; // TODO(Yu) SYCL capability check
|
|
|
|
| 2043 |
: src1_as_f16.get();
|
| 2044 |
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
|
| 2045 |
|
| 2046 |
+
#if GGML_SYCL_DNNL
|
| 2047 |
+
if (!g_ggml_sycl_disable_dnn) {
|
| 2048 |
+
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
|
| 2049 |
+
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
|
| 2050 |
+
dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
|
| 2051 |
+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
| 2052 |
+
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
|
| 2053 |
+
}
|
| 2054 |
+
else
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2055 |
#endif
|
| 2056 |
+
{
|
| 2057 |
+
const sycl::half alpha_f16 = 1.0f;
|
| 2058 |
+
const sycl::half beta_f16 = 0.0f;
|
| 2059 |
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
|
| 2060 |
+
*stream, oneapi::math::transpose::trans,
|
| 2061 |
+
oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
|
| 2062 |
+
&alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
|
| 2063 |
+
src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
|
| 2064 |
+
dst_f16.get(), dpct::library_data_t::real_half, ldc,
|
| 2065 |
+
dpct::library_data_t::real_half)));
|
| 2066 |
+
const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
|
| 2067 |
+
to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
|
| 2068 |
+
}
|
| 2069 |
}
|
| 2070 |
else {
|
| 2071 |
// GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
|
|
|
|
| 2086 |
const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
|
| 2087 |
const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
|
| 2088 |
|
| 2089 |
+
#if GGML_SYCL_DNNL
|
| 2090 |
+
if (!g_ggml_sycl_disable_dnn) {
|
| 2091 |
+
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
|
| 2092 |
+
DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
|
| 2093 |
+
dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
|
| 2094 |
+
}
|
| 2095 |
+
else
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2096 |
#endif
|
| 2097 |
+
{
|
| 2098 |
+
const float alpha = 1.0f;
|
| 2099 |
+
const float beta = 0.0f;
|
| 2100 |
+
SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
|
| 2101 |
+
get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
|
| 2102 |
+
src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
|
| 2103 |
+
dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
|
| 2104 |
+
}
|
| 2105 |
}
|
| 2106 |
GGML_UNUSED(dst);
|
| 2107 |
GGML_UNUSED(src1_ddq_i);
|
|
|
|
| 2715 |
std::exit(1);
|
| 2716 |
}
|
| 2717 |
|
| 2718 |
+
static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
|
| 2719 |
const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
|
| 2720 |
size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
|
| 2721 |
int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
|
|
|
|
| 2731 |
|
| 2732 |
const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
|
| 2733 |
const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
|
| 2734 |
+
uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
|
| 2735 |
|
| 2736 |
ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
|
| 2737 |
ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
|
|
|
|
| 2744 |
GGML_ASSERT(!ggml_is_transposed(src1));
|
| 2745 |
GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
|
| 2746 |
GGML_ASSERT(src0->type == GGML_TYPE_F16);
|
| 2747 |
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
| 2748 |
|
| 2749 |
GGML_TENSOR_BINARY_OP_LOCALS
|
| 2750 |
|
|
|
|
| 2785 |
}
|
| 2786 |
|
| 2787 |
ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
|
|
|
|
| 2788 |
|
| 2789 |
dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
|
| 2790 |
dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
|
|
|
|
| 2801 |
|
| 2802 |
GGML_ASSERT(ne12 % ne02 == 0);
|
| 2803 |
GGML_ASSERT(ne13 % ne03 == 0);
|
| 2804 |
+
GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
|
| 2805 |
+
GGML_ASSERT(ne10 == ne00);
|
| 2806 |
|
| 2807 |
// broadcast factors
|
| 2808 |
const int64_t r2 = ne12 / ne02;
|
| 2809 |
const int64_t r3 = ne13 / ne03;
|
| 2810 |
|
| 2811 |
+
#if GGML_SYCL_DNNL
|
| 2812 |
+
if (!g_ggml_sycl_disable_dnn) {
|
| 2813 |
+
auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
|
| 2814 |
+
(const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
|
| 2815 |
+
|
| 2816 |
+
DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
|
| 2817 |
+
src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
|
| 2818 |
+
src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
|
| 2819 |
+
dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
|
| 2820 |
+
};
|
| 2821 |
+
|
| 2822 |
+
if (r2 == 1 && r3 == 1) {
|
| 2823 |
+
if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
| 2824 |
+
dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
|
| 2825 |
+
}
|
| 2826 |
+
else {
|
| 2827 |
+
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
| 2828 |
+
const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
|
| 2829 |
+
const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
|
| 2830 |
+
float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
|
| 2831 |
+
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
|
| 2832 |
+
}
|
| 2833 |
+
}
|
| 2834 |
+
} else {
|
| 2835 |
+
// iterate over batches from smaller set of matrices (matrix 0)
|
| 2836 |
+
for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
|
| 2837 |
+
for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
|
| 2838 |
+
const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
|
| 2839 |
+
const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
|
| 2840 |
+
float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
|
| 2841 |
+
dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
|
| 2842 |
+
}
|
| 2843 |
+
}
|
| 2844 |
+
}
|
| 2845 |
+
}
|
| 2846 |
+
else
|
| 2847 |
+
#endif
|
| 2848 |
+
{
|
| 2849 |
+
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
| 2850 |
+
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
| 2851 |
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
|
| 2852 |
+
oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
| 2853 |
+
src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
|
| 2854 |
+
src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
|
| 2855 |
+
mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
|
| 2856 |
+
} else {
|
| 2857 |
+
const int ne23 = ne12 * ne13;
|
| 2858 |
+
|
| 2859 |
+
ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
|
| 2860 |
+
ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
|
| 2861 |
+
ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
|
| 2862 |
+
|
| 2863 |
+
sycl::range<3> block_dims(1, ne12, ne13);
|
| 2864 |
+
queue->submit([&](sycl::handler & cgh) {
|
| 2865 |
+
const void ** ptrs_src_get = ptrs_src.get();
|
| 2866 |
+
void ** ptrs_dst_get = ptrs_dst.get();
|
| 2867 |
+
size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
|
| 2868 |
+
size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
|
| 2869 |
+
cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
|
| 2870 |
+
k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
|
| 2871 |
+
nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
|
| 2872 |
+
});
|
| 2873 |
});
|
|
|
|
| 2874 |
|
| 2875 |
+
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
|
| 2876 |
+
*queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
|
| 2877 |
+
(const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
|
| 2878 |
+
(const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
|
| 2879 |
+
(void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
|
| 2880 |
+
}
|
| 2881 |
}
|
| 2882 |
} catch (const sycl::exception & exc) {
|
| 2883 |
std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
|
|
|
|
| 3772 |
return GGML_STATUS_SUCCESS;
|
| 3773 |
}
|
| 3774 |
|
| 3775 |
+
sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
|
| 3776 |
+
|
| 3777 |
model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
|
| 3778 |
ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
|
| 3779 |
model_sycl_graph.end_recording();
|