Łukasz Ślusarczyk commited on
Commit
2008e08
·
1 Parent(s): 03048ea

sycl: use oneDNN for matrices multiplication (llama/12972)

Browse files
ggml/CMakeLists.txt CHANGED
@@ -193,6 +193,7 @@ option(GGML_RPC "ggml: use RPC"
193
  option(GGML_SYCL "ggml: use SYCL" OFF)
194
  option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
195
  option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
 
196
  set (GGML_SYCL_TARGET "INTEL" CACHE STRING
197
  "ggml: sycl target device")
198
  set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
 
193
  option(GGML_SYCL "ggml: use SYCL" OFF)
194
  option(GGML_SYCL_F16 "ggml: use 16 bit floats for sycl calculations" OFF)
195
  option(GGML_SYCL_GRAPH "ggml: enable graphs in the SYCL backend" ON)
196
+ option(GGML_SYCL_DNN "ggml: enable oneDNN in the SYCL backend" ON)
197
  set (GGML_SYCL_TARGET "INTEL" CACHE STRING
198
  "ggml: sycl target device")
199
  set (GGML_SYCL_DEVICE_ARCH "" CACHE STRING
ggml/src/ggml-sycl/CMakeLists.txt CHANGED
@@ -49,34 +49,38 @@ endif()
49
  target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
50
 
51
  # Link against oneDNN
52
- find_package(DNNL)
53
  set(GGML_SYCL_DNNL 0)
54
- if(DNNL_FOUND)
55
- if (NOT DEFINED DNNL_GPU_VENDOR)
56
- # default to intel target
57
- set(DNNL_GPU_VENDOR "INTEL")
58
- if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
59
- message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
 
 
 
60
  endif()
61
- endif()
62
 
63
- # Verify oneDNN was compiled for the same target as llama
64
- if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
65
- target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
66
- set(GGML_SYCL_DNNL 1)
67
- get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
68
- foreach(CONFIG ${CONFIGS})
69
- get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
70
- message(STATUS "Found oneDNN: ${DNNL_LIB}")
71
- endforeach()
 
 
 
 
 
 
72
  else()
73
- message(WARNING
74
- "oneDNN must be compiled for the same target as llama.cpp.
75
- llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
76
- Disabling oneDNN support.")
77
  endif()
78
  else()
79
- message(STATUS "oneDNN not found, disabling oneDNN support")
80
  endif()
81
  target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
82
 
 
49
  target_compile_options(ggml-sycl PRIVATE "-Wno-narrowing")
50
 
51
  # Link against oneDNN
 
52
  set(GGML_SYCL_DNNL 0)
53
+ if(GGML_SYCL_DNN)
54
+ find_package(DNNL)
55
+ if(DNNL_FOUND)
56
+ if (NOT DEFINED DNNL_GPU_VENDOR)
57
+ # default to intel target
58
+ set(DNNL_GPU_VENDOR "INTEL")
59
+ if(NOT "${GGML_SYCL_TARGET}" STREQUAL "INTEL")
60
+ message(WARNING "oneDNN builds bundled with oneapi release only support INTEL target")
61
+ endif()
62
  endif()
 
63
 
64
+ # Verify oneDNN was compiled for the same target as llama
65
+ if("${GGML_SYCL_TARGET}" STREQUAL "${DNNL_GPU_VENDOR}")
66
+ target_link_libraries(ggml-sycl PRIVATE DNNL::dnnl)
67
+ set(GGML_SYCL_DNNL 1)
68
+ get_target_property(CONFIGS DNNL::dnnl IMPORTED_CONFIGURATIONS)
69
+ foreach(CONFIG ${CONFIGS})
70
+ get_target_property(DNNL_LIB DNNL::dnnl IMPORTED_LOCATION_${CONFIG})
71
+ message(STATUS "Found oneDNN: ${DNNL_LIB}")
72
+ endforeach()
73
+ else()
74
+ message(WARNING
75
+ "oneDNN must be compiled for the same target as llama.cpp.
76
+ llama.cpp: ${GGML_SYCL_TARGET}, oneDNN: ${DNNL_GPU_VENDOR}.
77
+ Disabling oneDNN support.")
78
+ endif()
79
  else()
80
+ message(STATUS "oneDNN not found, disabling oneDNN support")
 
 
 
81
  endif()
82
  else()
83
+ message(STATUS "oneDNN support disabled by the user")
84
  endif()
85
  target_compile_definitions(ggml-sycl PRIVATE GGML_SYCL_DNNL=${GGML_SYCL_DNNL})
86
 
ggml/src/ggml-sycl/gemm.hpp CHANGED
@@ -32,16 +32,36 @@ public:
32
  else static_assert(0);
33
  }
34
 
35
- static inline void row_gemm(ggml_backend_sycl_context & ctx, bool a_trans, bool b_trans, int m, int n, int k,
36
- const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  auto stream = ctx.stream_dnnl(q);
38
  auto eng = ctx.engine_dnnl(q);
39
- dnnl::memory::dims a_dims = { m, k };
40
- dnnl::memory::dims b_dims = { k, n };
41
- dnnl::memory::dims c_dims = { m, n };
42
- const auto a_in_md = dnnl::memory::desc(a_dims, at, a_trans ? tag::ba : tag::ab);
43
- const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_trans ? tag::ba : tag::ab);
44
- const auto c_md = dnnl::memory::desc(c_dims, ct, tag::ab);
 
 
 
 
 
 
 
45
 
46
  dnnl::primitive_attr primitive_attr;
47
  primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
@@ -63,6 +83,15 @@ public:
63
 
64
  matmul_prim.execute(stream, matmul_args);
65
  }
 
 
 
 
 
 
 
 
 
66
  };
67
 
68
  #endif
 
32
  else static_assert(0);
33
  }
34
 
35
+ // matrix A has m rows, k columns
36
+ // matrix B has k rows, n columns
37
+ // nra - number of elements to skip when moving into next row in A
38
+ // nrb - number of elements to skip when moving into next row in B
39
+ // nca - number of elements to skip when moving into next column in A
40
+ // ncb - number of elements to skip when moving into next column in B
41
+ // stride_a - number of elements to skip when moving to next A matrix
42
+ // stride_b - number of elements to skip when moving to next B matrix
43
+ // batches_a - number of A matrices
44
+ // batches_b - number of B matrices
45
+ static void gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
46
+ const void * a, dt at, dnnl_dim_t nra, dnnl_dim_t nca, dnnl_dim_t stride_a,
47
+ const void * b, dt bt, dnnl_dim_t nrb, dnnl_dim_t ncb, dnnl_dim_t stride_b,
48
+ void * c, dt ct, const queue_ptr & q, dnnl_dim_t batches_a, dnnl_dim_t batches_b) {
49
+
50
  auto stream = ctx.stream_dnnl(q);
51
  auto eng = ctx.engine_dnnl(q);
52
+
53
+ // { # strides, # rows, # columns }
54
+ dnnl::memory::dims a_dims = { batches_a, m, k };
55
+ dnnl::memory::dims b_dims = { batches_b, k, n };
56
+ dnnl::memory::dims c_dims = { std::max(batches_a, batches_b), m, n };
57
+
58
+ // { # elements to skip to next stride, # elements to skip to next row, # elements to skip to next column }
59
+ dnnl::memory::dims a_strides = { stride_a, nra, nca };
60
+ dnnl::memory::dims b_strides = { stride_b, nrb, ncb };
61
+
62
+ const auto a_in_md = dnnl::memory::desc(a_dims, at, a_strides);
63
+ const auto b_in_md = dnnl::memory::desc(b_dims, bt, b_strides);
64
+ const auto c_md = dnnl::memory::desc(c_dims, ct, tag::abc);
65
 
66
  dnnl::primitive_attr primitive_attr;
67
  primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
 
83
 
84
  matmul_prim.execute(stream, matmul_args);
85
  }
86
+
87
+ // matrices A and B are column major, both having k rows
88
+ // matrix A has m column, matrix B has n columns
89
+ // output: column major matrix C = A transposed * B
90
+ static void row_gemm(ggml_backend_sycl_context & ctx, int m, int n, int k,
91
+ const void * a, dt at, const void * b, dt bt, void * c, dt ct, const queue_ptr & q) {
92
+
93
+ gemm(ctx, m, n, k, a, at, k, 1, k * m, b, bt, 1, k, n * k, c, ct, q, 1, 1);
94
+ }
95
  };
96
 
97
  #endif
ggml/src/ggml-sycl/ggml-sycl.cpp CHANGED
@@ -49,6 +49,7 @@ static bool g_sycl_loaded = false;
49
  int g_ggml_sycl_debug = 0;
50
  int g_ggml_sycl_disable_optimize = 0;
51
  int g_ggml_sycl_disable_graph = 0;
 
52
  int g_ggml_sycl_prioritize_dmmv = 0;
53
 
54
  static ggml_sycl_device_info ggml_sycl_init() {
@@ -196,12 +197,22 @@ static void ggml_check_sycl() try {
196
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
197
  g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
198
  g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
 
199
  g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
200
  GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
201
  GGML_LOG_INFO("Running with Environment Variables:\n");
202
  GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
203
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
 
204
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
 
 
 
 
 
 
 
 
205
  GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
206
  GGML_LOG_INFO("Build with Macros:\n");
207
  #if defined(GGML_SYCL_FORCE_MMQ)
@@ -1985,19 +1996,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
1985
 
1986
  const int64_t ne00 = src0->ne[0];
1987
  const int64_t ne10 = src1->ne[0];
1988
-
1989
 
1990
  const int64_t row_diff = row_high - row_low;
1991
 
1992
  int id;
1993
  SYCL_CHECK(
1994
  CHECK_TRY_ERROR(id = get_current_device_id()));
1995
- #if !GGML_SYCL_DNNL
1996
- const int64_t ne0 = dst->ne[0];
1997
  // the main device has a larger memory buffer to hold the results from all GPUs
1998
  // ldc == nrows of the matrix that cuBLAS writes into
1999
- int ldc = id == ctx.device ? ne0 : row_diff;
2000
- #endif
2001
 
2002
  #ifdef GGML_SYCL_F16
2003
  bool use_fp16 = true; // TODO(Yu) SYCL capability check
@@ -2033,25 +2043,29 @@ inline void ggml_sycl_op_mul_mat_sycl(
2033
  : src1_as_f16.get();
2034
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2035
 
2036
- #if !GGML_SYCL_DNNL
2037
- const sycl::half alpha_f16 = 1.0f;
2038
- const sycl::half beta_f16 = 0.0f;
2039
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2040
- *stream, oneapi::math::transpose::trans,
2041
- oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2042
- &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2043
- src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2044
- dst_f16.get(), dpct::library_data_t::real_half, ldc,
2045
- dpct::library_data_t::real_half)));
2046
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2047
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2048
- #else
2049
- DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ptr,
2050
- DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2051
- dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2052
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2053
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2054
  #endif
 
 
 
 
 
 
 
 
 
 
 
 
 
2055
  }
2056
  else {
2057
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
@@ -2072,18 +2086,22 @@ inline void ggml_sycl_op_mul_mat_sycl(
2072
  const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
2073
  const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
2074
 
2075
- #if !GGML_SYCL_DNNL
2076
- const float alpha = 1.0f;
2077
- const float beta = 0.0f;
2078
- SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2079
- get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2080
- src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2081
- dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2082
- #else
2083
- DnnlGemmWrapper::row_gemm(ctx, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i,
2084
- DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2085
- dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2086
  #endif
 
 
 
 
 
 
 
 
2087
  }
2088
  GGML_UNUSED(dst);
2089
  GGML_UNUSED(src1_ddq_i);
@@ -2697,7 +2715,7 @@ catch (sycl::exception const &exc) {
2697
  std::exit(1);
2698
  }
2699
 
2700
- static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, char * dst,
2701
  const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2702
  size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2703
  int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
@@ -2713,7 +2731,7 @@ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::h
2713
 
2714
  const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2715
  const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2716
- uint8_t * dst_bytes = reinterpret_cast<uint8_t *>(dst);
2717
 
2718
  ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2719
  ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
@@ -2726,6 +2744,7 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2726
  GGML_ASSERT(!ggml_is_transposed(src1));
2727
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2728
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
 
2729
 
2730
  GGML_TENSOR_BINARY_OP_LOCALS
2731
 
@@ -2766,7 +2785,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2766
  }
2767
 
2768
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
2769
- char * dst_t = reinterpret_cast<char *>(dst_ddf);
2770
 
2771
  dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
2772
  dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
@@ -2783,42 +2801,83 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx, cons
2783
 
2784
  GGML_ASSERT(ne12 % ne02 == 0);
2785
  GGML_ASSERT(ne13 % ne03 == 0);
 
 
2786
 
2787
  // broadcast factors
2788
  const int64_t r2 = ne12 / ne02;
2789
  const int64_t r3 = ne13 / ne03;
2790
 
2791
- if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2792
- // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2793
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2794
- oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2795
- src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2796
- src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_t,
2797
- mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2798
- } else {
2799
- const int ne23 = ne12 * ne13;
2800
-
2801
- ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
2802
- ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
2803
- ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2804
-
2805
- sycl::range<3> block_dims(1, ne12, ne13);
2806
- queue->submit([&](sycl::handler & cgh) {
2807
- const void ** ptrs_src_get = ptrs_src.get();
2808
- void ** ptrs_dst_get = ptrs_dst.get();
2809
- size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2810
- size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2811
- cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2812
- k_compute_batched_ptrs(src0_f16, src1_f16, dst_t, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2813
- nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2814
  });
2815
- });
2816
 
2817
- SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2818
- *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2819
- (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2820
- (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
2821
- (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
 
2822
  }
2823
  } catch (const sycl::exception & exc) {
2824
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
@@ -3713,7 +3772,8 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
3713
  return GGML_STATUS_SUCCESS;
3714
  }
3715
 
3716
- sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()));
 
3717
  model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
3718
  ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3719
  model_sycl_graph.end_recording();
 
49
  int g_ggml_sycl_debug = 0;
50
  int g_ggml_sycl_disable_optimize = 0;
51
  int g_ggml_sycl_disable_graph = 0;
52
+ int g_ggml_sycl_disable_dnn = 0;
53
  int g_ggml_sycl_prioritize_dmmv = 0;
54
 
55
  static ggml_sycl_device_info ggml_sycl_init() {
 
197
  g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
198
  g_ggml_sycl_disable_optimize= get_sycl_env("GGML_SYCL_DISABLE_OPT", 1);
199
  g_ggml_sycl_disable_graph = get_sycl_env("GGML_SYCL_DISABLE_GRAPH", 1);
200
+ g_ggml_sycl_disable_dnn = get_sycl_env("GGML_SYCL_DISABLE_DNN", 0);
201
  g_ggml_sycl_prioritize_dmmv = get_sycl_env("GGML_SYCL_PRIORITIZE_DMMV", 0);
202
  GGML_SYCL_DEBUG("[SYCL] call ggml_check_sycl\n");
203
  GGML_LOG_INFO("Running with Environment Variables:\n");
204
  GGML_LOG_INFO(" GGML_SYCL_DEBUG: %d\n", g_ggml_sycl_debug);
205
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_OPT: %d\n", g_ggml_sycl_disable_optimize);
206
+ #ifdef GGML_SYCL_GRAPH
207
  GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: %d\n", g_ggml_sycl_disable_graph);
208
+ #else
209
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_GRAPH: graph disabled by compile flag\n");
210
+ #endif
211
+ #if GGML_SYCL_DNNL
212
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: %d\n", g_ggml_sycl_disable_dnn);
213
+ #else
214
+ GGML_LOG_INFO(" GGML_SYCL_DISABLE_DNN: DNN disabled by compile flag\n");
215
+ #endif
216
  GGML_LOG_INFO(" GGML_SYCL_PRIORITIZE_DMMV: %d\n", g_ggml_sycl_prioritize_dmmv);
217
  GGML_LOG_INFO("Build with Macros:\n");
218
  #if defined(GGML_SYCL_FORCE_MMQ)
 
1996
 
1997
  const int64_t ne00 = src0->ne[0];
1998
  const int64_t ne10 = src1->ne[0];
1999
+ GGML_ASSERT(ne00 == ne10);
2000
 
2001
  const int64_t row_diff = row_high - row_low;
2002
 
2003
  int id;
2004
  SYCL_CHECK(
2005
  CHECK_TRY_ERROR(id = get_current_device_id()));
2006
+
2007
+ const int64_t ne0 = dst->ne[0]; // used by MKL only
2008
  // the main device has a larger memory buffer to hold the results from all GPUs
2009
  // ldc == nrows of the matrix that cuBLAS writes into
2010
+ int ldc = id == ctx.device ? ne0 : row_diff; // used by MKL only
 
2011
 
2012
  #ifdef GGML_SYCL_F16
2013
  bool use_fp16 = true; // TODO(Yu) SYCL capability check
 
2043
  : src1_as_f16.get();
2044
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
2045
 
2046
+ #if GGML_SYCL_DNNL
2047
+ if (!g_ggml_sycl_disable_dnn) {
2048
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
2049
+ DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
2050
+ dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
2051
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2052
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
2053
+ }
2054
+ else
 
 
 
 
 
 
 
 
 
2055
  #endif
2056
+ {
2057
+ const sycl::half alpha_f16 = 1.0f;
2058
+ const sycl::half beta_f16 = 0.0f;
2059
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
2060
+ *stream, oneapi::math::transpose::trans,
2061
+ oneapi::math::transpose::nontrans, row_diff, src1_ncols, ne10,
2062
+ &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
2063
+ src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
2064
+ dst_f16.get(), dpct::library_data_t::real_half, ldc,
2065
+ dpct::library_data_t::real_half)));
2066
+ const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
2067
+ to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
2068
+ }
2069
  }
2070
  else {
2071
  // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
 
2086
  const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
2087
  const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
2088
 
2089
+ #if GGML_SYCL_DNNL
2090
+ if (!g_ggml_sycl_disable_dnn) {
2091
+ DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ddf1_i,
2092
+ DnnlGemmWrapper::to_dt<float>(), src0_ddf_i, DnnlGemmWrapper::to_dt<float>(),
2093
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
2094
+ }
2095
+ else
 
 
 
 
2096
  #endif
2097
+ {
2098
+ const float alpha = 1.0f;
2099
+ const float beta = 0.0f;
2100
+ SYCL_CHECK(CHECK_TRY_ERROR(oneapi::math::blas::column_major::gemm(
2101
+ get_onemath_backend(*stream), oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, row_diff,
2102
+ src1_ncols, ne10, dpct::get_value(&alpha, *stream), src0_ddf_i, ne00, src1_ddf1_i, ne10,
2103
+ dpct::get_value(&beta, *stream), dst_dd_i, ldc)));
2104
+ }
2105
  }
2106
  GGML_UNUSED(dst);
2107
  GGML_UNUSED(src1_ddq_i);
 
2715
  std::exit(1);
2716
  }
2717
 
2718
+ static void k_compute_batched_ptrs(const sycl::half * src0_as_f16, const sycl::half * src1_as_f16, void * dst,
2719
  const void ** ptrs_src, void ** ptrs_dst, int64_t ne12, int64_t ne13, int64_t ne23,
2720
  size_t nb02, size_t nb03, size_t nb12, size_t nb13, size_t nbd2, size_t nbd3,
2721
  int64_t r2, int64_t r3, const sycl::nd_item<3> & item_ct1) {
 
2731
 
2732
  const uint8_t * src0_bytes = reinterpret_cast<const uint8_t *>(src0_as_f16);
2733
  const uint8_t * src1_bytes = reinterpret_cast<const uint8_t *>(src1_as_f16);
2734
+ uint8_t * dst_bytes = static_cast<uint8_t *>(dst);
2735
 
2736
  ptrs_src[0 * ne23 + i12 + i13 * ne12] = src0_bytes + i02 * nb02 + i03 * nb03;
2737
  ptrs_src[1 * ne23 + i12 + i13 * ne12] = src1_bytes + i12 * nb12 + i13 * nb13;
 
2744
  GGML_ASSERT(!ggml_is_transposed(src1));
2745
  GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer));
2746
  GGML_ASSERT(src0->type == GGML_TYPE_F16);
2747
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
2748
 
2749
  GGML_TENSOR_BINARY_OP_LOCALS
2750
 
 
2785
  }
2786
 
2787
  ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
 
2788
 
2789
  dpct::library_data_t mkl_compute_type = dpct::library_data_t::real_float;
2790
  dpct::library_data_t mkl_data_type = dpct::library_data_t::real_float;
 
2801
 
2802
  GGML_ASSERT(ne12 % ne02 == 0);
2803
  GGML_ASSERT(ne13 % ne03 == 0);
2804
+ GGML_ASSERT(ne01 == static_cast<int64_t>(nb1/nb0));
2805
+ GGML_ASSERT(ne10 == ne00);
2806
 
2807
  // broadcast factors
2808
  const int64_t r2 = ne12 / ne02;
2809
  const int64_t r3 = ne13 / ne03;
2810
 
2811
+ #if GGML_SYCL_DNNL
2812
+ if (!g_ggml_sycl_disable_dnn) {
2813
+ auto dnn_gemm = [&ctx, queue, ne11, ne01, ne10, nb00, nb01, nb02, s11, s12]
2814
+ (const sycl::half* src1, const sycl::half* src0, float* dst, const dnnl_dim_t batches_a, const dnnl_dim_t batches_b) {
2815
+
2816
+ DnnlGemmWrapper::gemm(ctx, ne11,ne01, ne10,
2817
+ src1, DnnlGemmWrapper::to_dt<sycl::half>(), s11, 1, s12,
2818
+ src0, DnnlGemmWrapper::to_dt<sycl::half>(), 1, nb01/nb00, nb02/nb00,
2819
+ dst, DnnlGemmWrapper::to_dt<float>(), queue, batches_a, batches_b);
2820
+ };
2821
+
2822
+ if (r2 == 1 && r3 == 1) {
2823
+ if (ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2824
+ dnn_gemm(src1_f16, src0_f16, dst_ddf, ne12*ne13, ne02 * ne03);
2825
+ }
2826
+ else {
2827
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2828
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie03*nb03)/sizeof(sycl::half)); // nb is in bytes
2829
+ const sycl::half* src1_f16_shifted = src1_f16 + ie03*s13;
2830
+ float* dst_shifted = dst_ddf + ((ie03*nb3)/sizeof(float));
2831
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, ne12, ne02);
2832
+ }
2833
+ }
2834
+ } else {
2835
+ // iterate over batches from smaller set of matrices (matrix 0)
2836
+ for (int64_t ie02 = 0; ie02 < ne02; ++ie02) {
2837
+ for (int64_t ie03 = 0; ie03 < ne03; ++ie03) {
2838
+ const sycl::half* src0_f16_shifted = src0_f16 + ((ie02*nb02 + ie03*nb03)/sizeof(sycl::half));
2839
+ const sycl::half* src1_f16_shifted = src1_f16 + ie02*s12*r2 + ie03*s13*r3;
2840
+ float* dst_shifted = dst_ddf + ((ie02*nb2*r2 + ie03*nb3*r3)/sizeof(float));
2841
+ dnn_gemm(src1_f16_shifted, src0_f16_shifted, dst_shifted, r2*r3, 1);
2842
+ }
2843
+ }
2844
+ }
2845
+ }
2846
+ else
2847
+ #endif
2848
+ {
2849
+ if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
2850
+ // there is no broadcast and src0, src1 are contiguous across dims 2, 3
2851
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(*queue, oneapi::math::transpose::trans,
2852
+ oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2853
+ src0_f16, dpct::library_data_t::real_half, nb01 / nb00, nb02 / nb00,
2854
+ src1_f16, dpct::library_data_t::real_half, s11, s12, beta, dst_ddf,
2855
+ mkl_data_type, ne0, ne1 * ne0, ne12 * ne13, mkl_compute_type)));
2856
+ } else {
2857
+ const int ne23 = ne12 * ne13;
2858
+
2859
+ ggml_sycl_pool_alloc<const void *> ptrs_src(ctx.pool(), 2 * ne23);
2860
+ ggml_sycl_pool_alloc<void *> ptrs_dst(ctx.pool(), 1 * ne23);
2861
+ ggml_sycl_pool_alloc<matrix_info_t<float>> matrix_info(ctx.host_pool(), 1);
2862
+
2863
+ sycl::range<3> block_dims(1, ne12, ne13);
2864
+ queue->submit([&](sycl::handler & cgh) {
2865
+ const void ** ptrs_src_get = ptrs_src.get();
2866
+ void ** ptrs_dst_get = ptrs_dst.get();
2867
+ size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : s12 * sizeof(sycl::half);
2868
+ size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : s13 * sizeof(sycl::half);
2869
+ cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims), [=](sycl::nd_item<3> item_ct1) {
2870
+ k_compute_batched_ptrs(src0_f16, src1_f16, dst_ddf, ptrs_src_get, ptrs_dst_get, ne12, ne13, ne23, nb02,
2871
+ nb03, nb12_scaled, nb13_scaled, nbd2, nbd3, r2, r3, item_ct1);
2872
+ });
2873
  });
 
2874
 
2875
+ SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
2876
+ *queue, oneapi::math::transpose::trans, oneapi::math::transpose::nontrans, ne01, ne11, ne10, alpha,
2877
+ (const void **) (ptrs_src.get() + 0 * ne23), dpct::library_data_t::real_half, nb01 / nb00,
2878
+ (const void **) (ptrs_src.get() + 1 * ne23), dpct::library_data_t::real_half, s11, beta,
2879
+ (void **) (ptrs_dst.get() + 0 * ne23), mkl_data_type, ne0, ne23, mkl_compute_type, matrix_info.get())));
2880
+ }
2881
  }
2882
  } catch (const sycl::exception & exc) {
2883
  std::cerr << exc.what() << "Exception caught at file:" << __FILE__ << ", line:" << __LINE__ << std::endl;
 
3772
  return GGML_STATUS_SUCCESS;
3773
  }
3774
 
3775
+ sycl_ex::command_graph model_sycl_graph(*(sycl_ctx->stream()), {sycl_ex::property::graph::assume_buffer_outlives_graph{}});
3776
+
3777
  model_sycl_graph.begin_recording(*(sycl_ctx->stream()));
3778
  ggml_backend_sycl_graph_compute_impl(sycl_ctx, cgraph);
3779
  model_sycl_graph.end_recording();