Diego Devesa ggerganov R0CKSTAR commited on
Commit
3dc93f3
·
1 Parent(s): 1741306

ggml : build backends as libraries (llama/10256)

Browse files

* ggml : build backends as libraries

---------

Signed-off-by: Xiaodong Ye <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: R0CKSTAR <[email protected]>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. ggml/CMakeLists.txt +7 -3
  2. ggml/include/ggml-amx.h +5 -5
  3. ggml/include/ggml-backend.h +14 -0
  4. ggml/include/ggml-blas.h +4 -4
  5. ggml/include/ggml-cann.h +8 -8
  6. ggml/include/ggml-cpu.h +64 -40
  7. ggml/include/ggml-cuda.h +12 -12
  8. ggml/include/ggml-kompute.h +4 -4
  9. ggml/include/ggml-metal.h +8 -8
  10. ggml/include/ggml-rpc.h +7 -7
  11. ggml/include/ggml-sycl.h +13 -13
  12. ggml/include/ggml-vulkan.h +9 -9
  13. ggml/include/ggml.h +5 -38
  14. ggml/src/ggml-aarch64.c +0 -0
  15. ggml/src/ggml-aarch64.h +0 -20
  16. ggml/src/ggml-amx/CMakeLists.txt +107 -0
  17. ggml/src/ggml-amx/common.h +2 -1
  18. ggml/src/ggml-amx/ggml-amx.cpp +449 -0
  19. ggml/src/ggml-amx/mmq.cpp +4 -3
  20. ggml/src/ggml-backend-reg.cpp +195 -0
  21. ggml/src/ggml-blas/CMakeLists.txt +91 -0
  22. ggml/src/ggml-blas/ggml-blas.cpp +514 -0
  23. ggml/src/ggml-cann/CMakeLists.txt +46 -0
  24. ggml/src/ggml-cann/ggml-cann.cpp +2128 -0
  25. ggml/src/ggml-cpu/CMakeLists.txt +244 -0
  26. ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
  27. ggml/src/ggml-cpu/ggml-cpu-aarch64.c +0 -0
  28. ggml/src/ggml-cpu/ggml-cpu-aarch64.h +27 -0
  29. ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  30. ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -0
  31. ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  32. ggml/src/ggml-cpu/ggml-cpu.c +0 -0
  33. ggml/src/ggml-cpu/ggml-cpu.cpp +575 -0
  34. ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
  35. ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
  36. ggml/src/ggml-cuda/common.cuh +25 -25
  37. ggml/src/ggml-cuda/fattn-common.cuh +2 -2
  38. ggml/src/ggml-cuda/fattn-tile-f16.cu +2 -2
  39. ggml/src/ggml-cuda/fattn-tile-f32.cu +2 -2
  40. ggml/src/ggml-cuda/fattn-vec-f16.cuh +2 -2
  41. ggml/src/ggml-cuda/fattn-vec-f32.cuh +2 -2
  42. ggml/src/ggml-cuda/fattn-wmma-f16.cuh +2 -2
  43. ggml/src/ggml-cuda/ggml-cuda.cu +0 -0
  44. ggml/src/ggml-cuda/ggml/CMakeLists.txt +165 -0
  45. ggml/src/ggml-cuda/mmq.cuh +11 -11
  46. ggml/src/ggml-cuda/mmvq.cu +4 -4
  47. ggml/src/ggml-cuda/sum.cu +2 -2
  48. ggml/src/ggml-hip/CMakeLists.txt +113 -0
  49. ggml/src/ggml-impl.h +267 -13
  50. ggml/src/ggml-kompute/CMakeLists.txt +162 -0
ggml/CMakeLists.txt CHANGED
@@ -116,6 +116,7 @@ endif()
116
 
117
  # ggml core
118
  set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
 
119
 
120
  # 3rd party libs / backends
121
  option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
@@ -141,7 +142,7 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
141
  option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
142
  option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
143
 
144
- option(GGML_HIPBLAS "ggml: use hipBLAS" OFF)
145
  option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
146
  option(GGML_VULKAN "ggml: use Vulkan" OFF)
147
  option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
@@ -238,12 +239,15 @@ set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
238
  install(TARGETS ggml PUBLIC_HEADER)
239
 
240
  if (BUILD_SHARED_LIBS)
241
- install(TARGETS ggml LIBRARY)
 
242
  endif()
243
 
 
244
  if (GGML_METAL)
 
245
  install(
246
- FILES src/ggml-metal.metal
247
  PERMISSIONS
248
  OWNER_READ
249
  OWNER_WRITE
 
116
 
117
  # ggml core
118
  set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
119
+ option(GGML_CPU "ggml: enable CPU backend" ON)
120
 
121
  # 3rd party libs / backends
122
  option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
 
142
  option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
143
  option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
144
 
145
+ option(GGML_HIP "ggml: use HIP" OFF)
146
  option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
147
  option(GGML_VULKAN "ggml: use Vulkan" OFF)
148
  option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
 
239
  install(TARGETS ggml PUBLIC_HEADER)
240
 
241
  if (BUILD_SHARED_LIBS)
242
+ install(TARGETS ggml LIBRARY)
243
+ install(TARGETS ggml-base LIBRARY)
244
  endif()
245
 
246
+ # FIXME: this should be done in the backend cmake files
247
  if (GGML_METAL)
248
+ # FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY?
249
  install(
250
+ FILES ggml/src/ggml-metal/ggml-metal.metal
251
  PERMISSIONS
252
  OWNER_READ
253
  OWNER_WRITE
ggml/include/ggml-amx.h CHANGED
@@ -9,16 +9,16 @@ extern "C" {
9
  #endif
10
 
11
  // buffer_type API
12
- GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
13
 
14
- GGML_API bool ggml_backend_is_amx(ggml_backend_t backend);
15
 
16
  // backend API
17
- GGML_API ggml_backend_t ggml_backend_amx_init(void);
18
 
19
- GGML_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);
20
 
21
- GGML_API ggml_backend_reg_t ggml_backend_amx_reg(void);
22
 
23
  #ifdef __cplusplus
24
  }
 
9
  #endif
10
 
11
  // buffer_type API
12
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
13
 
14
+ GGML_BACKEND_API bool ggml_backend_is_amx(ggml_backend_t backend);
15
 
16
  // backend API
17
+ GGML_BACKEND_API ggml_backend_t ggml_backend_amx_init(void);
18
 
19
+ GGML_BACKEND_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);
20
 
21
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_amx_reg(void);
22
 
23
  #ifdef __cplusplus
24
  }
ggml/include/ggml-backend.h CHANGED
@@ -3,6 +3,20 @@
3
  #include "ggml.h"
4
  #include "ggml-alloc.h"
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  #ifdef __cplusplus
7
  extern "C" {
8
  #endif
 
3
  #include "ggml.h"
4
  #include "ggml-alloc.h"
5
 
6
+ #ifdef GGML_BACKEND_SHARED
7
+ # if defined(_WIN32) && !defined(__MINGW32__)
8
+ # ifdef GGML_BACKEND_BUILD
9
+ # define GGML_BACKEND_API __declspec(dllexport) extern
10
+ # else
11
+ # define GGML_BACKEND_API __declspec(dllimport) extern
12
+ # endif
13
+ # else
14
+ # define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern
15
+ # endif
16
+ #else
17
+ # define GGML_BACKEND_API extern
18
+ #endif
19
+
20
  #ifdef __cplusplus
21
  extern "C" {
22
  #endif
ggml/include/ggml-blas.h CHANGED
@@ -9,15 +9,15 @@ extern "C" {
9
  #endif
10
 
11
  // backend API
12
- GGML_API ggml_backend_t ggml_backend_blas_init(void);
13
 
14
- GGML_API bool ggml_backend_is_blas(ggml_backend_t backend);
15
 
16
  // number of threads used for conversion to float
17
  // for openblas and blis, this will also set the number of threads used for blas operations
18
- GGML_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
19
 
20
- GGML_API ggml_backend_reg_t ggml_backend_blas_reg(void);
21
 
22
 
23
  #ifdef __cplusplus
 
9
  #endif
10
 
11
  // backend API
12
+ GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);
13
 
14
+ GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);
15
 
16
  // number of threads used for conversion to float
17
  // for openblas and blis, this will also set the number of threads used for blas operations
18
+ GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
19
 
20
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);
21
 
22
 
23
  #ifdef __cplusplus
ggml/include/ggml-cann.h CHANGED
@@ -34,7 +34,7 @@ extern "C" {
34
  */
35
  #define GGML_CANN_MAX_DEVICES 16
36
 
37
- GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
38
 
39
  /**
40
  * @brief Initializes the CANN backend for a specified device.
@@ -46,7 +46,7 @@ GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
46
  * @param device The index of the device to initialize.
47
  * @return A pointer to the initialized backend instance, or nullptr on failure.
48
  */
49
- GGML_API ggml_backend_t ggml_backend_cann_init(int32_t device);
50
 
51
  /**
52
  * @brief Checks if a given backend is a CANN backend.
@@ -57,7 +57,7 @@ GGML_API ggml_backend_t ggml_backend_cann_init(int32_t device);
57
  * @param backend The backend instance to check.
58
  * @return True if the backend is a CANN backend, false otherwise.
59
  */
60
- GGML_API bool ggml_backend_is_cann(ggml_backend_t backend);
61
 
62
  /**
63
  * @brief Retrieves the CANN buffer type for a specified device.
@@ -69,7 +69,7 @@ GGML_API bool ggml_backend_is_cann(ggml_backend_t backend);
69
  * @return A pointer to the buffer type interface for the specified device, or
70
  * nullptr if the device index is out of range.
71
  */
72
- GGML_API ggml_backend_buffer_type_t
73
  ggml_backend_cann_buffer_type(int32_t device);
74
 
75
  /**
@@ -80,14 +80,14 @@ ggml_backend_cann_buffer_type(int32_t device);
80
  *
81
  * @return The number of CANN devices available.
82
  */
83
- GGML_API int32_t ggml_backend_cann_get_device_count(void);
84
 
85
  /**
86
  * @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.
87
  *
88
  * @return A pointer to the host buffer type interface.
89
  */
90
- GGML_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
91
 
92
  /**
93
  * @brief Retrieves the description of a specific CANN device.
@@ -99,7 +99,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
99
  * @param description Pointer to a buffer where the description will be written.
100
  * @param description_size Size of the description buffer.
101
  */
102
- GGML_API void ggml_backend_cann_get_device_description(
103
  int32_t device, char* description, size_t description_size);
104
 
105
  /**
@@ -114,7 +114,7 @@ GGML_API void ggml_backend_cann_get_device_description(
114
  * @param total Pointer to a variable where the total memory size will be
115
  * stored.
116
  */
117
- GGML_API void ggml_backend_cann_get_device_memory(int32_t device,
118
  size_t* free,
119
  size_t* total);
120
 
 
34
  */
35
  #define GGML_CANN_MAX_DEVICES 16
36
 
37
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void);
38
 
39
  /**
40
  * @brief Initializes the CANN backend for a specified device.
 
46
  * @param device The index of the device to initialize.
47
  * @return A pointer to the initialized backend instance, or nullptr on failure.
48
  */
49
+ GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device);
50
 
51
  /**
52
  * @brief Checks if a given backend is a CANN backend.
 
57
  * @param backend The backend instance to check.
58
  * @return True if the backend is a CANN backend, false otherwise.
59
  */
60
+ GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend);
61
 
62
  /**
63
  * @brief Retrieves the CANN buffer type for a specified device.
 
69
  * @return A pointer to the buffer type interface for the specified device, or
70
  * nullptr if the device index is out of range.
71
  */
72
+ GGML_BACKEND_API ggml_backend_buffer_type_t
73
  ggml_backend_cann_buffer_type(int32_t device);
74
 
75
  /**
 
80
  *
81
  * @return The number of CANN devices available.
82
  */
83
+ GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void);
84
 
85
  /**
86
  * @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.
87
  *
88
  * @return A pointer to the host buffer type interface.
89
  */
90
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
91
 
92
  /**
93
  * @brief Retrieves the description of a specific CANN device.
 
99
  * @param description Pointer to a buffer where the description will be written.
100
  * @param description_size Size of the description buffer.
101
  */
102
+ GGML_BACKEND_API void ggml_backend_cann_get_device_description(
103
  int32_t device, char* description, size_t description_size);
104
 
105
  /**
 
114
  * @param total Pointer to a variable where the total memory size will be
115
  * stored.
116
  */
117
+ GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device,
118
  size_t* free,
119
  size_t* total);
120
 
ggml/include/ggml-cpu.h CHANGED
@@ -54,54 +54,77 @@ extern "C" {
54
  GGML_NUMA_STRATEGY_COUNT
55
  };
56
 
57
- GGML_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
58
- GGML_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
59
 
60
- GGML_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
61
- GGML_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
62
 
63
- GGML_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
64
- GGML_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
65
 
66
- GGML_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
67
- GGML_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
68
 
69
- GGML_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
70
- GGML_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
71
 
72
- GGML_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
73
- GGML_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
74
 
75
- GGML_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
76
- GGML_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
77
 
78
- GGML_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);
79
- GGML_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
80
- GGML_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
81
- GGML_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
82
- GGML_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
83
- GGML_API int ggml_threadpool_get_n_threads(struct ggml_threadpool * threadpool);
84
- GGML_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
85
- GGML_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
86
 
87
  // ggml_graph_plan() has to be called before ggml_graph_compute()
88
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
89
- GGML_API struct ggml_cplan ggml_graph_plan(
90
  const struct ggml_cgraph * cgraph,
91
  int n_threads, /* = GGML_DEFAULT_N_THREADS */
92
  struct ggml_threadpool * threadpool /* = NULL */ );
93
- GGML_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
94
 
95
  // same as ggml_graph_compute() but the work data is allocated as a part of the context
96
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
97
- GGML_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
98
 
99
- // TODO: move to backend interface
100
- GGML_API int ggml_cpu_has_neon (void);
101
- GGML_API int ggml_cpu_has_sve (void);
102
- GGML_API int ggml_cpu_has_matmul_int8(void);
103
- // get the sve vector length in bytes
104
- GGML_API int ggml_cpu_get_sve_cnt(void);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  // Internal types and functions exposed for tests and benchmarks
107
 
@@ -115,6 +138,7 @@ extern "C" {
115
  const void * GGML_RESTRICT y, int nr, int nc);
116
 
117
  struct ggml_type_traits_cpu {
 
118
  ggml_from_float_to_mat_t from_float_to_mat;
119
  ggml_vec_dot_t vec_dot;
120
  enum ggml_type vec_dot_type;
@@ -124,25 +148,25 @@ extern "C" {
124
  ggml_gemm_t gemm;
125
  };
126
 
127
- GGML_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
128
 
129
- GGML_API void ggml_cpu_init(void);
130
 
131
  //
132
  // CPU backend
133
  //
134
 
135
- GGML_API ggml_backend_t ggml_backend_cpu_init(void);
136
 
137
- GGML_API bool ggml_backend_is_cpu (ggml_backend_t backend);
138
- GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
139
- GGML_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
140
- GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
141
 
142
- GGML_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
143
 
144
  #ifdef GGML_USE_CPU_HBM
145
- GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
146
  #endif
147
 
148
  #ifdef __cplusplus
 
54
  GGML_NUMA_STRATEGY_COUNT
55
  };
56
 
57
+ GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
58
+ GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
59
 
60
+ GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
61
+ GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
62
 
63
+ GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
64
+ GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
65
 
66
+ GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
67
+ GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
68
 
69
+ GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
70
+ GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
71
 
72
+ GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
73
+ GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
74
 
75
+ GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
76
+ GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
77
 
78
+ GGML_BACKEND_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);
79
+ GGML_BACKEND_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
80
+ GGML_BACKEND_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
81
+ GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
82
+ GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
83
+ GGML_BACKEND_API int ggml_threadpool_get_n_threads(struct ggml_threadpool * threadpool);
84
+ GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
85
+ GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
86
 
87
  // ggml_graph_plan() has to be called before ggml_graph_compute()
88
  // when plan.work_size > 0, caller must allocate memory for plan.work_data
89
+ GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(
90
  const struct ggml_cgraph * cgraph,
91
  int n_threads, /* = GGML_DEFAULT_N_THREADS */
92
  struct ggml_threadpool * threadpool /* = NULL */ );
93
+ GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
94
 
95
  // same as ggml_graph_compute() but the work data is allocated as a part of the context
96
  // note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
97
+ GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
98
 
99
+ //
100
+ // system info
101
+ //
102
+
103
+ // x86
104
+ GGML_BACKEND_API int ggml_cpu_has_sse3 (void);
105
+ GGML_BACKEND_API int ggml_cpu_has_ssse3 (void);
106
+ GGML_BACKEND_API int ggml_cpu_has_avx (void);
107
+ GGML_BACKEND_API int ggml_cpu_has_avx2 (void);
108
+ GGML_BACKEND_API int ggml_cpu_has_f16c (void);
109
+ GGML_BACKEND_API int ggml_cpu_has_fma (void);
110
+ GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void);
111
+ GGML_BACKEND_API int ggml_cpu_has_avx512 (void);
112
+ GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);
113
+ GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);
114
+ GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);
115
+ GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void);
116
+ // ARM
117
+ GGML_BACKEND_API int ggml_cpu_has_neon (void);
118
+ GGML_BACKEND_API int ggml_cpu_has_arm_fma (void);
119
+ GGML_BACKEND_API int ggml_cpu_has_fp16_va (void);
120
+ GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
121
+ GGML_BACKEND_API int ggml_cpu_has_sve (void);
122
+ GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
123
+ // other
124
+ GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
125
+ GGML_BACKEND_API int ggml_cpu_has_vsx (void);
126
+ GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
127
+ GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
128
 
129
  // Internal types and functions exposed for tests and benchmarks
130
 
 
138
  const void * GGML_RESTRICT y, int nr, int nc);
139
 
140
  struct ggml_type_traits_cpu {
141
+ ggml_from_float_t from_float;
142
  ggml_from_float_to_mat_t from_float_to_mat;
143
  ggml_vec_dot_t vec_dot;
144
  enum ggml_type vec_dot_type;
 
148
  ggml_gemm_t gemm;
149
  };
150
 
151
+ GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
152
 
153
+ GGML_BACKEND_API void ggml_cpu_init(void);
154
 
155
  //
156
  // CPU backend
157
  //
158
 
159
+ GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);
160
 
161
+ GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend);
162
+ GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
163
+ GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
164
+ GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
165
 
166
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
167
 
168
  #ifdef GGML_USE_CPU_HBM
169
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
170
  #endif
171
 
172
  #ifdef __cplusplus
ggml/include/ggml-cuda.h CHANGED
@@ -7,7 +7,7 @@
7
  extern "C" {
8
  #endif
9
 
10
- #ifdef GGML_USE_HIPBLAS
11
  #define GGML_CUDA_NAME "ROCm"
12
  #define GGML_CUBLAS_NAME "hipBLAS"
13
  #elif defined(GGML_USE_MUSA)
@@ -20,27 +20,27 @@ extern "C" {
20
  #define GGML_CUDA_MAX_DEVICES 16
21
 
22
  // backend API
23
- GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
24
 
25
- GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
26
 
27
  // device buffer
28
- GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
29
 
30
  // split tensor buffer that splits matrices by rows across multiple devices
31
- GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
32
 
33
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
34
- GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
35
 
36
- GGML_API int ggml_backend_cuda_get_device_count(void);
37
- GGML_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
38
- GGML_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
39
 
40
- GGML_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
41
- GGML_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
42
 
43
- GGML_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
44
 
45
  #ifdef __cplusplus
46
  }
 
7
  extern "C" {
8
  #endif
9
 
10
+ #ifdef GGML_USE_HIP
11
  #define GGML_CUDA_NAME "ROCm"
12
  #define GGML_CUBLAS_NAME "hipBLAS"
13
  #elif defined(GGML_USE_MUSA)
 
20
  #define GGML_CUDA_MAX_DEVICES 16
21
 
22
  // backend API
23
+ GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);
24
 
25
+ GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
26
 
27
  // device buffer
28
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
29
 
30
  // split tensor buffer that splits matrices by rows across multiple devices
31
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
32
 
33
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
34
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
35
 
36
+ GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void);
37
+ GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
38
+ GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
39
 
40
+ GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
41
+ GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
42
 
43
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
44
 
45
  #ifdef __cplusplus
46
  }
ggml/include/ggml-kompute.h CHANGED
@@ -37,13 +37,13 @@ struct ggml_vk_device ggml_vk_current_device(void);
37
  // forward declaration
38
  typedef struct ggml_backend * ggml_backend_t;
39
 
40
- GGML_API ggml_backend_t ggml_backend_kompute_init(int device);
41
 
42
- GGML_API bool ggml_backend_is_kompute(ggml_backend_t backend);
43
 
44
- GGML_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
45
 
46
- GGML_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
47
 
48
  #ifdef __cplusplus
49
  }
 
37
  // forward declaration
38
  typedef struct ggml_backend * ggml_backend_t;
39
 
40
+ GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
41
 
42
+ GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
43
 
44
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
45
 
46
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
47
 
48
  #ifdef __cplusplus
49
  }
ggml/include/ggml-metal.h CHANGED
@@ -39,27 +39,27 @@ extern "C" {
39
  // user-code should use only these functions
40
  //
41
 
42
- GGML_API ggml_backend_t ggml_backend_metal_init(void);
43
 
44
- GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
45
 
46
  GGML_DEPRECATED(
47
- GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
  "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
49
 
50
- GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
51
 
52
- GGML_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
53
 
54
  // helper to check if the device supports a specific family
55
  // ideally, the user code should be doing these checks
56
  // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
57
- GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
58
 
59
  // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
60
- GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
61
 
62
- GGML_API ggml_backend_reg_t ggml_backend_metal_reg(void);
63
 
64
  #ifdef __cplusplus
65
  }
 
39
  // user-code should use only these functions
40
  //
41
 
42
+ GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
43
 
44
+ GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
45
 
46
  GGML_DEPRECATED(
47
+ GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
48
  "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
49
 
50
+ GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
51
 
52
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
53
 
54
  // helper to check if the device supports a specific family
55
  // ideally, the user code should be doing these checks
56
  // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
57
+ GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
58
 
59
  // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
60
+ GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
61
 
62
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void);
63
 
64
  #ifdef __cplusplus
65
  }
ggml/include/ggml-rpc.h CHANGED
@@ -10,18 +10,18 @@ extern "C" {
10
  #define GGML_RPC_MAX_SERVERS 16
11
 
12
  // backend API
13
- GGML_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
14
- GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend);
15
 
16
- GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
17
 
18
- GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
19
 
20
- GGML_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
21
 
22
- GGML_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
23
 
24
- GGML_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
25
 
26
  #ifdef __cplusplus
27
  }
 
10
  #define GGML_RPC_MAX_SERVERS 16
11
 
12
  // backend API
13
+ GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
14
+ GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
15
 
16
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
17
 
18
+ GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
19
 
20
+ GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
21
 
22
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
23
 
24
+ GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
25
 
26
  #ifdef __cplusplus
27
  }
ggml/include/ggml-sycl.h CHANGED
@@ -17,32 +17,32 @@ extern "C" {
17
  #endif
18
 
19
  // backend API
20
- GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
21
 
22
- GGML_API bool ggml_backend_is_sycl(ggml_backend_t backend);
23
 
24
  // devide buffer
25
- GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
26
 
27
  // split tensor buffer that splits matrices by rows across multiple devices
28
- GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
29
 
30
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
31
- GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
32
 
33
- GGML_API void ggml_backend_sycl_print_sycl_devices(void);
34
- GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
35
- GGML_API void ggml_backend_sycl_get_device_description(int device,
36
  char *description,
37
  size_t description_size);
38
- GGML_API int ggml_backend_sycl_get_device_count();
39
- GGML_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
40
 
41
  // SYCL doesn't support registering host memory, keep here for reference
42
- // GGML_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
43
- // GGML_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
44
 
45
- GGML_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
46
 
47
  #ifdef __cplusplus
48
  }
 
17
  #endif
18
 
19
  // backend API
20
+ GGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device);
21
 
22
+ GGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend);
23
 
24
  // devide buffer
25
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
26
 
27
  // split tensor buffer that splits matrices by rows across multiple devices
28
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
29
 
30
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
31
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
32
 
33
+ GGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void);
34
+ GGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
35
+ GGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device,
36
  char *description,
37
  size_t description_size);
38
+ GGML_BACKEND_API int ggml_backend_sycl_get_device_count();
39
+ GGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
40
 
41
  // SYCL doesn't support registering host memory, keep here for reference
42
+ // GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
43
+ // GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
44
 
45
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
46
 
47
  #ifdef __cplusplus
48
  }
ggml/include/ggml-vulkan.h CHANGED
@@ -10,21 +10,21 @@ extern "C" {
10
  #define GGML_VK_NAME "Vulkan"
11
  #define GGML_VK_MAX_DEVICES 16
12
 
13
- GGML_API void ggml_vk_instance_init(void);
14
 
15
  // backend API
16
- GGML_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
17
 
18
- GGML_API bool ggml_backend_is_vk(ggml_backend_t backend);
19
- GGML_API int ggml_backend_vk_get_device_count(void);
20
- GGML_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
21
- GGML_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
22
 
23
- GGML_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
24
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
25
- GGML_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
26
 
27
- GGML_API ggml_backend_reg_t ggml_backend_vk_reg(void);
28
 
29
  #ifdef __cplusplus
30
  }
 
10
  #define GGML_VK_NAME "Vulkan"
11
  #define GGML_VK_MAX_DEVICES 16
12
 
13
+ GGML_BACKEND_API void ggml_vk_instance_init(void);
14
 
15
  // backend API
16
+ GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
17
 
18
+ GGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend);
19
+ GGML_BACKEND_API int ggml_backend_vk_get_device_count(void);
20
+ GGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
21
+ GGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
22
 
23
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
24
  // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
25
+ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
26
 
27
+ GGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void);
28
 
29
  #ifdef __cplusplus
30
  }
ggml/include/ggml.h CHANGED
@@ -176,15 +176,15 @@
176
  #ifdef GGML_SHARED
177
  # if defined(_WIN32) && !defined(__MINGW32__)
178
  # ifdef GGML_BUILD
179
- # define GGML_API __declspec(dllexport)
180
  # else
181
- # define GGML_API __declspec(dllimport)
182
  # endif
183
  # else
184
- # define GGML_API __attribute__ ((visibility ("default")))
185
  # endif
186
  #else
187
- # define GGML_API
188
  #endif
189
 
190
  // TODO: support for clang
@@ -1490,7 +1490,7 @@ extern "C" {
1490
  "use ggml_rope_ext_inplace instead");
1491
 
1492
  // compute correction dims for YaRN RoPE scaling
1493
- void ggml_rope_yarn_corr_dims(
1494
  int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1495
 
1496
  // rotary position embedding backward, i.e compute dx from dy
@@ -2384,38 +2384,6 @@ extern "C" {
2384
  GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
2385
  GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
2386
 
2387
- //
2388
- // system info
2389
- //
2390
-
2391
- GGML_API int ggml_cpu_has_avx (void);
2392
- GGML_API int ggml_cpu_has_avx_vnni (void);
2393
- GGML_API int ggml_cpu_has_avx2 (void);
2394
- GGML_API int ggml_cpu_has_avx512 (void);
2395
- GGML_API int ggml_cpu_has_avx512_vbmi(void);
2396
- GGML_API int ggml_cpu_has_avx512_vnni(void);
2397
- GGML_API int ggml_cpu_has_avx512_bf16(void);
2398
- GGML_API int ggml_cpu_has_amx_int8 (void);
2399
- GGML_API int ggml_cpu_has_fma (void);
2400
- GGML_API int ggml_cpu_has_arm_fma (void);
2401
- GGML_API int ggml_cpu_has_metal (void);
2402
- GGML_API int ggml_cpu_has_f16c (void);
2403
- GGML_API int ggml_cpu_has_fp16_va (void);
2404
- GGML_API int ggml_cpu_has_wasm_simd (void);
2405
- GGML_API int ggml_cpu_has_blas (void);
2406
- GGML_API int ggml_cpu_has_cuda (void);
2407
- GGML_API int ggml_cpu_has_vulkan (void);
2408
- GGML_API int ggml_cpu_has_kompute (void);
2409
- GGML_API int ggml_cpu_has_gpublas (void);
2410
- GGML_API int ggml_cpu_has_sse3 (void);
2411
- GGML_API int ggml_cpu_has_ssse3 (void);
2412
- GGML_API int ggml_cpu_has_riscv_v (void);
2413
- GGML_API int ggml_cpu_has_sycl (void);
2414
- GGML_API int ggml_cpu_has_rpc (void);
2415
- GGML_API int ggml_cpu_has_vsx (void);
2416
- GGML_API int ggml_cpu_has_cann (void);
2417
- GGML_API int ggml_cpu_has_llamafile (void);
2418
-
2419
  #ifdef __cplusplus
2420
  // restrict not standard in C++
2421
  #define GGML_RESTRICT
@@ -2432,7 +2400,6 @@ extern "C" {
2432
  size_t type_size;
2433
  bool is_quantized;
2434
  ggml_to_float_t to_float;
2435
- ggml_from_float_t from_float;
2436
  ggml_from_float_t from_float_ref;
2437
  };
2438
 
 
176
  #ifdef GGML_SHARED
177
  # if defined(_WIN32) && !defined(__MINGW32__)
178
  # ifdef GGML_BUILD
179
+ # define GGML_API __declspec(dllexport) extern
180
  # else
181
+ # define GGML_API __declspec(dllimport) extern
182
  # endif
183
  # else
184
+ # define GGML_API __attribute__ ((visibility ("default"))) extern
185
  # endif
186
  #else
187
+ # define GGML_API extern
188
  #endif
189
 
190
  // TODO: support for clang
 
1490
  "use ggml_rope_ext_inplace instead");
1491
 
1492
  // compute correction dims for YaRN RoPE scaling
1493
+ GGML_API void ggml_rope_yarn_corr_dims(
1494
  int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
1495
 
1496
  // rotary position embedding backward, i.e compute dx from dy
 
2384
  GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
2385
  GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
2386
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2387
  #ifdef __cplusplus
2388
  // restrict not standard in C++
2389
  #define GGML_RESTRICT
 
2400
  size_t type_size;
2401
  bool is_quantized;
2402
  ggml_to_float_t to_float;
 
2403
  ggml_from_float_t from_float_ref;
2404
  };
2405
 
ggml/src/ggml-aarch64.c CHANGED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-aarch64.h CHANGED
@@ -1,9 +1,5 @@
1
- // SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
2
  #pragma once
3
 
4
- #define GGML_COMMON_DECL_C
5
- #include "ggml-common.h"
6
-
7
  #include "ggml.h"
8
 
9
  // GGML internal header
@@ -12,27 +8,11 @@
12
  extern "C" {
13
  #endif
14
 
15
- // Quantization
16
- void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
17
- void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
18
-
19
- void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
20
-
21
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
22
  size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
23
  size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
24
  size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
25
 
26
- // GEMV
27
- void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
28
- void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
29
- void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
30
-
31
- // GEMM
32
- void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
33
- void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
34
- void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
35
-
36
  #ifdef __cplusplus
37
  }
38
  #endif
 
 
1
  #pragma once
2
 
 
 
 
3
  #include "ggml.h"
4
 
5
  // GGML internal header
 
8
  extern "C" {
9
  #endif
10
 
 
 
 
 
 
 
11
  // Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
12
  size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
13
  size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
14
  size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
15
 
 
 
 
 
 
 
 
 
 
 
16
  #ifdef __cplusplus
17
  }
18
  #endif
ggml/src/ggml-amx/CMakeLists.txt ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
2
+ (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
3
+ CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$") AND
4
+ CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 11.0)
5
+ message(STATUS "Using AMX")
6
+
7
+ file(GLOB GGML_HEADERS_AMX "*.h")
8
+ list(APPEND GGML_HEADERS_AMX "../../include/ggml-amx.h")
9
+
10
+ file(GLOB GGML_SOURCES_AMX "*.cpp")
11
+
12
+ add_library(ggml-amx
13
+ ${GGML_HEADERS_AMX}
14
+ ${GGML_SOURCES_AMX})
15
+
16
+ target_link_libraries(ggml-amx PRIVATE ggml-base)
17
+ target_include_directories(ggml-amx PRIVATE . ..)
18
+
19
+ # this is duplicated from the CPU backend, since the AMX backend also depends on the architecture flags
20
+ # TODO: integrate AMX backend into the CPU backend
21
+ if (MSVC)
22
+ # instruction set detection for MSVC only
23
+ if (GGML_NATIVE)
24
+ # TODO: improve, should not reference files from the parent folder
25
+ include(../ggml-cpu/cmake/FindSIMD.cmake)
26
+ endif ()
27
+ if (GGML_AVX512)
28
+ list(APPEND ARCH_FLAGS /arch:AVX512)
29
+ # MSVC has no compile-time flags enabling specific
30
+ # AVX512 extensions, neither it defines the
31
+ # macros corresponding to the extensions.
32
+ # Do it manually.
33
+ if (GGML_AVX512_VBMI)
34
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
35
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
36
+ endif()
37
+ if (GGML_AVX512_VNNI)
38
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
39
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
40
+ endif()
41
+ if (GGML_AVX512_BF16)
42
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
43
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
44
+ endif()
45
+ if (GGML_AMX_TILE)
46
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_TILE__>)
47
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_TILE__>)
48
+ endif()
49
+ if (GGML_AMX_INT8)
50
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_INT8__>)
51
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_INT8__>)
52
+ endif()
53
+ if (GGML_AMX_BF16)
54
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_BF16__>)
55
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_BF16__>)
56
+ endif()
57
+ elseif (GGML_AVX2)
58
+ list(APPEND ARCH_FLAGS /arch:AVX2)
59
+ elseif (GGML_AVX)
60
+ list(APPEND ARCH_FLAGS /arch:AVX)
61
+ endif()
62
+ else()
63
+ if (GGML_NATIVE)
64
+ list(APPEND ARCH_FLAGS -march=native)
65
+ endif()
66
+ if (GGML_F16C)
67
+ list(APPEND ARCH_FLAGS -mf16c)
68
+ endif()
69
+ if (GGML_FMA)
70
+ list(APPEND ARCH_FLAGS -mfma)
71
+ endif()
72
+ if (GGML_AVX)
73
+ list(APPEND ARCH_FLAGS -mavx)
74
+ endif()
75
+ if (GGML_AVX2)
76
+ list(APPEND ARCH_FLAGS -mavx2)
77
+ endif()
78
+ if (GGML_AVX512)
79
+ list(APPEND ARCH_FLAGS -mavx512f)
80
+ list(APPEND ARCH_FLAGS -mavx512dq)
81
+ list(APPEND ARCH_FLAGS -mavx512bw)
82
+ endif()
83
+ if (GGML_AVX512_VBMI)
84
+ list(APPEND ARCH_FLAGS -mavx512vbmi)
85
+ endif()
86
+ if (GGML_AVX512_VNNI)
87
+ list(APPEND ARCH_FLAGS -mavx512vnni)
88
+ endif()
89
+ if (GGML_AVX512_BF16)
90
+ list(APPEND ARCH_FLAGS -mavx512bf16)
91
+ endif()
92
+ if (GGML_AMX_TILE)
93
+ list(APPEND ARCH_FLAGS -mamx-tile)
94
+ endif()
95
+ if (GGML_AMX_INT8)
96
+ list(APPEND ARCH_FLAGS -mamx-int8)
97
+ endif()
98
+ if (GGML_AMX_BF16)
99
+ list(APPEND ARCH_FLAGS -mamx-bf16)
100
+ endif()
101
+ endif()
102
+
103
+ target_compile_options(ggml-amx PRIVATE ${ARCH_FLAGS})
104
+ else()
105
+ set(GGML_AMX OFF PARENT_SCOPE)
106
+ message(WARNING "AMX requires x86 and gcc version > 11.0. Turning off GGML_AMX.")
107
+ endif()
ggml/src/ggml-amx/common.h CHANGED
@@ -1,7 +1,8 @@
1
  #pragma once
2
 
3
  #include "ggml.h"
4
- #include "ggml-cpu-impl.h" // <immintrin.h>
 
5
 
6
  #include <algorithm>
7
  #include <memory>
 
1
  #pragma once
2
 
3
  #include "ggml.h"
4
+ // hack until AMX is moved into the CPU backend
5
+ #include "../ggml-cpu/ggml-cpu-impl.h" // <immintrin.h>
6
 
7
  #include <algorithm>
8
  #include <memory>
ggml/src/ggml-amx/ggml-amx.cpp ADDED
@@ -0,0 +1,449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-amx.h"
2
+ #include "ggml-amx/common.h"
3
+ #include "ggml-amx/mmq.h"
4
+ #include "ggml-backend-impl.h"
5
+ #include "ggml-impl.h"
6
+
7
+ #if defined(__gnu_linux__)
8
+ #include <sys/syscall.h>
9
+ #include <unistd.h>
10
+ #endif
11
+
12
+ #include <cstdlib>
13
+ #include <cstring>
14
+ #include <memory>
15
+
16
+ #if defined(__AMX_INT8__)
17
+
18
+ // AMX buffer interface
19
+ static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
20
+ free(buffer->context);
21
+ }
22
+
23
+ static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
24
+ return (void *)(buffer->context);
25
+ }
26
+
27
+ static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
28
+ memset((char *)tensor->data + offset, value, size);
29
+
30
+ GGML_UNUSED(buffer);
31
+ }
32
+
33
+ static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
34
+ if (qtype_has_amx_kernels(tensor->type)) {
35
+ ggml_backend_amx_convert_weight(tensor, data, offset, size);
36
+ } else {
37
+ memcpy((char *)tensor->data + offset, data, size);
38
+ }
39
+
40
+ GGML_UNUSED(buffer);
41
+ }
42
+
43
+ static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
44
+ GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
45
+ memcpy(data, (const char *)tensor->data + offset, size);
46
+
47
+ GGML_UNUSED(buffer);
48
+ }
49
+
50
+ static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
51
+ if (ggml_backend_buffer_is_host(src->buffer)) {
52
+ if (qtype_has_amx_kernels(src->type)) {
53
+ ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
54
+ } else {
55
+ memcpy(dst->data, src->data, ggml_nbytes(src));
56
+ }
57
+ return true;
58
+ }
59
+ return false;
60
+
61
+ GGML_UNUSED(buffer);
62
+ }
63
+
64
+ static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
65
+ memset(buffer->context, value, buffer->size);
66
+ }
67
+
68
+ static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
69
+ /* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
70
+ /* .get_base = */ ggml_backend_amx_buffer_get_base,
71
+ /* .init_tensor = */ NULL, // no initialization required
72
+ /* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
73
+ /* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
74
+ /* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
75
+ /* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
76
+ /* .clear = */ ggml_backend_amx_buffer_clear,
77
+ /* .reset = */ NULL,
78
+ };
79
+
80
+ static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
81
+ return "AMX";
82
+
83
+ GGML_UNUSED(buft);
84
+ }
85
+
86
+ static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
87
+ void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
88
+ if (data == NULL) {
89
+ fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
90
+ return NULL;
91
+ }
92
+
93
+ return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
94
+ }
95
+
96
+ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
97
+ return TENSOR_ALIGNMENT;
98
+
99
+ GGML_UNUSED(buft);
100
+ }
101
+
102
+ static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
103
+ return ggml_backend_amx_get_alloc_size(tensor);
104
+
105
+ GGML_UNUSED(buft);
106
+ }
107
+
108
+ static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
109
+ return false;
110
+
111
+ GGML_UNUSED(buft);
112
+ }
113
+
114
+ ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
115
+ static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
116
+ /* .iface = */ {
117
+ /* .get_name = */ ggml_backend_amx_buffer_type_get_name,
118
+ /* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
119
+ /* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
120
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
121
+ /* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
122
+ /* .is_host = */ ggml_backend_amx_buffer_type_is_host,
123
+ },
124
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
125
+ /* .context = */ NULL,
126
+ };
127
+
128
+ return &ggml_backend_buffer_type_amx;
129
+ }
130
+
131
+ // backend interface
132
+
133
+ static const char * ggml_backend_amx_name(ggml_backend_t backend) {
134
+ return "AMX";
135
+
136
+ GGML_UNUSED(backend);
137
+ }
138
+
139
+ static void ggml_backend_amx_free(ggml_backend_t backend) {
140
+ ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
141
+ delete ctx;
142
+ delete backend;
143
+ }
144
+
145
+ static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
146
+ ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
147
+
148
+ for (int i = 0; i < cgraph->n_nodes; i++) {
149
+ struct ggml_tensor * node = cgraph->nodes[i];
150
+
151
+ switch (node->op) {
152
+ case GGML_OP_MUL_MAT:
153
+ ggml_backend_amx_mul_mat(ctx, node);
154
+ break;
155
+
156
+ case GGML_OP_NONE:
157
+ case GGML_OP_RESHAPE:
158
+ case GGML_OP_VIEW:
159
+ case GGML_OP_PERMUTE:
160
+ case GGML_OP_TRANSPOSE:
161
+ break;
162
+
163
+ default:
164
+ fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
165
+ GGML_ASSERT(false);
166
+ }
167
+ }
168
+
169
+ return GGML_STATUS_SUCCESS;
170
+
171
+ GGML_UNUSED(backend);
172
+ }
173
+
174
+ static struct ggml_backend_i ggml_backend_amx_i = {
175
+ /* .get_name = */ ggml_backend_amx_name,
176
+ /* .free = */ ggml_backend_amx_free,
177
+ /* .set_tensor_async = */ NULL,
178
+ /* .get_tensor_async = */ NULL,
179
+ /* .cpy_tensor_async = */ NULL,
180
+ /* .synchronize = */ NULL,
181
+ /* .graph_plan_create = */ NULL,
182
+ /* .graph_plan_free = */ NULL,
183
+ /* .graph_plan_update = */ NULL,
184
+ /* .graph_plan_compute = */ NULL,
185
+ /* .graph_compute = */ ggml_backend_amx_graph_compute,
186
+ /* .event_record = */ NULL,
187
+ /* .event_wait = */ NULL,
188
+ };
189
+
190
+ static ggml_guid_t ggml_backend_amx_guid() {
191
+ static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
192
+ return &guid;
193
+ }
194
+
195
+ #define ARCH_GET_XCOMP_PERM 0x1022
196
+ #define ARCH_REQ_XCOMP_PERM 0x1023
197
+ #define XFEATURE_XTILECFG 17
198
+ #define XFEATURE_XTILEDATA 18
199
+
200
+ static bool ggml_amx_init() {
201
+ #if defined(__gnu_linux__)
202
+ if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
203
+ fprintf(stderr, "AMX is not ready to be used!\n");
204
+ return false;
205
+ }
206
+ return true;
207
+ #elif defined(_WIN32)
208
+ return true;
209
+ #endif
210
+ }
211
+
212
+ ggml_backend_t ggml_backend_amx_init() {
213
+
214
+ // invoke a Linux system call to request access to AMX features
215
+ ggml_amx_init();
216
+
217
+ // backend context
218
+ ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
219
+
220
+ // ggml amx backend
221
+ ggml_backend_t backend = new ggml_backend {
222
+ /* .guid = */ ggml_backend_amx_guid(),
223
+ /* .interface = */ ggml_backend_amx_i,
224
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
225
+ /* .context = */ ctx,
226
+ };
227
+
228
+ return backend;
229
+ }
230
+
231
+ bool ggml_backend_is_amx(ggml_backend_t backend) {
232
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
233
+ }
234
+
235
+ void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
236
+ GGML_ASSERT(ggml_backend_is_amx(backend_amx));
237
+
238
+ ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
239
+ ctx->n_threads = n_threads;
240
+ }
241
+
242
+ // device interface
243
+
244
+ static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
245
+ return "AMX";
246
+
247
+ GGML_UNUSED(dev);
248
+ }
249
+
250
+ static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
251
+ return "Intel Advanced Matrix Extensions";
252
+
253
+ GGML_UNUSED(dev);
254
+ }
255
+
256
+ static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
257
+ // TODO
258
+ *free = 0;
259
+ *total = 0;
260
+
261
+ GGML_UNUSED(dev);
262
+ }
263
+
264
+ static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
265
+ return GGML_BACKEND_DEVICE_TYPE_ACCEL;
266
+
267
+ GGML_UNUSED(dev);
268
+ }
269
+
270
+ static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
271
+ props->name = ggml_backend_amx_device_get_name(dev);
272
+ props->description = ggml_backend_amx_device_get_description(dev);
273
+ props->type = ggml_backend_amx_device_get_type(dev);
274
+ ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
275
+
276
+ // `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
277
+ props->caps = {
278
+ /* .async = */ false,
279
+ /* .host_buffer = */ false,
280
+ /* .buffer_from_host_ptr = */ false,
281
+ /* .events = */ false,
282
+ };
283
+ }
284
+
285
+ static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
286
+ return ggml_backend_amx_init();
287
+
288
+ GGML_UNUSED(dev);
289
+ GGML_UNUSED(params);
290
+ }
291
+
292
+ static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
293
+ return ggml_backend_amx_buffer_type();
294
+
295
+ GGML_UNUSED(dev);
296
+ }
297
+
298
+ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
299
+
300
+ // handle only 2d gemm for now
301
+ auto is_contiguous_2d = [](const struct ggml_tensor * t) {
302
+ return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
303
+ };
304
+
305
+ switch (op->op) {
306
+ case GGML_OP_NONE:
307
+ case GGML_OP_RESHAPE:
308
+ case GGML_OP_VIEW:
309
+ case GGML_OP_PERMUTE:
310
+ case GGML_OP_TRANSPOSE:
311
+ return true;
312
+
313
+ case GGML_OP_MUL_MAT: {
314
+ const struct ggml_tensor * src0 = op->src[0];
315
+ const struct ggml_tensor * src1 = op->src[1];
316
+
317
+ const enum ggml_type type = src0->type;
318
+ const int64_t ne0 = op->ne[0];
319
+
320
+ bool is_training = src0->grad || src1->grad;
321
+
322
+ // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
323
+ // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
324
+ bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
325
+
326
+ bool can_use_amx =
327
+ is_contiguous_2d(src0) && // src0 must be contiguous
328
+ is_contiguous_2d(src1) && // src1 must be contiguous
329
+ !is_training && // inference only
330
+ src1->type == GGML_TYPE_F32 && // src1 must be float32
331
+ has_amx_kernels && // with amx kernel impls
332
+ ne0 % (TILE_N * 2) == 0; // out_features is 32x
333
+
334
+ return can_use_amx;
335
+ }
336
+ default:
337
+ return false;
338
+ }
339
+
340
+ GGML_UNUSED(dev);
341
+ }
342
+
343
+ static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
344
+ return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
345
+
346
+ GGML_UNUSED(dev);
347
+ }
348
+
349
+ static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
350
+ /* .get_name = */ ggml_backend_amx_device_get_name,
351
+ /* .get_description = */ ggml_backend_amx_device_get_description,
352
+ /* .get_memory = */ ggml_backend_amx_device_get_memory,
353
+ /* .get_type = */ ggml_backend_amx_device_get_type,
354
+ /* .get_props = */ ggml_backend_amx_device_get_props,
355
+ /* .init_backend = */ ggml_backend_amx_device_init,
356
+ /* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
357
+ /* .get_host_buffer_type = */ NULL,
358
+ /* .buffer_from_host_ptr = */ NULL,
359
+ /* .supports_op = */ ggml_backend_amx_device_supports_op,
360
+ /* .supports_buft = */ ggml_backend_amx_device_supports_buft,
361
+ /* .offload_op = */ NULL,
362
+ /* .event_new = */ NULL,
363
+ /* .event_free = */ NULL,
364
+ /* .event_synchronize = */ NULL,
365
+ };
366
+
367
+ // backend reg interface
368
+
369
+ static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
370
+ return "AMX";
371
+
372
+ GGML_UNUSED(reg);
373
+ }
374
+
375
+ static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
376
+ return 1;
377
+
378
+ GGML_UNUSED(reg);
379
+ }
380
+
381
+ static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
382
+ GGML_ASSERT(index == 0);
383
+
384
+ static ggml_backend_device ggml_backend_amx_device = {
385
+ /* .iface = */ ggml_backend_amx_device_i,
386
+ /* .reg = */ reg,
387
+ /* .context = */ nullptr,
388
+ };
389
+
390
+ return &ggml_backend_amx_device;
391
+
392
+ GGML_UNUSED(reg);
393
+ GGML_UNUSED(index);
394
+ }
395
+
396
+ static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
397
+ if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
398
+ return (void *)ggml_backend_amx_set_n_threads;
399
+ }
400
+ return NULL;
401
+
402
+ GGML_UNUSED(reg);
403
+ GGML_UNUSED(name);
404
+ }
405
+
406
+ static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
407
+ /* .get_name = */ ggml_backend_amx_reg_get_name,
408
+ /* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
409
+ /* .get_device = */ ggml_backend_amx_reg_get_device,
410
+ /* .get_proc_address = */ ggml_backend_amx_get_proc_address,
411
+ };
412
+
413
+ ggml_backend_reg_t ggml_backend_amx_reg(void) {
414
+ static struct ggml_backend_reg ggml_backend_amx_reg = {
415
+ /* .iface = */ ggml_backend_amx_reg_i,
416
+ /* .context = */ NULL,
417
+ };
418
+
419
+ return &ggml_backend_amx_reg;
420
+ }
421
+
422
+ #else // if defined(__AMX_INT8__)
423
+
424
+ ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void) {
425
+ return nullptr;
426
+ }
427
+
428
+ bool ggml_backend_is_amx(ggml_backend_t backend) {
429
+ GGML_UNUSED(backend);
430
+ return false;
431
+ }
432
+
433
+ ggml_backend_t ggml_backend_amx_init(void) {
434
+ fprintf(stderr, "GGML is not compiled with AMX support!\n");
435
+ return nullptr;
436
+ }
437
+
438
+ void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
439
+ fprintf(stderr, "GGML is not compiled with AMX support!\n");
440
+
441
+ GGML_UNUSED(backend_amx);
442
+ GGML_UNUSED(n_threads);
443
+ }
444
+
445
+ ggml_backend_reg_t ggml_backend_amx_reg(void) {
446
+ return nullptr;
447
+ }
448
+
449
+ #endif
ggml/src/ggml-amx/mmq.cpp CHANGED
@@ -496,19 +496,20 @@ inline void from_float(const float * x, char * vy, int64_t k);
496
 
497
  template <>
498
  inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
499
- quantize_row_q8_0(x, vy, k);
 
500
  }
501
 
502
  template <>
503
  inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
504
- quantize_row_q8_1(x, vy, k);
505
  }
506
 
507
  template <>
508
  inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {
509
  #if 1
510
  // TODO: this is reference impl!
511
- quantize_row_q8_K(x, vy, k);
512
  #else
513
  quantize_row_q8_K_vnni(x, vy, k);
514
  #endif
 
496
 
497
  template <>
498
  inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
499
+ // FIXME: using unoptimized reference impl until moved to CPU backend
500
+ quantize_row_q8_0_ref(x, (block_q8_0 *)vy, k);
501
  }
502
 
503
  template <>
504
  inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
505
+ quantize_row_q8_1_ref(x, (block_q8_1 *)vy, k);
506
  }
507
 
508
  template <>
509
  inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {
510
  #if 1
511
  // TODO: this is reference impl!
512
+ quantize_row_q8_K_ref(x, (block_q8_K *)vy, k);
513
  #else
514
  quantize_row_q8_K_vnni(x, vy, k);
515
  #endif
ggml/src/ggml-backend-reg.cpp ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-backend-impl.h"
2
+ #include "ggml-backend.h"
3
+ #include "ggml-cpu.h"
4
+ #include "ggml-impl.h"
5
+ #include <cstring>
6
+ #include <vector>
7
+
8
+ // Backend registry
9
+
10
+ #ifdef GGML_USE_CUDA
11
+ #include "ggml-cuda.h"
12
+ #endif
13
+
14
+ #ifdef GGML_USE_METAL
15
+ #include "ggml-metal.h"
16
+ #endif
17
+
18
+ #ifdef GGML_USE_SYCL
19
+ #include "ggml-sycl.h"
20
+ #endif
21
+
22
+ #ifdef GGML_USE_VULKAN
23
+ #include "ggml-vulkan.h"
24
+ #endif
25
+
26
+ #ifdef GGML_USE_BLAS
27
+ #include "ggml-blas.h"
28
+ #endif
29
+
30
+ #ifdef GGML_USE_RPC
31
+ #include "ggml-rpc.h"
32
+ #endif
33
+
34
+ #ifdef GGML_USE_AMX
35
+ # include "ggml-amx.h"
36
+ #endif
37
+
38
+ #ifdef GGML_USE_CANN
39
+ #include "ggml-cann.h"
40
+ #endif
41
+
42
+ #ifdef GGML_USE_KOMPUTE
43
+ #include "ggml-kompute.h"
44
+ #endif
45
+
46
+ struct ggml_backend_registry {
47
+ std::vector<ggml_backend_reg_t> backends;
48
+ std::vector<ggml_backend_dev_t> devices;
49
+
50
+ ggml_backend_registry() {
51
+ #ifdef GGML_USE_CUDA
52
+ register_backend(ggml_backend_cuda_reg());
53
+ #endif
54
+ #ifdef GGML_USE_METAL
55
+ register_backend(ggml_backend_metal_reg());
56
+ #endif
57
+ #ifdef GGML_USE_SYCL
58
+ register_backend(ggml_backend_sycl_reg());
59
+ #endif
60
+ #ifdef GGML_USE_VULKAN
61
+ register_backend(ggml_backend_vk_reg());
62
+ #endif
63
+ #ifdef GGML_USE_CANN
64
+ register_backend(ggml_backend_cann_reg());
65
+ #endif
66
+ #ifdef GGML_USE_BLAS
67
+ register_backend(ggml_backend_blas_reg());
68
+ #endif
69
+ #ifdef GGML_USE_RPC
70
+ register_backend(ggml_backend_rpc_reg());
71
+ #endif
72
+ #ifdef GGML_USE_AMX
73
+ register_backend(ggml_backend_amx_reg());
74
+ #endif
75
+ #ifdef GGML_USE_KOMPUTE
76
+ register_backend(ggml_backend_kompute_reg());
77
+ #endif
78
+
79
+ register_backend(ggml_backend_cpu_reg());
80
+ }
81
+
82
+ void register_backend(ggml_backend_reg_t reg) {
83
+ if (!reg) {
84
+ return;
85
+ }
86
+
87
+ #ifndef NDEBUG
88
+ GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
89
+ __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
90
+ #endif
91
+ backends.push_back(reg);
92
+ for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
93
+ register_device(ggml_backend_reg_dev_get(reg, i));
94
+ }
95
+ }
96
+
97
+ void register_device(ggml_backend_dev_t device) {
98
+ #ifndef NDEBUG
99
+ GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
100
+ #endif
101
+ devices.push_back(device);
102
+ }
103
+ };
104
+
105
+ static ggml_backend_registry & get_reg() {
106
+ static ggml_backend_registry reg;
107
+ return reg;
108
+ }
109
+
110
+ // Internal API
111
+ void ggml_backend_register(ggml_backend_reg_t reg) {
112
+ get_reg().register_backend(reg);
113
+ }
114
+
115
+ void ggml_backend_device_register(ggml_backend_dev_t device) {
116
+ get_reg().register_device(device);
117
+ }
118
+
119
+ // Backend (reg) enumeration
120
+ size_t ggml_backend_reg_count() {
121
+ return get_reg().backends.size();
122
+ }
123
+
124
+ ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
125
+ GGML_ASSERT(index < ggml_backend_reg_count());
126
+ return get_reg().backends[index];
127
+ }
128
+
129
+ ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
130
+ for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
131
+ ggml_backend_reg_t reg = ggml_backend_reg_get(i);
132
+ if (std::strcmp(ggml_backend_reg_name(reg), name) == 0) {
133
+ return reg;
134
+ }
135
+ }
136
+ return NULL;
137
+ }
138
+
139
+ // Device enumeration
140
+ size_t ggml_backend_dev_count() {
141
+ return get_reg().devices.size();
142
+ }
143
+
144
+ ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
145
+ GGML_ASSERT(index < ggml_backend_dev_count());
146
+ return get_reg().devices[index];
147
+ }
148
+
149
+ ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
150
+ for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
151
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
152
+ if (strcmp(ggml_backend_dev_name(dev), name) == 0) {
153
+ return dev;
154
+ }
155
+ }
156
+ return NULL;
157
+ }
158
+
159
+ ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {
160
+ for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
161
+ ggml_backend_dev_t dev = ggml_backend_dev_get(i);
162
+ if (ggml_backend_dev_type(dev) == type) {
163
+ return dev;
164
+ }
165
+ }
166
+ return NULL;
167
+ }
168
+
169
+ // Convenience functions
170
+ ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) {
171
+ ggml_backend_dev_t dev = ggml_backend_dev_by_name(name);
172
+ if (!dev) {
173
+ return NULL;
174
+ }
175
+ return ggml_backend_dev_init(dev, params);
176
+ }
177
+
178
+ ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) {
179
+ ggml_backend_dev_t dev = ggml_backend_dev_by_type(type);
180
+ if (!dev) {
181
+ return NULL;
182
+ }
183
+ return ggml_backend_dev_init(dev, params);
184
+ }
185
+
186
+ ggml_backend_t ggml_backend_init_best(void) {
187
+ ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
188
+ if (!dev) {
189
+ dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
190
+ }
191
+ if (!dev) {
192
+ return NULL;
193
+ }
194
+ return ggml_backend_dev_init(dev, NULL);
195
+ }
ggml/src/ggml-blas/CMakeLists.txt ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if (GGML_STATIC)
2
+ set(BLA_STATIC ON)
3
+ endif()
4
+ #if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22)
5
+ # set(BLA_SIZEOF_INTEGER 8)
6
+ #endif()
7
+
8
+ set(BLA_VENDOR ${GGML_BLAS_VENDOR})
9
+ find_package(BLAS)
10
+
11
+ if (BLAS_FOUND)
12
+ message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
13
+
14
+ add_library(ggml-blas
15
+ ggml-blas.cpp
16
+ )
17
+
18
+ target_link_libraries(ggml-blas PRIVATE ggml-base)
19
+ target_include_directories(ggml-blas PRIVATE . ..)
20
+
21
+ if (${GGML_BLAS_VENDOR} MATCHES "Apple")
22
+ add_compile_definitions(ACCELERATE_NEW_LAPACK)
23
+ add_compile_definitions(ACCELERATE_LAPACK_ILP64)
24
+ add_compile_definitions(GGML_BLAS_USE_ACCELERATE)
25
+ elseif ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
26
+ # BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
27
+ # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
28
+ find_package(PkgConfig REQUIRED)
29
+ if (${GGML_BLAS_VENDOR} MATCHES "Generic")
30
+ pkg_check_modules(DepBLAS blas)
31
+ elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
32
+ # As of openblas v0.3.22, the 64-bit is named openblas64.pc
33
+ pkg_check_modules(DepBLAS openblas64)
34
+ if (NOT DepBLAS_FOUND)
35
+ pkg_check_modules(DepBLAS openblas)
36
+ endif()
37
+ elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
38
+ add_compile_definitions(GGML_BLAS_USE_BLIS)
39
+ pkg_check_modules(DepBLAS blis)
40
+ elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
41
+ pkg_check_modules(DepBLAS blas-atlas)
42
+ elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
43
+ pkg_check_modules(DepBLAS flexiblas_api)
44
+ elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
45
+ add_compile_definitions(GGML_BLAS_USE_MKL)
46
+ # all Intel* libraries share the same include path
47
+ pkg_check_modules(DepBLAS mkl-sdl)
48
+ elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
49
+ # this doesn't provide pkg-config
50
+ # suggest to assign BLAS_INCLUDE_DIRS on your own
51
+ if ("${NVHPC_VERSION}" STREQUAL "")
52
+ message(WARNING "Better to set NVHPC_VERSION")
53
+ else()
54
+ set(DepBLAS_FOUND ON)
55
+ set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
56
+ endif()
57
+ endif()
58
+ if (DepBLAS_FOUND)
59
+ set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
60
+ else()
61
+ message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
62
+ " detected by pkgconfig, trying to find cblas.h from possible paths...")
63
+ find_path(BLAS_INCLUDE_DIRS
64
+ NAMES cblas.h
65
+ HINTS
66
+ /usr/include
67
+ /usr/local/include
68
+ /usr/include/openblas
69
+ /opt/homebrew/opt/openblas/include
70
+ /usr/local/opt/openblas/include
71
+ /usr/include/x86_64-linux-gnu/openblas/include
72
+ )
73
+ endif()
74
+ endif()
75
+
76
+ message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
77
+
78
+ #add_compile_options(${BLAS_LINKER_FLAGS})
79
+ target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
80
+
81
+ if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
82
+ add_compile_definitions(GGML_BLAS_USE_MKL)
83
+ endif()
84
+
85
+ target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
86
+ target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
87
+ else()
88
+ message(ERROR "BLAS not found, please refer to "
89
+ "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
90
+ " to set correct GGML_BLAS_VENDOR")
91
+ endif()
ggml/src/ggml-blas/ggml-blas.cpp ADDED
@@ -0,0 +1,514 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-impl.h"
2
+ #include "ggml-blas.h"
3
+ #include "ggml-backend-impl.h"
4
+
5
+ #include <future>
6
+ #include <vector>
7
+ #include <cstring>
8
+
9
+ #if defined(GGML_BLAS_USE_ACCELERATE)
10
+ # include <Accelerate/Accelerate.h>
11
+ #elif defined(GGML_BLAS_USE_MKL)
12
+ # include <mkl.h>
13
+ #elif defined(GGML_BLAS_USE_BLIS)
14
+ # include <blis.h>
15
+ #elif defined(GGML_BLAS_USE_NVPL)
16
+ # include <nvpl_blas.h>
17
+ #else
18
+ # include <cblas.h>
19
+ #endif
20
+
21
+ struct ggml_backend_blas_context {
22
+ int n_threads = GGML_DEFAULT_N_THREADS;
23
+ std::unique_ptr<char[]> work_data;
24
+ size_t work_size = 0;
25
+ #ifndef GGML_USE_OPENMP
26
+ std::vector<std::future<void>> tasks;
27
+ #endif
28
+ };
29
+
30
+ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
31
+ const struct ggml_tensor * src0 = dst->src[0];
32
+ const struct ggml_tensor * src1 = dst->src[1];
33
+
34
+ GGML_TENSOR_BINARY_OP_LOCALS
35
+
36
+ const enum ggml_type type = src0->type;
37
+
38
+ GGML_ASSERT(ne0 == ne01);
39
+ GGML_ASSERT(ne1 == ne11);
40
+ GGML_ASSERT(ne2 == ne12);
41
+ GGML_ASSERT(ne3 == ne13);
42
+
43
+ // we don't support permuted src0 or src1
44
+ GGML_ASSERT(nb00 == ggml_type_size(type));
45
+ GGML_ASSERT(nb10 == ggml_type_size(src1->type));
46
+
47
+ // dst cannot be transposed or permuted
48
+ GGML_ASSERT(nb0 == sizeof(float));
49
+ GGML_ASSERT(nb0 <= nb1);
50
+ GGML_ASSERT(nb1 <= nb2);
51
+ GGML_ASSERT(nb2 <= nb3);
52
+
53
+ // broadcast factors
54
+ const int64_t r2 = ne12/ne02;
55
+ const int64_t r3 = ne13/ne03;
56
+
57
+ const int64_t ne_plane = ne01*ne00;
58
+ const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
59
+
60
+ if (ctx->work_size < desired_wsize) {
61
+ ctx->work_data.reset(new char[desired_wsize]);
62
+ ctx->work_size = desired_wsize;
63
+ }
64
+ void * wdata = ctx->work_data.get();
65
+
66
+ // convert src0 to float
67
+ if (type != GGML_TYPE_F32) {
68
+ const auto * type_traits = ggml_get_type_traits(type);
69
+ ggml_to_float_t const to_float = type_traits->to_float;
70
+
71
+ for (int64_t i03 = 0; i03 < ne03; i03++) {
72
+ for (int64_t i02 = 0; i02 < ne02; i02++) {
73
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
74
+ float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
75
+
76
+ const int min_cols_per_thread = 4096;
77
+ const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
78
+ const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
79
+
80
+ #ifdef GGML_USE_OPENMP
81
+ #pragma omp parallel for num_threads(n_threads)
82
+ for (int64_t i01 = 0; i01 < ne01; i01++) {
83
+ to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
84
+ }
85
+ #else
86
+ for (int i = 1; i < n_threads; i++) {
87
+ const int64_t start = i*ne01/n_threads;
88
+ const int64_t end = (i + 1)*ne01/n_threads;
89
+ if (start < end) {
90
+ ctx->tasks.push_back(std::async(std::launch::async, [=]() {
91
+ for (int64_t i01 = start; i01 < end; i01++) {
92
+ to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
93
+ }
94
+ }));
95
+ }
96
+ }
97
+ {
98
+ // reuse the current thread for the first task
99
+ const int64_t start = 0;
100
+ const int64_t end = ne01/n_threads;
101
+ for (int64_t i01 = start; i01 < end; i01++) {
102
+ to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
103
+ }
104
+ }
105
+ #endif
106
+ }
107
+ }
108
+
109
+ #ifndef GGML_USE_OPENMP
110
+ // wait for all tasks to finish
111
+ for (auto & task : ctx->tasks) {
112
+ task.get();
113
+ }
114
+ ctx->tasks.clear();
115
+ #endif
116
+ }
117
+
118
+ #if defined(OPENBLAS_VERSION)
119
+ openblas_set_num_threads(ctx->n_threads);
120
+ #endif
121
+
122
+ #if defined(GGML_BLAS_USE_BLIS)
123
+ bli_thread_set_num_threads(ctx->n_threads);
124
+ #endif
125
+
126
+ #if defined(GGML_BLAS_USE_NVPL)
127
+ nvpl_blas_set_num_threads(ctx->n_threads);
128
+ #endif
129
+
130
+ for (int64_t i13 = 0; i13 < ne13; i13++) {
131
+ for (int64_t i12 = 0; i12 < ne12; i12++) {
132
+ const int64_t i03 = i13/r3;
133
+ const int64_t i02 = i12/r2;
134
+
135
+ const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
136
+ const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
137
+ float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
138
+
139
+ if (type != GGML_TYPE_F32) {
140
+ x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
141
+ }
142
+
143
+ cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
144
+ ne1, ne01, ne10,
145
+ 1.0f, y, ne10,
146
+ x, ne00,
147
+ 0.0f, d, ne01);
148
+ }
149
+ }
150
+ }
151
+
152
+ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
153
+ const struct ggml_tensor * src0 = dst->src[0];
154
+ const struct ggml_tensor * src1 = dst->src[1];
155
+
156
+ GGML_TENSOR_BINARY_OP_LOCALS
157
+
158
+ GGML_ASSERT(ne0 == ne00);
159
+ GGML_ASSERT(ne1 == ne10);
160
+ GGML_ASSERT(ne2 == ne02);
161
+ GGML_ASSERT(ne02 == ne12);
162
+ GGML_ASSERT(ne3 == ne13);
163
+ GGML_ASSERT(ne03 == ne13);
164
+
165
+ // we don't support permuted src0 or src1
166
+ GGML_ASSERT(nb00 == sizeof(float));
167
+
168
+ // dst cannot be transposed or permuted
169
+ GGML_ASSERT(nb0 == sizeof(float));
170
+ // GGML_ASSERT(nb0 <= nb1);
171
+ // GGML_ASSERT(nb1 <= nb2);
172
+ // GGML_ASSERT(nb2 <= nb3);
173
+
174
+ // Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
175
+ // src0: (k,n)
176
+ // src1: (k,m)
177
+ // dst: (m,n)
178
+ //
179
+ // Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
180
+ // Also expressed as (major,minor)
181
+ // a: (m,k): so src1 transposed
182
+ // b: (k,n): so src0
183
+ // c: (m,n)
184
+ //
185
+ // However, if ggml_is_transposed(src1) is true, then
186
+ // src1->data already contains a transposed version, so sgemm mustn't
187
+ // transpose it further.
188
+
189
+ int n = src0->ne[0];
190
+ int k = src0->ne[1];
191
+ int m = src1->ne[0];
192
+
193
+ CBLAS_TRANSPOSE transposeA;
194
+ int lda;
195
+
196
+ if (!ggml_is_transposed(src1)) {
197
+ transposeA = CblasTrans;
198
+ lda = m;
199
+ } else {
200
+ transposeA = CblasNoTrans;
201
+ lda = k;
202
+ }
203
+
204
+ float * a = (float *) ((char *) src1->data);
205
+ float * b = (float *) ((char *) src0->data);
206
+ float * c = (float *) ((char *) dst->data);
207
+
208
+ cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
209
+
210
+ GGML_UNUSED(ctx);
211
+ }
212
+
213
+ // backend interface
214
+
215
+ static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
216
+ return "BLAS";
217
+
218
+ GGML_UNUSED(backend);
219
+ }
220
+
221
+ static void ggml_backend_blas_free(ggml_backend_t backend) {
222
+ ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
223
+ delete ctx;
224
+ delete backend;
225
+ }
226
+
227
+ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
228
+ ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
229
+
230
+ for (int i = 0; i < cgraph->n_nodes; i++) {
231
+ struct ggml_tensor * node = cgraph->nodes[i];
232
+
233
+ switch (node->op) {
234
+ case GGML_OP_MUL_MAT:
235
+ ggml_backend_blas_mul_mat(ctx, node);
236
+ break;
237
+
238
+ case GGML_OP_OUT_PROD:
239
+ ggml_backend_blas_out_prod(ctx, node);
240
+ break;
241
+
242
+ case GGML_OP_NONE:
243
+ case GGML_OP_RESHAPE:
244
+ case GGML_OP_VIEW:
245
+ case GGML_OP_PERMUTE:
246
+ case GGML_OP_TRANSPOSE:
247
+ break;
248
+
249
+ default:
250
+ GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
251
+ }
252
+ }
253
+
254
+ return GGML_STATUS_SUCCESS;
255
+
256
+ GGML_UNUSED(backend);
257
+ }
258
+
259
+ static struct ggml_backend_i blas_backend_i = {
260
+ /* .get_name = */ ggml_backend_blas_get_name,
261
+ /* .free = */ ggml_backend_blas_free,
262
+ /* .set_tensor_async = */ NULL,
263
+ /* .get_tensor_async = */ NULL,
264
+ /* .cpy_tensor_async = */ NULL,
265
+ /* .synchronize = */ NULL,
266
+ /* .graph_plan_create = */ NULL,
267
+ /* .graph_plan_free = */ NULL,
268
+ /* .graph_plan_update = */ NULL,
269
+ /* .graph_plan_compute = */ NULL,
270
+ /* .graph_compute = */ ggml_backend_blas_graph_compute,
271
+ /* .event_record = */ NULL,
272
+ /* .event_wait = */ NULL,
273
+ };
274
+
275
+ static ggml_guid_t ggml_backend_blas_guid(void) {
276
+ static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
277
+ return &guid;
278
+ }
279
+
280
+ ggml_backend_t ggml_backend_blas_init(void) {
281
+ ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
282
+
283
+ ggml_backend_t backend = new ggml_backend {
284
+ /* .guid = */ ggml_backend_blas_guid(),
285
+ /* .interface = */ blas_backend_i,
286
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
287
+ /* .context = */ ctx,
288
+ };
289
+
290
+ #if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
291
+ if (openblas_get_parallel() != OPENBLAS_OPENMP) {
292
+ GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
293
+ }
294
+ #endif
295
+
296
+ #if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
297
+ GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
298
+ #endif
299
+
300
+ return backend;
301
+ }
302
+
303
+ bool ggml_backend_is_blas(ggml_backend_t backend) {
304
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
305
+ }
306
+
307
+ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
308
+ GGML_ASSERT(ggml_backend_is_blas(backend_blas));
309
+
310
+ ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
311
+ ctx->n_threads = n_threads;
312
+ }
313
+
314
+ // device interface
315
+
316
+ static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
317
+ return "BLAS";
318
+
319
+ GGML_UNUSED(dev);
320
+ }
321
+
322
+ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
323
+ #if defined(GGML_BLAS_USE_ACCELERATE)
324
+ return "Accelerate";
325
+ #elif defined(GGML_BLAS_USE_MKL)
326
+ return "MKL";
327
+ #elif defined(GGML_BLAS_USE_BLIS)
328
+ return "BLIS";
329
+ #elif defined(GGML_BLAS_USE_NVPL)
330
+ return "NVPL";
331
+ #elif defined(OPENBLAS_VERSION)
332
+ return "OpenBLAS";
333
+ #else
334
+ return "BLAS";
335
+ #endif
336
+
337
+ GGML_UNUSED(dev);
338
+ }
339
+
340
+ static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
341
+ // TODO
342
+ *free = 0;
343
+ *total = 0;
344
+
345
+ GGML_UNUSED(dev);
346
+ }
347
+
348
+ static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
349
+ return GGML_BACKEND_DEVICE_TYPE_ACCEL;
350
+
351
+ GGML_UNUSED(dev);
352
+ }
353
+
354
+ static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
355
+ props->name = ggml_backend_blas_device_get_name(dev);
356
+ props->description = ggml_backend_blas_device_get_description(dev);
357
+ props->type = ggml_backend_blas_device_get_type(dev);
358
+ ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
359
+ props->caps = {
360
+ /* .async = */ false,
361
+ /* .host_buffer = */ false,
362
+ /* .buffer_from_host_ptr = */ true,
363
+ /* .events = */ false,
364
+ };
365
+ }
366
+
367
+ static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
368
+ return ggml_backend_blas_init();
369
+
370
+ GGML_UNUSED(dev);
371
+ GGML_UNUSED(params);
372
+ }
373
+
374
+ static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
375
+ return ggml_backend_cpu_buffer_type();
376
+
377
+ GGML_UNUSED(dev);
378
+ }
379
+
380
+ static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
381
+ return ggml_backend_cpu_buffer_from_ptr(ptr, size);
382
+
383
+ GGML_UNUSED(dev);
384
+ GGML_UNUSED(max_tensor_size);
385
+ }
386
+
387
+ static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
388
+ const struct ggml_tensor * src0 = op->src[0];
389
+ const struct ggml_tensor * src1 = op->src[1];
390
+
391
+ switch (op->op) {
392
+ case GGML_OP_NONE:
393
+ case GGML_OP_RESHAPE:
394
+ case GGML_OP_VIEW:
395
+ case GGML_OP_PERMUTE:
396
+ case GGML_OP_TRANSPOSE:
397
+ return true;
398
+
399
+ case GGML_OP_MUL_MAT:
400
+ {
401
+ // BLAS usually is only faster for large matrices
402
+ const struct ggml_tensor * src0 = op->src[0];
403
+ const struct ggml_tensor * src1 = op->src[1];
404
+
405
+ const int64_t ne10 = src1->ne[0];
406
+
407
+ const int64_t ne0 = op->ne[0];
408
+ const int64_t ne1 = op->ne[1];
409
+
410
+ // TODO: find the optimal value
411
+ const int64_t min_batch = 32;
412
+
413
+ return ggml_is_contiguous(src0) &&
414
+ ggml_is_contiguous(src1) &&
415
+ src1->type == GGML_TYPE_F32 &&
416
+ (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
417
+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
418
+ }
419
+
420
+ case GGML_OP_OUT_PROD:
421
+ return op->src[0]->type == GGML_TYPE_F32 &&
422
+ op->src[1]->type == GGML_TYPE_F32 &&
423
+ ggml_is_matrix(src0) &&
424
+ ggml_is_matrix(src1) &&
425
+ ggml_is_contiguous(src0) &&
426
+ (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
427
+ (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
428
+
429
+ default:
430
+ return false;
431
+
432
+ }
433
+
434
+ GGML_UNUSED(dev);
435
+ }
436
+
437
+ static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
438
+ return ggml_backend_buft_is_host(buft);
439
+
440
+ GGML_UNUSED(dev);
441
+ }
442
+
443
+ static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
444
+ /* .get_name = */ ggml_backend_blas_device_get_name,
445
+ /* .get_description = */ ggml_backend_blas_device_get_description,
446
+ /* .get_memory = */ ggml_backend_blas_device_get_memory,
447
+ /* .get_type = */ ggml_backend_blas_device_get_type,
448
+ /* .get_props = */ ggml_backend_blas_device_get_props,
449
+ /* .init_backend = */ ggml_backend_blas_device_init_backend,
450
+ /* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
451
+ /* .get_host_buffer_type = */ NULL,
452
+ /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
453
+ /* .supports_op = */ ggml_backend_blas_device_supports_op,
454
+ /* .supports_buft = */ ggml_backend_blas_device_supports_buft,
455
+ /* .offload_op = */ NULL,
456
+ /* .event_new = */ NULL,
457
+ /* .event_free = */ NULL,
458
+ /* .event_synchronize = */ NULL,
459
+ };
460
+
461
+ // backend reg interface
462
+
463
+ static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
464
+ return "BLAS";
465
+
466
+ GGML_UNUSED(reg);
467
+ }
468
+
469
+ static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
470
+ return 1;
471
+
472
+ GGML_UNUSED(reg);
473
+ }
474
+
475
+ static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
476
+ GGML_ASSERT(index == 0);
477
+
478
+ static ggml_backend_device ggml_backend_blas_device = {
479
+ /* .iface = */ ggml_backend_blas_device_i,
480
+ /* .reg = */ reg,
481
+ /* .context = */ nullptr,
482
+ };
483
+
484
+ return &ggml_backend_blas_device;
485
+
486
+ GGML_UNUSED(reg);
487
+ GGML_UNUSED(index);
488
+ }
489
+
490
+ static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
491
+ if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
492
+ return (void *)ggml_backend_blas_set_n_threads;
493
+ }
494
+ return NULL;
495
+
496
+ GGML_UNUSED(reg);
497
+ GGML_UNUSED(name);
498
+ }
499
+
500
+ static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
501
+ /* .get_name = */ ggml_backend_blas_reg_get_name,
502
+ /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
503
+ /* .get_device = */ ggml_backend_blas_reg_get_device,
504
+ /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
505
+ };
506
+
507
+ ggml_backend_reg_t ggml_backend_blas_reg(void) {
508
+ static struct ggml_backend_reg ggml_backend_blas_reg = {
509
+ /* .iface = */ ggml_backend_blas_reg_i,
510
+ /* .context = */ NULL,
511
+ };
512
+
513
+ return &ggml_backend_blas_reg;
514
+ }
ggml/src/ggml-cann/CMakeLists.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME})
2
+ set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME})
3
+ message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}")
4
+ endif()
5
+
6
+ if (CANN_INSTALL_DIR)
7
+ # Only Support Linux.
8
+ if (NOT UNIX)
9
+ message(FATAL_ERROR "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}")
10
+ endif()
11
+
12
+ # Supported platforms: x86-64, arm64
13
+ if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
14
+ elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64")
15
+ else()
16
+ message(FATAL_ERROR "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}")
17
+ endif()
18
+
19
+ # Set header and libs
20
+ set(CANN_INCLUDE_DIRS
21
+ ${CANN_INSTALL_DIR}/include
22
+ ${CANN_INSTALL_DIR}/include/aclnn
23
+ ${CANN_INSTALL_DIR}/acllib/include
24
+ )
25
+
26
+ add_subdirectory(kernels)
27
+ list(APPEND CANN_LIBRARIES
28
+ ascendcl
29
+ nnopbase
30
+ opapi
31
+ acl_op_compiler
32
+ ascendc_kernels
33
+ )
34
+
35
+ file(GLOB GGML_SOURCES_CANN "*.cpp")
36
+
37
+ add_library(ggml-cann ${GGML_SOURCES_CANN})
38
+ target_link_libraries(ggml-cann PRIVATE ggml-base ${CANN_LIBRARIES})
39
+ target_include_directories(ggml-cann PRIVATE . .. ${CANN_INCLUDE_DIRS})
40
+ target_link_directories(ggml-cann PRIVATE ${CANN_INSTALL_DIR}/lib64)
41
+
42
+ message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
43
+ message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
44
+ else()
45
+ message(FATAL_ERROR "CANN: Can't find CANN_INSTALL_DIR, did you forget to source set_var.sh?")
46
+ endif()
ggml/src/ggml-cann/ggml-cann.cpp ADDED
@@ -0,0 +1,2128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ * Copyright (c) 2023-2024 The ggml authors
3
+ *
4
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
5
+ * of this software and associated documentation files (the "Software"), to
6
+ * deal in the Software without restriction, including without limitation the
7
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
8
+ * sell copies of the Software, and to permit persons to whom the Software is
9
+ * furnished to do so, subject to the following conditions:
10
+ *
11
+ * The above copyright notice and this permission notice shall be included in
12
+ * all copies or substantial portions of the Software.
13
+ *
14
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
19
+ * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
20
+ * IN THE SOFTWARE.
21
+ */
22
+
23
+ #include "ggml-cann.h"
24
+
25
+ #include <acl/acl.h>
26
+ #include <stdarg.h>
27
+
28
+ #include <cmath>
29
+ #include <cstdio>
30
+ #include <cstring>
31
+ #include <mutex>
32
+
33
+ #include "ggml-impl.h"
34
+ #include "ggml-backend-impl.h"
35
+ #include "ggml-cann/aclnn_ops.h"
36
+ #include "ggml-cann/common.h"
37
+
38
+ #define GGML_COMMON_DECL_C
39
+
40
+ #include "ggml-common.h"
41
+
42
+ #define GGML_CANN_NAME "CANN"
43
+
44
+ /**
45
+ * @brief Handles CANN errors by printing an error message and aborting.
46
+ *
47
+ * @param stmt The statement that caused the error.
48
+ * @param func The function in which the error occurred.
49
+ * @param file The file in which the error occurred.
50
+ * @param line The line number where the error occurred.
51
+ * @param msg The error message.
52
+ */
53
+ [[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
54
+ const char* file, int line, const char* msg) {
55
+ int32_t id = -1;
56
+ aclrtGetDevice(&id);
57
+
58
+ GGML_LOG_ERROR("CANN error: %s\n", msg);
59
+ GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
60
+ file, line);
61
+ GGML_LOG_ERROR(" %s\n", stmt);
62
+ // abort with GGML_ASSERT to get a stack trace
63
+ GGML_ABORT("CANN error");
64
+ }
65
+
66
+ /**
67
+ * @brief Sets the device to be used by CANN.
68
+ *
69
+ * @param device The device ID to set.
70
+ */
71
+ void ggml_cann_set_device(const int32_t device) {
72
+ // TODO: uncomment these lines after empty context has fixed.
73
+ // int current_device;
74
+ // ACL_CHECK(aclrtGetDevice(&current_device));
75
+
76
+ // if (device == current_device) {
77
+ // return;
78
+ // }
79
+ ACL_CHECK(aclrtSetDevice(device));
80
+ }
81
+
82
+ /**
83
+ * @brief Retrieves the current device ID.
84
+ *
85
+ * @return The current device ID.
86
+ */
87
+ int32_t ggml_cann_get_device() {
88
+ int32_t id;
89
+ ACL_CHECK(aclrtGetDevice(&id));
90
+ return id;
91
+ }
92
+
93
+ /**
94
+ * @brief Initialize the CANN device information.
95
+ *
96
+ * This function initializes the CANN device information by obtaining the
97
+ * device count and setting the memory allocation granularity for each device.
98
+ *
99
+ * @return A structure containing the device information.
100
+ */
101
+ static ggml_cann_device_info ggml_cann_init() {
102
+ ggml_cann_device_info info = {};
103
+
104
+ aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
105
+
106
+ if (err != ACL_SUCCESS) {
107
+ GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
108
+ __func__, aclGetRecentErrMsg());
109
+ return info;
110
+ }
111
+
112
+ GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
113
+
114
+ for (int id = 0; id < info.device_count; ++id) {
115
+ aclrtPhysicalMemProp prop = {};
116
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
117
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
118
+ prop.memAttr = ACL_HBM_MEM_HUGE;
119
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
120
+ prop.location.id = id;
121
+ prop.reserve = 0;
122
+ ACL_CHECK(aclrtMemGetAllocationGranularity(
123
+ &prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
124
+ &info.devices[id].vmm_granularity));
125
+ }
126
+
127
+ // TODO: add more device info later.
128
+ return info;
129
+ }
130
+
131
+ /**
132
+ * @brief Retrieve the CANN device information.
133
+ *
134
+ * This function returns a reference to a structure containing the CANN device
135
+ * information. The device information is initialized once and reused on
136
+ * subsequent calls.
137
+ *
138
+ * @return A reference to the structure containing the device information.
139
+ */
140
+ const ggml_cann_device_info& ggml_cann_info() {
141
+ static ggml_cann_device_info info = ggml_cann_init();
142
+ return info;
143
+ }
144
+
145
+ //#define DEBUG_CANN_MALLOC
146
+ /**
147
+ * @brief A pool of CANN buffers(legacy).
148
+ *
149
+ * This class manages a pool of CANN buffers for a specific device.
150
+ */
151
+ struct ggml_cann_pool_leg : public ggml_cann_pool {
152
+ /**
153
+ * @brief The maximum number of buffers in the pool.
154
+ */
155
+ static const int MAX_BUFFERS = 256;
156
+
157
+ /**
158
+ * @brief The device ID associated with this buffer pool.
159
+ */
160
+ int device;
161
+
162
+ /**
163
+ * @brief Structure representing a CANN buffer.
164
+ */
165
+ struct ggml_cann_buffer {
166
+ void* ptr = nullptr; ///< Pointer to the buffer memory.
167
+ size_t size = 0; ///< Size of the buffer.
168
+ };
169
+
170
+ /**
171
+ * @brief Array of CANN buffers in the pool.
172
+ */
173
+ ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
174
+
175
+ /**
176
+ * @brief Total size of all buffers in the pool.
177
+ */
178
+ size_t pool_size = 0;
179
+
180
+ /**
181
+ * @brief Constructor to initialize the buffer pool for a specific device.
182
+ *
183
+ * @param device The device ID to associate with this buffer pool.
184
+ */
185
+ explicit ggml_cann_pool_leg(int device) : device(device) {}
186
+
187
+ /**
188
+ * @brief Destructor to free all buffers in the pool.
189
+ */
190
+ ~ggml_cann_pool_leg() {
191
+ ggml_cann_set_device(device);
192
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
193
+ ggml_cann_buffer& b = buffer_pool[i];
194
+ if (b.ptr != nullptr) {
195
+ ACL_CHECK(aclrtFree(b.ptr));
196
+ pool_size -= b.size;
197
+ }
198
+ }
199
+ GGML_ASSERT(pool_size == 0);
200
+ }
201
+
202
+ /**
203
+ * @brief Allocate a buffer of the given size.
204
+ *
205
+ * @param size The size of the buffer to allocate.
206
+ * @param actual_size A pointer to a variable to receive the actual size of
207
+ * the allocated buffer.
208
+ * @return A pointer to the allocated buffer.
209
+ */
210
+ void* alloc(size_t size, size_t* actual_size) override {
211
+ #ifdef DEBUG_CANN_MALLOC
212
+ int nnz = 0;
213
+ size_t max_size = 0;
214
+ #endif
215
+ size_t best_diff = 1ull << 36;
216
+ int ibest = -1;
217
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
218
+ ggml_cann_buffer& b = buffer_pool[i];
219
+ if (b.ptr != nullptr) {
220
+ #ifdef DEBUG_CANN_MALLOC
221
+ ++nnz;
222
+ if (b.size > max_size) max_size = b.size;
223
+ #endif
224
+ if (b.size >= size) {
225
+ size_t diff = b.size - size;
226
+ if (diff < best_diff) {
227
+ best_diff = diff;
228
+ ibest = i;
229
+ if (!best_diff) {
230
+ void* ptr = b.ptr;
231
+ *actual_size = b.size;
232
+ b.ptr = nullptr;
233
+ b.size = 0;
234
+ return ptr;
235
+ }
236
+ }
237
+ }
238
+ }
239
+ }
240
+ if (ibest >= 0) {
241
+ ggml_cann_buffer& b = buffer_pool[ibest];
242
+ void* ptr = b.ptr;
243
+ *actual_size = b.size;
244
+ b.ptr = nullptr;
245
+ b.size = 0;
246
+ return ptr;
247
+ }
248
+ void* ptr;
249
+ size_t look_ahead_size = (size_t)(1.05 * size);
250
+ look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
251
+ ggml_cann_set_device(device);
252
+ ACL_CHECK(
253
+ aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
254
+ *actual_size = look_ahead_size;
255
+ pool_size += look_ahead_size;
256
+ #ifdef DEBUG_CANN_MALLOC
257
+ GGML_LOG_INFO(
258
+ "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
259
+ "requested %u MB\n",
260
+ __func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
261
+ (uint32_t)(pool_size / 1024 / 1024),
262
+ (uint32_t)(size / 1024 / 1024));
263
+ #endif
264
+ return ptr;
265
+ }
266
+
267
+ /**
268
+ * @brief Free a buffer and return it to the pool.
269
+ *
270
+ * @param ptr Pointer to the buffer to free.
271
+ * @param size Size of the buffer to free.
272
+ */
273
+ void free(void* ptr, size_t size) override {
274
+ for (int i = 0; i < MAX_BUFFERS; ++i) {
275
+ ggml_cann_buffer& b = buffer_pool[i];
276
+ if (b.ptr == nullptr) {
277
+ b.ptr = ptr;
278
+ b.size = size;
279
+ return;
280
+ }
281
+ }
282
+ // memory should always buffered. these memory may still needed by
283
+ // tasks in stream.
284
+ // TODO, fix me.
285
+ GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
286
+ }
287
+ };
288
+
289
+ /**
290
+ * @brief A pool of CANN buffers with virtual memory.
291
+ *
292
+ * This class manages a pool of CANN buffers with virtual memory for a specific
293
+ * device.
294
+ */
295
+ struct ggml_cann_pool_vmm : public ggml_cann_pool {
296
+ /**
297
+ * @brief The maximum size of the virtual memory pool (32 GB).
298
+ */
299
+ static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
300
+
301
+ /**
302
+ * @brief The device ID associated with this buffer pool.
303
+ */
304
+ int device;
305
+
306
+ /**
307
+ * @brief Pointer to the start of the virtual memory pool.
308
+ */
309
+ void* pool_addr = 0;
310
+
311
+ /**
312
+ * @brief Amount of virtual memory used in the pool.
313
+ */
314
+ size_t pool_used = 0;
315
+
316
+ /**
317
+ * @brief Total size of the virtual memory pool.
318
+ */
319
+ size_t pool_size = 0;
320
+
321
+ /**
322
+ * @brief Allocation granularity for the virtual memory pool.
323
+ */
324
+ size_t granularity;
325
+
326
+ /**
327
+ * @brief Handles for the physical memory allocated.
328
+ */
329
+ std::vector<aclrtDrvMemHandle> handles;
330
+
331
+ /**
332
+ * @brief Offsets for the mapped memory regions.
333
+ */
334
+ std::vector<void*> map_offsets;
335
+
336
+ /**
337
+ * @brief Constructor to initialize the buffer pool with virtual memory for
338
+ * a specific device.
339
+ *
340
+ * @param device The device ID to associate with this buffer pool.
341
+ */
342
+ explicit ggml_cann_pool_vmm(int device)
343
+ : device(device),
344
+ granularity(ggml_cann_info().devices[device].vmm_granularity) {}
345
+
346
+ /**
347
+ * @brief Destructor to free all buffers in the virtual memory pool.
348
+ */
349
+ ~ggml_cann_pool_vmm() {
350
+ if (pool_addr != 0) {
351
+ for (auto& offset : map_offsets) {
352
+ ACL_CHECK(aclrtUnmapMem(offset));
353
+ }
354
+ for (auto& handle : handles) {
355
+ ACL_CHECK(aclrtFreePhysical(handle));
356
+ }
357
+ ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
358
+ }
359
+ }
360
+
361
+ /**
362
+ * @brief Allocate a buffer of the given size in the virtual memory pool.
363
+ *
364
+ * @param size The size of the buffer to allocate.
365
+ * @param actual_size A pointer to a variable to receive the actual size of
366
+ * the allocated buffer.
367
+ * @return A pointer to the allocated buffer.
368
+ */
369
+ void* alloc(size_t size, size_t* actual_size) override {
370
+ // round up the allocation size to the alignment to ensure that all
371
+ // allocations are aligned for all data types
372
+ const size_t alignment = 128;
373
+ size = alignment * ((size + alignment - 1) / alignment);
374
+
375
+ size_t avail = pool_size - pool_used;
376
+
377
+ if (size > avail) {
378
+ // round up to the next multiple of the granularity
379
+ size_t reserve_size = size - avail;
380
+ reserve_size =
381
+ granularity * ((reserve_size + granularity - 1) / granularity);
382
+
383
+ GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
384
+
385
+ // allocate more physical memory
386
+ aclrtPhysicalMemProp prop = {};
387
+ prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
388
+ prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
389
+ prop.memAttr = ACL_HBM_MEM_HUGE;
390
+ prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
391
+ prop.location.id = device;
392
+ prop.reserve = 0;
393
+ aclrtDrvMemHandle handle;
394
+ ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
395
+
396
+ // reserve virtual address space (if not already reserved)
397
+ if (pool_addr == 0) {
398
+ ACL_CHECK(aclrtReserveMemAddress(
399
+ &pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
400
+ }
401
+
402
+ // map at the end of the pool
403
+ ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
404
+ handle, 0));
405
+
406
+ handles.push_back(handle);
407
+ map_offsets.push_back((char*)pool_addr + pool_size);
408
+
409
+ // add to the pool
410
+ pool_size += reserve_size;
411
+
412
+ // GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (
413
+ // reserved %llu MB)\n",
414
+ // device, (unsigned long long) (pool_size/1024/1024),
415
+ // (unsigned long long) (reserve_size/1024/1024));
416
+ }
417
+
418
+ GGML_ASSERT(pool_addr != 0);
419
+
420
+ void* ptr = (void*)((char*)pool_addr + pool_used);
421
+ *actual_size = size;
422
+ pool_used += size;
423
+
424
+ #ifdef DEBUG_CANN_MALLOC
425
+ GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
426
+ (unsigned long long)size, (unsigned long long)ptr);
427
+ #endif
428
+ return ptr;
429
+ }
430
+
431
+ /**
432
+ * @brief Free a buffer and return it to the virtual memory pool.
433
+ *
434
+ * @param ptr Pointer to the buffer to free.
435
+ * @param size Size of the buffer to free.
436
+ */
437
+ void free(void* ptr, size_t size) override {
438
+ #ifdef DEBUG_CANN_MALLOC
439
+ GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
440
+ (unsigned long long)size, (unsigned long long)ptr);
441
+ #endif
442
+
443
+ pool_used -= size;
444
+
445
+ // all deallocations must be in reverse order of the allocations
446
+ GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
447
+ }
448
+ };
449
+
450
+ /**
451
+ * @brief Create a new CANN pool for a specific device.
452
+ *
453
+ * Factory method to create a new CANN pool object based on the device type.
454
+ *
455
+ * @param device The device ID for which to create the pool.
456
+ * @return A unique pointer to the created CANN pool.
457
+ */
458
+ std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
459
+ int device) {
460
+ // return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
461
+ return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
462
+ }
463
+
464
+ // cann buffer
465
+ /**
466
+ * @brief Context for managing a CANN buffer associated with a specific device.
467
+ *
468
+ * This structure holds information about a CANN buffer, including the device
469
+ * ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
470
+ */
471
+ struct ggml_backend_cann_buffer_context {
472
+ int32_t device; ///< The device ID associated with this buffer context.
473
+ void* dev_ptr =
474
+ nullptr; ///< Pointer to the device memory allocated for the buffer.
475
+
476
+ /**
477
+ * @brief Constructor to initialize the CANN buffer context.
478
+ *
479
+ * @param device The device ID associated with this buffer context.
480
+ * @param dev_ptr Pointer to the device memory allocated for the buffer.
481
+ */
482
+ ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
483
+ : device(device),
484
+ dev_ptr(dev_ptr) {}
485
+
486
+ /**
487
+ * @brief Destructor to free the device memory allocated for the buffer.
488
+ */
489
+ ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
490
+ };
491
+
492
+ /**
493
+ * @brief Check if a buffer is a CANN buffer.
494
+ *
495
+ * This function checks if a given buffer is a CANN buffer by comparing its
496
+ * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
497
+ *
498
+ * @param buffer The buffer to check.
499
+ * @return true if the buffer is a CANN buffer, false otherwise.
500
+ */
501
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
502
+ static bool ggml_backend_buffer_is_cann(
503
+ ggml_backend_buffer_t buffer) {
504
+ return ggml_backend_buft_is_cann(buffer->buft);
505
+ }
506
+
507
+ /**
508
+ * @brief Free resources associated with a CANN buffer.
509
+ *
510
+ * This function frees the resources associated with a CANN buffer, including
511
+ * its context.
512
+ *
513
+ * @param buffer The CANN buffer to free.
514
+ */
515
+ static void ggml_backend_cann_buffer_free_buffer(
516
+ ggml_backend_buffer_t buffer) {
517
+ ggml_backend_cann_buffer_context* ctx =
518
+ (ggml_backend_cann_buffer_context*)buffer->context;
519
+ delete ctx;
520
+ }
521
+
522
+ /**
523
+ * @brief Retrieve the base pointer of a CANN buffer.
524
+ *
525
+ * This function returns the base pointer of a CANN buffer, which points to the
526
+ * device memory allocated for the buffer.
527
+ *
528
+ * @param buffer The CANN buffer whose base pointer is to be retrieved.
529
+ * @return A pointer to the base of the device memory allocated for the buffer.
530
+ */
531
+ static void* ggml_backend_cann_buffer_get_base(
532
+ ggml_backend_buffer_t buffer) {
533
+ ggml_backend_cann_buffer_context* ctx =
534
+ (ggml_backend_cann_buffer_context*)buffer->context;
535
+ return ctx->dev_ptr;
536
+ }
537
+
538
+ /**
539
+ * @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
540
+ * processing.
541
+ *
542
+ * This function transforms quantized Q4.0 tensor data into a format suitable
543
+ * for CANN processing. It extracts quantization values and scales from the
544
+ * source data and prepares them in a format expected by CANN operations.
545
+ *
546
+ * @param tensor Pointer to the tensor information.
547
+ * @param src Pointer to the source data in Q4.0 format.
548
+ * @param dst Pointer to the destination buffer where transformed data will be
549
+ * stored.
550
+ */
551
+ static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
552
+ const void* src,
553
+ void* dst) {
554
+
555
+ int64_t n_elems = ggml_nelements(tensor);
556
+ int64_t groups = n_elems / QK4_0;
557
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
558
+
559
+ uint8_t* quant_offset = (uint8_t*)dst;
560
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
561
+
562
+ for (int i = 0; i < groups; i++) {
563
+ const block_q4_0* group =
564
+ (const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
565
+ *scale_offset = group->d;
566
+ scale_offset++;
567
+
568
+ // 0-15
569
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
570
+ (*quant_offset) = (group->qs[j] & 0x0F);
571
+ (*quant_offset) |= ((group->qs[j + 1] << 4));
572
+ quant_offset++;
573
+ }
574
+
575
+ // 16-31
576
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
577
+ (*quant_offset) = (group->qs[j] >> 4);
578
+ (*quant_offset) |= (group->qs[j + 1] & 0xF0);
579
+ quant_offset++;
580
+ }
581
+ }
582
+
583
+ // put (uint4b_t -8) into int4b_t
584
+ for (quant_offset = (uint8_t*)dst;
585
+ quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
586
+ (*quant_offset) ^= 0x88;
587
+ }
588
+ }
589
+
590
+ /**
591
+ * @brief Transform CANN processed data back into quantized Q4.0 format.
592
+ *
593
+ * This function transforms CANN processed data back into quantized Q4.0 format.
594
+ * It reverses the transformation performed by
595
+ * ggml_backend_cann_transform_q4_0(), converting the data back into its
596
+ * original quantized form.
597
+ *
598
+ * @param tensor Pointer to the tensor information.
599
+ * @param src Pointer to the source buffer containing transformed data.
600
+ * @param dst Pointer to the destination buffer where the Q4.0 formatted data
601
+ * will be stored.
602
+ */
603
+ static void ggml_backend_cann_transform_back_q4_0(
604
+ const ggml_tensor* tensor, void* src, void* dst) {
605
+
606
+ int64_t n_elems = ggml_nelements(tensor);
607
+ int64_t groups = n_elems / QK4_0;
608
+ size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
609
+
610
+ uint8_t* quant_offset = (uint8_t*)src;
611
+ uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
612
+
613
+ for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
614
+ (*quant_offset) ^= 0x88;
615
+ }
616
+ quant_offset = (uint8_t*)src;
617
+
618
+ for (int i = 0; i < groups; i++) {
619
+ block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
620
+ group->d = *scale_offset;
621
+ scale_offset++;
622
+
623
+ // 0-15
624
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
625
+ group->qs[j] = ((*quant_offset) & 0x0F);
626
+ group->qs[j + 1] = ((*quant_offset) >> 4);
627
+ quant_offset++;
628
+ }
629
+
630
+ // 16-31
631
+ for (int j = 0; j < QK4_0 / 2; j += 2) {
632
+ group->qs[j] |= ((*quant_offset) << 4);
633
+ group->qs[j + 1] |= ((*quant_offset) & 0xF0);
634
+ quant_offset++;
635
+ }
636
+ }
637
+ }
638
+
639
+ /**
640
+ * @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
641
+ * processing.
642
+ *
643
+ * This function transforms quantized Q8.0 tensor data into a format suitable
644
+ * for CANN processing. It extracts quantization values and scales from the
645
+ * source data and prepares them in a format expected by CANN operations.
646
+ *
647
+ * @param tensor Pointer to the tensor information.
648
+ * @param src Pointer to the source data in Q8.0 format.
649
+ * @param dst Pointer to the destination buffer where transformed data will be
650
+ * stored.
651
+ */
652
+ static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
653
+ const void* src,
654
+ void* dst) {
655
+ int64_t n_elems = ggml_nelements(tensor);
656
+ int64_t groups = n_elems / QK8_0;
657
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
658
+
659
+ uint8_t* quant_offset = (uint8_t*)dst;
660
+ uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
661
+
662
+ for (int i = 0; i < groups; i++) {
663
+ const block_q8_0* group =
664
+ (const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
665
+ *scale_offset = group->d;
666
+ scale_offset++;
667
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
668
+ memcpy(quant_offset, group->qs, group_quant_size);
669
+ quant_offset += group_quant_size;
670
+ }
671
+ }
672
+
673
+ /**
674
+ * @brief Transform CANN processed data back into quantized Q8.0 format.
675
+ *
676
+ * This function transforms CANN processed data back into quantized Q8.0 format.
677
+ * It reverses the transformation performed by
678
+ * ggml_backend_cann_transform_q8_0(), converting the data back into its
679
+ * original quantized form.
680
+ *
681
+ * @param tensor Pointer to the tensor information.
682
+ * @param src Pointer to the source buffer containing transformed data.
683
+ * @param dst Pointer to the destination buffer where the Q8.0 formatted data
684
+ * will be stored.
685
+ */
686
+ static void ggml_backend_cann_transform_back_q8_0(
687
+ const ggml_tensor* tensor, const void* src, void* dst) {
688
+ int64_t n_elems = ggml_nelements(tensor);
689
+ int64_t groups = n_elems / QK8_0;
690
+ size_t quant_bytes = n_elems * sizeof(uint8_t);
691
+
692
+ const uint8_t* quant_offset = (const uint8_t*)src;
693
+ const uint16_t* scale_offset =
694
+ (const uint16_t*)((const char*)src + quant_bytes);
695
+
696
+ for (int i = 0; i < groups; i++) {
697
+ block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
698
+ group->d = *scale_offset;
699
+ scale_offset++;
700
+ size_t group_quant_size = QK8_0 * sizeof(uint8_t);
701
+ memcpy(group->qs, quant_offset, group_quant_size);
702
+ quant_offset += group_quant_size;
703
+ }
704
+ }
705
+
706
+ /**
707
+ * @brief Transform tensor data based on its type for CANN processing.
708
+ *
709
+ * This function transforms tensor data based on its quantization type for CANN
710
+ * processing. It dispatches the transformation based on the tensor's type to
711
+ * specialized functions handling Q4.0 and Q8.0 formats.
712
+ *
713
+ * @param tensor Pointer to the tensor information.
714
+ * @param src Pointer to the source data to be transformed.
715
+ * @param dst Pointer to the destination buffer where transformed data will be
716
+ * stored.
717
+ */
718
+ static void ggml_backend_cann_transform(ggml_tensor* tensor,
719
+ const void* src, void* dst) {
720
+ switch (tensor->type) {
721
+ case GGML_TYPE_Q4_0:
722
+ ggml_backend_cann_transform_q4_0(tensor, src, dst);
723
+ break;
724
+ case GGML_TYPE_Q8_0:
725
+ ggml_backend_cann_transform_q8_0(tensor, src, dst);
726
+ break;
727
+ default:
728
+ break;
729
+ }
730
+ }
731
+
732
+ /**
733
+ * @brief Transform CANN processed data back into tensor data based on its type.
734
+ *
735
+ * This function transforms CANN processed data back into tensor data based on
736
+ * its quantization type for Q4.0 and Q8.0 formats. It dispatches the
737
+ * transformation based on the tensor's type to specialized functions.
738
+ *
739
+ * @param tensor Pointer to the tensor information.
740
+ * @param src Pointer to the source data containing CANN processed data.
741
+ * @param dst Pointer to the destination buffer where transformed tensor data
742
+ * will be stored.
743
+ */
744
+ static void ggml_backend_cann_transform_back(
745
+ const ggml_tensor* tensor, void* src, void* dst) {
746
+ switch (tensor->type) {
747
+ case GGML_TYPE_Q4_0:
748
+ ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
749
+ break;
750
+ case GGML_TYPE_Q8_0:
751
+ ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
752
+ break;
753
+ default:
754
+ break;
755
+ }
756
+ }
757
+
758
+ /**
759
+ * @brief Check if transformation is needed for a given tensor type.
760
+ *
761
+ * This function checks if transformation is needed for a given tensor type
762
+ * to prepare data for CANN processing.
763
+ *
764
+ * @param type The tensor type to check.
765
+ * @return true if transformation is needed, false otherwise.
766
+ */
767
+ static bool need_transform(ggml_type type) {
768
+ switch (type) {
769
+ case GGML_TYPE_Q4_0:
770
+ case GGML_TYPE_Q8_0:
771
+ return true;
772
+ default:
773
+ return false;
774
+ }
775
+ }
776
+
777
+ /**
778
+ * @brief Initialize a tensor using data from a CANN buffer.
779
+ *
780
+ * This function initializes a tensor using data from a CANN buffer.
781
+ * It handles special cases such as views and quantization.
782
+ *
783
+ * @param buffer The CANN buffer from which to initialize the tensor.
784
+ * @param tensor Pointer to the tensor to be initialized.
785
+ */
786
+ static void ggml_backend_cann_buffer_init_tensor(
787
+ ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
788
+ if (tensor->view_src != NULL && tensor->view_offs == 0) {
789
+ GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
790
+ return;
791
+ }
792
+
793
+ // TODO: can backend doesn't support quantized yet. Just leave the code
794
+ // here.
795
+ if (ggml_is_quantized(tensor->type)) {
796
+ // Initialize padding to 0 to avoid possible NaN values
797
+ size_t original_size = ggml_nbytes(tensor);
798
+ size_t padded_size =
799
+ ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
800
+
801
+ if (padded_size > original_size && tensor->view_src == nullptr) {
802
+ size_t memset_size = padded_size - original_size;
803
+ ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
804
+ memset_size, 0, memset_size));
805
+ }
806
+ }
807
+ }
808
+
809
+ // TODO: need handle tensor which has paddings.
810
+ /**
811
+ * @brief Set tensor data in a CANN buffer.
812
+ *
813
+ * This function sets tensor data in a CANN buffer, handling transformations
814
+ * if needed based on the tensor's type.
815
+ *
816
+ * @param buffer The CANN buffer where the tensor data will be set.
817
+ * @param tensor Pointer to the tensor whose data will be set.
818
+ * @param data Pointer to the source data to be copied into the tensor.
819
+ * @param offset Offset in the source data from where to start copying.
820
+ * @param size Size of the data to be copied, in bytes.
821
+ */
822
+ static void ggml_backend_cann_buffer_set_tensor(
823
+ ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
824
+ size_t offset, size_t size) {
825
+ ggml_backend_cann_buffer_context *ctx =
826
+ (ggml_backend_cann_buffer_context *)buffer->context;
827
+
828
+ ggml_cann_set_device(ctx->device);
829
+ // TODO: refer to cann(#6017), it use thread's default stream.
830
+ // For acl, synchronous functions use this default stream.
831
+ // Why aclrtSynchronizeDevice?
832
+
833
+ if (!need_transform(tensor->type)) {
834
+ ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
835
+ ACL_MEMCPY_HOST_TO_DEVICE));
836
+ } else {
837
+ void *transform_buffer = malloc(size);
838
+ ggml_backend_cann_transform(tensor, data, transform_buffer);
839
+
840
+ ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
841
+ transform_buffer, size,
842
+ ACL_MEMCPY_HOST_TO_DEVICE));
843
+ free(transform_buffer);
844
+ }
845
+ }
846
+
847
+ /**
848
+ * @brief Get tensor data from a CANN buffer.
849
+ *
850
+ * This function retrieves tensor data from a CANN buffer, handling
851
+ * transformations if needed based on the tensor's type.
852
+ *
853
+ * @param buffer The CANN buffer from which to retrieve tensor data.
854
+ * @param tensor Pointer to the tensor whose data will be retrieved.
855
+ * @param data Pointer to the destination buffer where the tensor data will be
856
+ * copied.
857
+ * @param offset Offset in the destination buffer where to start copying.
858
+ * @param size Size of the data to be copied, in bytes.
859
+ */
860
+ static void ggml_backend_cann_buffer_get_tensor(
861
+ ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
862
+ size_t offset, size_t size) {
863
+ ggml_backend_cann_buffer_context* ctx =
864
+ (ggml_backend_cann_buffer_context*)buffer->context;
865
+
866
+ ggml_cann_set_device(ctx->device);
867
+
868
+ if (!need_transform(tensor->type)) {
869
+ ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
870
+ ACL_MEMCPY_DEVICE_TO_HOST));
871
+ } else {
872
+ void* transform_buffer = malloc(size);
873
+ ACL_CHECK(aclrtMemcpy(transform_buffer, size,
874
+ (char*)tensor->data + offset, size,
875
+ ACL_MEMCPY_DEVICE_TO_HOST));
876
+ ggml_backend_cann_transform_back(tensor, transform_buffer, data);
877
+ free(transform_buffer);
878
+ }
879
+ }
880
+
881
+ /**
882
+ * @brief Copy tensor data between CANN buffers if possible.
883
+ *
884
+ * This function copies tensor data between CANN buffers if the source and
885
+ * destination buffers are CANN buffers and they meet the necessary conditions
886
+ * (same device or devices can access each other).
887
+ *
888
+ * @param buffer The destination CANN buffer where the tensor data will be
889
+ * copied.
890
+ * @param src Pointer to the source tensor whose data will be copied.
891
+ * @param dst Pointer to the destination tensor where the data will be copied.
892
+ * @return true if the copy operation succeeded, false otherwise.
893
+ */
894
+ static bool ggml_backend_cann_buffer_cpy_tensor(
895
+ ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
896
+ if (ggml_backend_buffer_is_cann(src->buffer)) {
897
+ ggml_backend_cann_buffer_context* src_ctx =
898
+ (ggml_backend_cann_buffer_context*)src->buffer->context;
899
+ ggml_backend_cann_buffer_context* dst_ctx =
900
+ (ggml_backend_cann_buffer_context*)buffer->context;
901
+
902
+ size_t memcpy_size = ggml_nbytes(src);
903
+ // Same device.
904
+ if (src_ctx->device == dst_ctx->device) {
905
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
906
+ (const char*)src->data, memcpy_size,
907
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
908
+ return true;
909
+ } else {
910
+ // Different device but can access by peer.
911
+ int32_t canAccessPeer = 0;
912
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
913
+ dst_ctx->device));
914
+ if (canAccessPeer) {
915
+ ggml_cann_set_device(src_ctx->device);
916
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
917
+ ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
918
+ (const char*)src->data, memcpy_size,
919
+ ACL_MEMCPY_DEVICE_TO_DEVICE));
920
+ return true;
921
+ }
922
+ }
923
+ }
924
+ return false;
925
+ }
926
+
927
+ /**
928
+ * @brief Clear a CANN buffer by setting all its memory to a specified value.
929
+ *
930
+ * This function clears a CANN buffer by setting all its memory to a specified
931
+ * value.
932
+ *
933
+ * @param buffer The CANN buffer to be cleared.
934
+ * @param value The value to which each byte in the buffer will be set.
935
+ */
936
+ static void ggml_backend_cann_buffer_clear(
937
+ ggml_backend_buffer_t buffer, uint8_t value) {
938
+ ggml_backend_cann_buffer_context* ctx =
939
+ (ggml_backend_cann_buffer_context*)buffer->context;
940
+
941
+ ggml_cann_set_device(ctx->device);
942
+ ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
943
+ }
944
+
945
+ /**
946
+ * @brief Interface for a CANN buffer in the backend.
947
+ *
948
+ * This structure defines function pointers to operations that can be performed
949
+ * on a CANN buffer within the backend.
950
+ */
951
+ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
952
+ /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
953
+ /* .get_base = */ ggml_backend_cann_buffer_get_base,
954
+ /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
955
+ /* .memset_tensor = */ NULL,
956
+ /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
957
+ /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
958
+ /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
959
+ /* .clear = */ ggml_backend_cann_buffer_clear,
960
+ /* .reset = */ NULL,
961
+ };
962
+
963
+ // cann buffer type
964
+ /**
965
+ * @brief Structure representing context information for a specific backend
966
+ * buffer type.
967
+ */
968
+ struct ggml_backend_cann_buffer_type_context {
969
+ int32_t
970
+ device; /**< Device identifier associated with the buffer context. */
971
+ std::string name; /**< Name associated with the buffer context. */
972
+ };
973
+
974
+ /**
975
+ * @brief Retrieves the name associated with a CANN buffer type.
976
+ *
977
+ * This function returns the descriptive name associated with the specified
978
+ * CANN buffer type context.
979
+ *
980
+ * @param buft Pointer to the buffer type context.
981
+ * @return Const pointer to the C-style string containing the name.
982
+ */
983
+ static const char* ggml_backend_cann_buffer_type_name(
984
+ ggml_backend_buffer_type_t buft) {
985
+ ggml_backend_cann_buffer_type_context* buft_ctx =
986
+ (ggml_backend_cann_buffer_type_context*)buft->context;
987
+
988
+ return buft_ctx->name.c_str();
989
+ }
990
+
991
+ /**
992
+ * @brief Allocates a new CANN buffer of the specified type and size.
993
+ *
994
+ * This function allocates a new CANN buffer on the specified device with the
995
+ * given size.
996
+ *
997
+ * @param buft Pointer to the buffer type context.
998
+ * @param size Size in bytes of the buffer to allocate.
999
+ * @return Pointer to the allocated buffer, or nullptr if allocation fails.
1000
+ */
1001
+ static ggml_backend_buffer_t
1002
+ ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
1003
+ size_t size) {
1004
+ ggml_backend_cann_buffer_type_context* buft_ctx =
1005
+ (ggml_backend_cann_buffer_type_context*)buft->context;
1006
+
1007
+ ggml_cann_set_device(buft_ctx->device);
1008
+
1009
+ size = std::max(size, (size_t)1);
1010
+
1011
+ void* dev_ptr;
1012
+ aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
1013
+ if (err != ACL_SUCCESS) {
1014
+ GGML_LOG_ERROR(
1015
+ "%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
1016
+ __func__, size / 1024.0 / 1024.0, buft_ctx->device,
1017
+ aclGetRecentErrMsg());
1018
+ return nullptr;
1019
+ }
1020
+
1021
+ ggml_backend_cann_buffer_context* ctx =
1022
+ new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
1023
+
1024
+ return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
1025
+ ctx, size);
1026
+ }
1027
+
1028
+ /**
1029
+ * @brief Retrieves the memory alignment requirement for CANN buffers of this
1030
+ * type.
1031
+ *
1032
+ * This function returns the alignment requirement in bytes for memory allocated
1033
+ * by the CANN buffer type.
1034
+ *
1035
+ * @param buft Pointer to the buffer type context (unused in this
1036
+ * implementation).
1037
+ * @return The alignment requirement in bytes (fixed at 128 bytes for CANN
1038
+ * buffers).
1039
+ */
1040
+ static size_t ggml_backend_cann_buffer_type_get_alignment(
1041
+ ggml_backend_buffer_type_t buft) {
1042
+ return 128;
1043
+
1044
+ GGML_UNUSED(buft);
1045
+ }
1046
+
1047
+ /**
1048
+ * @brief Calculates the allocation size required for a tensor in a CANN buffer.
1049
+ *
1050
+ * Computes the total allocation size needed for storing the tensor's data in a
1051
+ * CANN buffer, considering any necessary padding or adjustments for quantized
1052
+ * types.
1053
+ *
1054
+ * @param buft Pointer to the buffer type context (unused in this
1055
+ * implementation).
1056
+ * @param tensor Pointer to the tensor for which the allocation size is
1057
+ * calculated.
1058
+ * @return The total allocation size in bytes required for the tensor in the
1059
+ * CANN buffer.
1060
+ */
1061
+ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
1062
+ ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
1063
+ size_t size = ggml_nbytes(tensor);
1064
+ int64_t ne0 = tensor->ne[0];
1065
+
1066
+ // last line must bigger than 32, because every single op deal at
1067
+ // least 32 bytes.
1068
+ // TODO: quantized type?
1069
+ // int64_t line_size = ne0 * ggml_element_size(tensor);
1070
+ // int64_t line_size_align_32 = (line_size + 31) & ~31;
1071
+ // size += (line_size_align_32 - line_size);
1072
+
1073
+ // TODO: not support quantized yet.
1074
+ // TODO: consider un-continue tensor.
1075
+ if (ggml_is_quantized(tensor->type)) {
1076
+ if (ne0 % MATRIX_ROW_PADDING != 0) {
1077
+ size += ggml_row_size(
1078
+ tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
1079
+ }
1080
+ }
1081
+
1082
+ return size;
1083
+
1084
+ GGML_UNUSED(buft);
1085
+ }
1086
+
1087
+ static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
1088
+ return false;
1089
+
1090
+ GGML_UNUSED(buft);
1091
+ }
1092
+
1093
+ /**
1094
+ * @brief Interface for managing CANN buffer types in the GGML backend.
1095
+ *
1096
+ * Provides function pointers for allocating, querying properties, and managing
1097
+ * memory for CANN buffer types in the GGML backend.
1098
+ */
1099
+ static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
1100
+ /* .get_name = */ ggml_backend_cann_buffer_type_name,
1101
+ /* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
1102
+ /* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
1103
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1104
+ /* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
1105
+ /* .is_host = */ ggml_backend_cann_buffer_type_is_host,
1106
+ };
1107
+
1108
+ /**
1109
+ * @brief Retrieves the CANN buffer type for a specified device.
1110
+ *
1111
+ * This function initializes and returns the buffer type interface associated
1112
+ * with the given device. It ensures thread-safe access using a mutex.
1113
+ *
1114
+ * @param device The device index for which to retrieve the buffer type.
1115
+ * @return A pointer to the buffer type interface for the specified device, or
1116
+ * nullptr if the device index is out of range.
1117
+ */
1118
+ ggml_backend_buffer_type_t
1119
+ ggml_backend_cann_buffer_type(int32_t device) {
1120
+ static std::mutex mutex;
1121
+ std::lock_guard<std::mutex> lock(mutex);
1122
+
1123
+ if (device >= ggml_backend_cann_get_device_count()) {
1124
+ return nullptr;
1125
+ }
1126
+
1127
+ static ggml_backend_buffer_type
1128
+ ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
1129
+
1130
+ static bool ggml_backend_cann_buffer_type_initialized = false;
1131
+
1132
+ if (!ggml_backend_cann_buffer_type_initialized) {
1133
+ for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
1134
+ ggml_backend_cann_buffer_types[i] = {
1135
+ /* .iface = */ ggml_backend_cann_buffer_type_interface,
1136
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
1137
+ /* .context = */
1138
+ new ggml_backend_cann_buffer_type_context{
1139
+ i, "CANN" + std::to_string(i)},
1140
+ };
1141
+ }
1142
+ ggml_backend_cann_buffer_type_initialized = true;
1143
+ }
1144
+
1145
+ return &ggml_backend_cann_buffer_types[device];
1146
+ }
1147
+
1148
+ /**
1149
+ * @brief Retrieves the name associated with a CANN host buffer type.
1150
+ *
1151
+ * This function returns the descriptive name associated with the specified
1152
+ * CANN host buffer type context.
1153
+ *
1154
+ * @param buft Pointer to the host buffer type context.
1155
+ * @return Const pointer to the C-style string containing the name.
1156
+ */
1157
+ static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
1158
+ return "CANN_Host";
1159
+
1160
+ GGML_UNUSED(buft);
1161
+ }
1162
+
1163
+ /**
1164
+ * @brief Retrieves the name associated with a CANN host buffer.
1165
+ *
1166
+ * This function returns the descriptive name associated with the specified
1167
+ * CANN host buffer context.
1168
+ *
1169
+ * @param buft Pointer to the host buffer context.
1170
+ * @return Const pointer to the C-style string containing the name.
1171
+ */
1172
+ static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
1173
+ return "CANN_Host";
1174
+
1175
+ GGML_UNUSED(buffer);
1176
+ }
1177
+
1178
+ /**
1179
+ * @brief Free resources associated with a CANN host buffer.
1180
+ *
1181
+ * This function frees the resources associated with a CANN host buffer, including
1182
+ * its context.
1183
+ *
1184
+ * @param buffer The CANN host buffer to free.
1185
+ */
1186
+ static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
1187
+ ACL_CHECK(aclrtFreeHost(buffer->context));
1188
+ }
1189
+
1190
+ /**
1191
+ * @brief Allocates a new CANN host buffer of the specified size.
1192
+ *
1193
+ * This function allocates a new CANN host buffer with the given size.
1194
+ * @param size Size in bytes of the host buffer to allocate.
1195
+ * @return Pointer to the allocated host buffer, or nullptr if allocation fails.
1196
+ */
1197
+ static void * ggml_cann_host_malloc(size_t size) {
1198
+ if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
1199
+ return nullptr;
1200
+ }
1201
+
1202
+ void * hostPtr = nullptr;
1203
+ aclError err = aclrtMallocHost((void **) &hostPtr, size);
1204
+ if (err != ACL_SUCCESS) {
1205
+
1206
+ GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
1207
+ size / 1024.0 / 1024.0, aclGetRecentErrMsg());
1208
+ return nullptr;
1209
+ }
1210
+ return hostPtr;
1211
+ }
1212
+
1213
+ /**
1214
+ * @brief Allocates a new CANN host buffer of the specified type and size.
1215
+ *
1216
+ * @param buft Pointer to the host buffer type context.
1217
+ * @param size Size in bytes of the host buffer to allocate.
1218
+ * @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
1219
+ */
1220
+ static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
1221
+ void * hostPtr = ggml_cann_host_malloc(size);
1222
+
1223
+ if (hostPtr == nullptr) {
1224
+ // fallback to cpu buffer
1225
+ return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
1226
+ }
1227
+
1228
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
1229
+ buffer->buft = buft;
1230
+ buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
1231
+
1232
+ return buffer;
1233
+ }
1234
+
1235
+ /**
1236
+ * @brief Interface for managing CANN host buffer types in the GGML backend.
1237
+ *
1238
+ * Provides function pointers for allocating, querying properties, and managing
1239
+ * memory for CANN buffer types in the GGML backend.
1240
+ */
1241
+ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
1242
+ static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
1243
+ /* .iface = */ {
1244
+ /* .get_name = */ ggml_backend_cann_host_buffer_type_name,
1245
+ /* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
1246
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
1247
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
1248
+ /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
1249
+ /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
1250
+ },
1251
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
1252
+ /* .context = */ nullptr,
1253
+ };
1254
+
1255
+ return &ggml_backend_cann_buffer_type_host;
1256
+ }
1257
+
1258
+ /**
1259
+ * @brief Computes the forward operation for a given tensor using CANN
1260
+ * operations.
1261
+ *
1262
+ * This function selects the appropriate CANN operation based on the type of
1263
+ * operation specified in the tensor and performs the computation.
1264
+ *
1265
+ * @param ctx The CANN context containing necessary resources and
1266
+ * configurations.
1267
+ * @param dst The destination tensor where the result of the computation will be
1268
+ * stored.
1269
+ * @return true if the computation was successful; false otherwise.
1270
+ */
1271
+ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
1272
+ struct ggml_tensor* dst) {
1273
+ switch (dst->op) {
1274
+ case GGML_OP_REPEAT:
1275
+ ggml_cann_repeat(ctx, dst);
1276
+ break;
1277
+ case GGML_OP_GET_ROWS:
1278
+ ggml_cann_get_rows(ctx, dst);
1279
+ break;
1280
+ case GGML_OP_DUP:
1281
+ ggml_cann_dup(ctx, dst);
1282
+ break;
1283
+ case GGML_OP_ADD:
1284
+ ggml_cann_add(ctx, dst);
1285
+ break;
1286
+ case GGML_OP_ACC:
1287
+ ggml_cann_acc(ctx, dst);
1288
+ break;
1289
+ case GGML_OP_MUL:
1290
+ ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
1291
+ break;
1292
+ case GGML_OP_DIV:
1293
+ ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
1294
+ break;
1295
+ case GGML_OP_UNARY:
1296
+ switch (ggml_get_unary_op(dst)) {
1297
+ case GGML_UNARY_OP_GELU:
1298
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1299
+ ctx, dst);
1300
+ break;
1301
+ case GGML_UNARY_OP_SILU:
1302
+ ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
1303
+ ctx, dst);
1304
+ break;
1305
+ // TODO: Use faster gelu??
1306
+ case GGML_UNARY_OP_GELU_QUICK:
1307
+ ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
1308
+ ctx, dst);
1309
+ break;
1310
+ case GGML_UNARY_OP_TANH:
1311
+ ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
1312
+ ctx, dst);
1313
+ break;
1314
+ case GGML_UNARY_OP_RELU:
1315
+ ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
1316
+ ctx, dst);
1317
+ break;
1318
+ case GGML_UNARY_OP_HARDSIGMOID:
1319
+ ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
1320
+ aclnnHardsigmoid>(ctx, dst);
1321
+ break;
1322
+ case GGML_UNARY_OP_HARDSWISH:
1323
+ ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
1324
+ aclnnHardswish>(ctx, dst);
1325
+ break;
1326
+ default:
1327
+ return false;
1328
+ }
1329
+ break;
1330
+ case GGML_OP_NORM:
1331
+ ggml_cann_norm(ctx, dst);
1332
+ break;
1333
+ case GGML_OP_GROUP_NORM:
1334
+ ggml_cann_group_norm(ctx, dst);
1335
+ break;
1336
+ case GGML_OP_CONCAT:
1337
+ ggml_cann_concat(ctx, dst);
1338
+ break;
1339
+ case GGML_OP_UPSCALE:
1340
+ ggml_cann_upsample_nearest2d(ctx, dst);
1341
+ break;
1342
+ case GGML_OP_PAD:
1343
+ ggml_cann_pad(ctx, dst);
1344
+ break;
1345
+ case GGML_OP_ARANGE:
1346
+ ggml_cann_arange(ctx, dst);
1347
+ break;
1348
+ case GGML_OP_TIMESTEP_EMBEDDING:
1349
+ ggml_cann_timestep_embedding(ctx, dst);
1350
+ break;
1351
+ case GGML_OP_LEAKY_RELU:
1352
+ ggml_cann_leaky_relu(ctx, dst);
1353
+ break;
1354
+ case GGML_OP_RMS_NORM:
1355
+ ggml_cann_rms_norm(ctx, dst);
1356
+ break;
1357
+ case GGML_OP_MUL_MAT:
1358
+ ggml_cann_mul_mat(ctx, dst);
1359
+ break;
1360
+ case GGML_OP_MUL_MAT_ID:
1361
+ return false;
1362
+ case GGML_OP_SCALE:
1363
+ ggml_cann_scale(ctx, dst);
1364
+ break;
1365
+ case GGML_OP_SQR:
1366
+ ggml_cann_sqr(ctx, dst);
1367
+ break;
1368
+ case GGML_OP_CLAMP:
1369
+ ggml_cann_clamp(ctx, dst);
1370
+ break;
1371
+ case GGML_OP_CPY:
1372
+ ggml_cann_cpy(ctx, dst);
1373
+ break;
1374
+ case GGML_OP_CONT:
1375
+ ggml_cann_dup(ctx, dst);
1376
+ break;
1377
+ case GGML_OP_NONE:
1378
+ case GGML_OP_RESHAPE:
1379
+ case GGML_OP_VIEW:
1380
+ case GGML_OP_PERMUTE:
1381
+ case GGML_OP_TRANSPOSE:
1382
+ break;
1383
+ case GGML_OP_DIAG_MASK_INF:
1384
+ ggml_cann_diag_mask(ctx, dst, -INFINITY);
1385
+ break;
1386
+ case GGML_OP_SOFT_MAX:
1387
+ ggml_cann_softmax(ctx, dst);
1388
+ break;
1389
+ case GGML_OP_ROPE:
1390
+ ggml_cann_rope(ctx, dst);
1391
+ break;
1392
+ case GGML_OP_IM2COL:
1393
+ ggml_cann_im2col(ctx, dst);
1394
+ break;
1395
+ case GGML_OP_POOL_2D:
1396
+ ggml_cann_pool2d(ctx, dst);
1397
+ break;
1398
+ case GGML_OP_SUM_ROWS:
1399
+ ggml_cann_sum_rows(ctx, dst);
1400
+ break;
1401
+ case GGML_OP_ARGSORT:
1402
+ ggml_cann_argsort(ctx, dst);
1403
+ break;
1404
+ default:
1405
+ return false;
1406
+ }
1407
+
1408
+ return true;
1409
+ }
1410
+
1411
+ // backend
1412
+ /**
1413
+ * @brief Retrieves the name associated with the CANN backend.
1414
+ *
1415
+ * This function returns the name assigned to the CANN backend, which is stored
1416
+ * in the context of the provided backend structure.
1417
+ *
1418
+ * @param backend Pointer to the CANN backend structure.
1419
+ * @return A pointer to a constant string representing the backend name.
1420
+ */
1421
+ static const char* ggml_backend_cann_name(ggml_backend_t backend) {
1422
+ ggml_backend_cann_context* cann_ctx =
1423
+ (ggml_backend_cann_context*)backend->context;
1424
+
1425
+ return cann_ctx->name.c_str();
1426
+ }
1427
+
1428
+ /**
1429
+ * @brief Frees resources associated with the CANN backend.
1430
+ *
1431
+ * This function releases resources associated with the CANN backend context
1432
+ * and resets the device associated with the backend to its initial state.
1433
+ *
1434
+ * @param backend Pointer to the CANN backend structure to be freed.
1435
+ */
1436
+ static void ggml_backend_cann_free(ggml_backend_t backend) {
1437
+ ggml_backend_cann_context* cann_ctx =
1438
+ (ggml_backend_cann_context*)backend->context;
1439
+ ACL_CHECK(aclrtSynchronizeDevice());
1440
+ ACL_CHECK(aclrtResetDevice(cann_ctx->device));
1441
+
1442
+ // finalize when last backend freed.
1443
+ if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
1444
+ ACL_CHECK(aclFinalize());
1445
+ }
1446
+
1447
+ delete cann_ctx;
1448
+ delete backend;
1449
+ }
1450
+
1451
+ /**
1452
+ * @brief Sets tensor data asynchronously in the CANN backend.
1453
+ *
1454
+ * This function asynchronously sets tensor data in the CANN backend. Depending
1455
+ * on the tensor type, it may perform data transformations before copying data
1456
+ * to the device.
1457
+ *
1458
+ * @param backend Pointer to the CANN backend structure.
1459
+ * @param tensor Pointer to the tensor structure to set data for.
1460
+ * @param data Pointer to the host data to copy to the tensor.
1461
+ * @param offset Offset in bytes within the host data.
1462
+ * @param size Size of the data to copy in bytes.
1463
+ */
1464
+ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
1465
+ ggml_tensor *tensor,
1466
+ const void *data,
1467
+ size_t offset,
1468
+ size_t size) {
1469
+ ggml_backend_cann_context *cann_ctx =
1470
+ (ggml_backend_cann_context *)backend->context;
1471
+
1472
+ if (!need_transform(tensor->type)) {
1473
+ ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
1474
+ size, ACL_MEMCPY_HOST_TO_DEVICE,
1475
+ cann_ctx->stream()));
1476
+ } else {
1477
+ void *transform_buffer = malloc(size);
1478
+ ggml_backend_cann_transform(tensor, data, transform_buffer);
1479
+
1480
+ ACL_CHECK(aclrtMemcpyAsync(
1481
+ (char *)tensor->data + offset, size, transform_buffer, size,
1482
+ ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
1483
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1484
+ free(transform_buffer);
1485
+ }
1486
+ }
1487
+
1488
+ static void ggml_backend_cann_get_tensor_async(
1489
+ ggml_backend_t backend, const ggml_tensor *tensor, void *data,
1490
+ size_t offset, size_t size) {
1491
+ ggml_backend_cann_context *cann_ctx =
1492
+ (ggml_backend_cann_context *)backend->context;
1493
+ ggml_backend_buffer_t buf =
1494
+ tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
1495
+
1496
+ GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
1497
+ "unsupported buffer type");
1498
+
1499
+ if (!need_transform(tensor->type)) {
1500
+ ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
1501
+ size, ACL_MEMCPY_DEVICE_TO_HOST,
1502
+ cann_ctx->stream()));
1503
+ } else {
1504
+ void *transform_buffer = malloc(size);
1505
+ ACL_CHECK(aclrtMemcpyAsync(
1506
+ transform_buffer, size, (char *)tensor->data + offset, size,
1507
+ ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
1508
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1509
+ ggml_backend_cann_transform_back(tensor, transform_buffer, data);
1510
+ free(transform_buffer);
1511
+ }
1512
+ }
1513
+
1514
+ /**
1515
+ * @brief Asynchronously copies tensor data between CANN backends.
1516
+ *
1517
+ * This function copies tensor data asynchronously between two CANN backends. It
1518
+ * checks if both tensors reside in CANN buffers and whether the devices support
1519
+ * peer-to-peer access for direct copying. If not, it returns false.
1520
+ *
1521
+ * @param backend_src Pointer to the source CANN backend structure.
1522
+ * @param backend_dst Pointer to the destination CANN backend structure.
1523
+ * @param src Pointer to the source tensor to copy data from.
1524
+ * @param dst Pointer to the destination tensor to copy data to.
1525
+ * @return true if the copy operation succeeds, false otherwise.
1526
+ */
1527
+ static bool ggml_backend_cann_cpy_tensor_async(
1528
+ ggml_backend_t backend_src, ggml_backend_t backend_dst,
1529
+ const ggml_tensor* src, ggml_tensor* dst) {
1530
+ GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
1531
+ ggml_backend_is_cann(backend_dst));
1532
+
1533
+ if (!ggml_backend_buffer_is_cann(src->buffer) ||
1534
+ !ggml_backend_buffer_is_cann(dst->buffer)) {
1535
+ return false;
1536
+ }
1537
+
1538
+ ggml_backend_buffer_t buf_src =
1539
+ src->view_src ? src->view_src->buffer : src->buffer;
1540
+ ggml_backend_buffer_t buf_dst =
1541
+ dst->view_src ? dst->view_src->buffer : dst->buffer;
1542
+
1543
+ ggml_backend_cann_context* cann_ctx_src =
1544
+ (ggml_backend_cann_context*)backend_src->context;
1545
+ ggml_backend_cann_context* cann_ctx_dst =
1546
+ (ggml_backend_cann_context*)backend_dst->context;
1547
+
1548
+ size_t copy_size = ggml_nbytes(dst);
1549
+ if (backend_src != backend_dst) {
1550
+ ggml_backend_cann_buffer_context* buf_ctx_src =
1551
+ (ggml_backend_cann_buffer_context*)buf_src->context;
1552
+ ggml_backend_cann_buffer_context* buf_ctx_dst =
1553
+ (ggml_backend_cann_buffer_context*)buf_dst->context;
1554
+
1555
+ GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
1556
+ GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
1557
+
1558
+ int32_t canAccessPeer = 0;
1559
+ ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
1560
+ cann_ctx_dst->device));
1561
+ if (!canAccessPeer) {
1562
+ return false;
1563
+ }
1564
+
1565
+ // need open both directions for memcpyasync between devices.
1566
+ ggml_cann_set_device(cann_ctx_dst->device);
1567
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
1568
+ ggml_cann_set_device(cann_ctx_src->device);
1569
+ ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
1570
+
1571
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1572
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1573
+ cann_ctx_src->stream()));
1574
+
1575
+ //TODO: workaround for Event didn`t work here.
1576
+ aclrtSynchronizeStream(cann_ctx_src->stream());
1577
+ } else {
1578
+ // src and dst are on the same backend
1579
+ ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
1580
+ ACL_MEMCPY_DEVICE_TO_DEVICE,
1581
+ cann_ctx_dst->stream()));
1582
+ }
1583
+
1584
+ return true;
1585
+ }
1586
+
1587
+ /**
1588
+ * @brief Synchronizes a CANN backend.
1589
+ *
1590
+ * This function synchronizes the specified CANN backend by waiting for all
1591
+ * operations in its associated stream to complete.
1592
+ *
1593
+ * @param backend Pointer to the CANN backend structure to synchronize.
1594
+ */
1595
+ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
1596
+ ggml_backend_cann_context* cann_ctx =
1597
+ (ggml_backend_cann_context*)backend->context;
1598
+
1599
+ ggml_cann_set_device(cann_ctx->device);
1600
+
1601
+ ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
1602
+ }
1603
+
1604
+ /**
1605
+ * @brief Computes a computational graph using a CANN backend.
1606
+ *
1607
+ * This function computes the operations defined in the computational graph
1608
+ * using the specified CANN backend.
1609
+ *
1610
+ * @param backend Pointer to the CANN backend structure to use for computation.
1611
+ * @param cgraph Pointer to the computational graph structure containing nodes
1612
+ * representing operations to be computed.
1613
+ * @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
1614
+ * completes successfully, otherwise an appropriate error status.
1615
+ */
1616
+ static enum ggml_status ggml_backend_cann_graph_compute(
1617
+ ggml_backend_t backend, ggml_cgraph* cgraph) {
1618
+ ggml_backend_cann_context* cann_ctx =
1619
+ (ggml_backend_cann_context*)backend->context;
1620
+
1621
+ ggml_cann_set_device(cann_ctx->device);
1622
+
1623
+ for (int i = 0; i < cgraph->n_nodes; i++) {
1624
+ ggml_tensor* node = cgraph->nodes[i];
1625
+
1626
+ if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
1627
+ continue;
1628
+ }
1629
+
1630
+ bool ok = ggml_cann_compute_forward(*cann_ctx, node);
1631
+
1632
+ if (!ok) {
1633
+ GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
1634
+ node->name, ggml_op_name(node->op));
1635
+ }
1636
+ GGML_ASSERT(ok);
1637
+ }
1638
+
1639
+ return GGML_STATUS_SUCCESS;
1640
+ }
1641
+
1642
+ /**
1643
+ * @brief Checks if the CANN backend supports a specific operation.
1644
+ *
1645
+ * This function checks whether the specified operation is supported by the
1646
+ * CANN backend.
1647
+ *
1648
+ * @param backend Pointer to the CANN backend structure to check support for
1649
+ * the operation.
1650
+ * @param op Pointer to the tensor representing the operation to check.
1651
+ * @return bool Returns true if the operation is supported by the backend,
1652
+ * otherwise false.
1653
+ */
1654
+ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
1655
+ const ggml_tensor* op) {
1656
+ switch (op->op) {
1657
+ case GGML_OP_UNARY:
1658
+ switch (ggml_get_unary_op(op)) {
1659
+ case GGML_UNARY_OP_GELU:
1660
+ case GGML_UNARY_OP_SILU:
1661
+ case GGML_UNARY_OP_RELU:
1662
+ case GGML_UNARY_OP_HARDSIGMOID:
1663
+ case GGML_UNARY_OP_HARDSWISH:
1664
+ case GGML_UNARY_OP_GELU_QUICK:
1665
+ case GGML_UNARY_OP_TANH:
1666
+ return true;
1667
+ default:
1668
+ return false;
1669
+ }
1670
+ case GGML_OP_MUL_MAT: {
1671
+ switch (op->src[0]->type) {
1672
+ case GGML_TYPE_F16:
1673
+ case GGML_TYPE_F32:
1674
+ case GGML_TYPE_Q8_0:
1675
+ // TODO: fix me
1676
+ // Current groupsize should not be greater than k-1 in
1677
+ // aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
1678
+ case GGML_TYPE_Q4_0:
1679
+ return true;
1680
+ default:
1681
+ return false;
1682
+ }
1683
+ }
1684
+ case GGML_OP_MUL_MAT_ID:
1685
+ return false;
1686
+ // embedding
1687
+ case GGML_OP_GET_ROWS: {
1688
+ switch (op->src[0]->type) {
1689
+ case GGML_TYPE_F32:
1690
+ case GGML_TYPE_F16:
1691
+ case GGML_TYPE_Q4_0:
1692
+ case GGML_TYPE_Q8_0:
1693
+ return true;
1694
+ default:
1695
+ return false;
1696
+ }
1697
+ } break;
1698
+ case GGML_OP_CPY: {
1699
+ switch (op->type) {
1700
+ case GGML_TYPE_F32:
1701
+ case GGML_TYPE_F16:
1702
+ case GGML_TYPE_Q8_0:
1703
+ case GGML_TYPE_Q4_0:
1704
+ return true;
1705
+ default:
1706
+ return false;
1707
+ }
1708
+ }
1709
+ case GGML_OP_DUP:
1710
+ case GGML_OP_REPEAT:
1711
+ case GGML_OP_CONCAT:
1712
+ case GGML_OP_NONE:
1713
+ case GGML_OP_RESHAPE:
1714
+ case GGML_OP_VIEW:
1715
+ case GGML_OP_PERMUTE:
1716
+ case GGML_OP_TRANSPOSE:
1717
+ case GGML_OP_NORM:
1718
+ case GGML_OP_ADD:
1719
+ case GGML_OP_MUL:
1720
+ case GGML_OP_DIV:
1721
+ case GGML_OP_RMS_NORM:
1722
+ case GGML_OP_SCALE:
1723
+ case GGML_OP_SQR:
1724
+ case GGML_OP_CLAMP:
1725
+ case GGML_OP_CONT:
1726
+ case GGML_OP_DIAG_MASK_INF:
1727
+ case GGML_OP_SOFT_MAX:
1728
+ case GGML_OP_ROPE:
1729
+ case GGML_OP_IM2COL:
1730
+ case GGML_OP_POOL_2D:
1731
+ case GGML_OP_SUM_ROWS:
1732
+ case GGML_OP_ARGSORT:
1733
+ case GGML_OP_ACC:
1734
+ case GGML_OP_GROUP_NORM:
1735
+ case GGML_OP_UPSCALE:
1736
+ case GGML_OP_PAD:
1737
+ case GGML_OP_ARANGE:
1738
+ case GGML_OP_TIMESTEP_EMBEDDING:
1739
+ case GGML_OP_LEAKY_RELU:
1740
+ return true;
1741
+ default:
1742
+ return false;
1743
+ }
1744
+
1745
+ GGML_UNUSED(dev);
1746
+ }
1747
+
1748
+ /**
1749
+ * @brief Checks if the backend buffer type is associated with the CANN backend.
1750
+ *
1751
+ * This function checks whether the provided backend buffer type is associated
1752
+ * with the CANN backend based on the comparison of its name retrieval function
1753
+ * pointer.
1754
+ *
1755
+ * @param buft Pointer to the backend buffer type to check.
1756
+ * @return bool Returns true if the buffer type is associated with the CANN
1757
+ * backend, otherwise false.
1758
+ */
1759
+ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
1760
+ return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
1761
+ }
1762
+
1763
+ /**
1764
+ * @brief Determines if a tensor operation should be offloaded to the CANN
1765
+ * backend.
1766
+ *
1767
+ * This function checks if a given tensor operation should be offloaded to the
1768
+ * CANN backend based on the operation type and the size of the tensor. It
1769
+ * returns true if the second dimension (ne[1]) of the tensor is greater than or
1770
+ * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
1771
+ *
1772
+ * @param backend Pointer to the CANN backend.
1773
+ * @param op Pointer to the tensor operation to check.
1774
+ * @return bool Returns true if the operation should be offloaded, otherwise
1775
+ * false.
1776
+ */
1777
+ static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
1778
+ const ggml_tensor* op) {
1779
+ const int min_batch_size = 32;
1780
+ GGML_UNUSED(dev);
1781
+
1782
+ return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
1783
+ }
1784
+
1785
+ /**
1786
+ * @brief Records an event on the CANN backend stream.
1787
+ *
1788
+ * This function records the given event on the ACL runtime stream associated
1789
+ * with the backend context.
1790
+ *
1791
+ * @param event Pointer to the event structure to be recorded.
1792
+ */
1793
+ static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
1794
+ ggml_backend_cann_context* cann_ctx =
1795
+ (ggml_backend_cann_context*)backend->context;
1796
+ ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
1797
+ }
1798
+
1799
+ /**
1800
+ * @brief Waits for a recorded event to complete on the CANN backend stream.
1801
+ *
1802
+ * This function makes the given backend wait for the event to complete on its
1803
+ * ACL runtime stream.
1804
+ *
1805
+ * @param backend Pointer to the backend structure.
1806
+ * @param event Pointer to the event structure that the backend needs to wait
1807
+ * for.
1808
+ */
1809
+ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
1810
+ ggml_backend_event_t event) {
1811
+ ggml_backend_cann_context* cann_ctx =
1812
+ (ggml_backend_cann_context*)backend->context;
1813
+ if (ggml_backend_is_cann(backend)) {
1814
+ ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
1815
+ (aclrtEvent)event->context));
1816
+ } else {
1817
+ GGML_ABORT("fatal error");
1818
+ }
1819
+ }
1820
+
1821
+ /**
1822
+ * @brief Structure defining the interface for the CANN backend.
1823
+ *
1824
+ * This structure contains function pointers for various operations
1825
+ * supported by the CANN backend, including name retrieval, memory
1826
+ * management, tensor operations, synchronization, and event handling.
1827
+ */
1828
+ static const ggml_backend_i ggml_backend_cann_interface = {
1829
+ /* .get_name = */ ggml_backend_cann_name,
1830
+ /* .free = */ ggml_backend_cann_free,
1831
+ /* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
1832
+ /* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
1833
+ /* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
1834
+ /* .synchronize = */ ggml_backend_cann_synchronize,
1835
+ /* .graph_plan_create = */ NULL,
1836
+ /* .graph_plan_free = */ NULL,
1837
+ /* .graph_plan_update = */ NULL,
1838
+ /* .graph_plan_compute = */ NULL,
1839
+ /* .graph_compute = */ ggml_backend_cann_graph_compute,
1840
+ /* .event_record = */ ggml_backend_cann_event_record,
1841
+ /* .event_wait = */ ggml_backend_cann_event_wait,
1842
+ };
1843
+
1844
+ /**
1845
+ * @brief Return the hardcoded GUID for the CANN backend.
1846
+ *
1847
+ * This function returns a static GUID which uniquely identifies the CANN
1848
+ * backend.
1849
+ *
1850
+ * @return A pointer to the static GUID.
1851
+ */
1852
+ static ggml_guid_t ggml_backend_cann_guid() {
1853
+ static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
1854
+ 0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
1855
+ return &guid;
1856
+ }
1857
+
1858
+ // backend device
1859
+ struct ggml_backend_cann_device_context {
1860
+ int device;
1861
+ std::string name;
1862
+ std::string description;
1863
+ };
1864
+
1865
+ static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
1866
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1867
+ return ctx->name.c_str();
1868
+ }
1869
+
1870
+ static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
1871
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1872
+ return ctx->description.c_str();
1873
+ }
1874
+
1875
+ static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
1876
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1877
+ ggml_backend_cann_get_device_memory(ctx->device, free, total);
1878
+ }
1879
+
1880
+ static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
1881
+ GGML_UNUSED(dev);
1882
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
1883
+ }
1884
+
1885
+ static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
1886
+ props->name = ggml_backend_cann_device_get_name(dev);
1887
+ props->description = ggml_backend_cann_device_get_description(dev);
1888
+ props->type = ggml_backend_cann_device_get_type(dev);
1889
+ ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
1890
+
1891
+ bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
1892
+
1893
+ props->caps = {
1894
+ /* .async = */ false,
1895
+ /* .host_buffer = */ host_buffer,
1896
+ /* .buffer_from_host_ptr = */ false,
1897
+ /* .events = */ true,
1898
+ };
1899
+ }
1900
+
1901
+ static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
1902
+ GGML_UNUSED(params);
1903
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1904
+ return ggml_backend_cann_init(ctx->device);
1905
+ }
1906
+
1907
+ /**
1908
+ * @brief Checks if the CANN backend supports a specific backend buffer type.
1909
+ *
1910
+ * This function determines whether the CANN backend supports the given backend
1911
+ * buffer type by comparing the device context of the backend and buffer type.
1912
+ * It returns true if the devices are same between the backend context and
1913
+ * buffer type context.
1914
+ *
1915
+ * @param backend Pointer to the CANN backend.
1916
+ * @param buft Pointer to the backend buffer type to check.
1917
+ * @return bool Returns true if the CANN backend supports the buffer type,
1918
+ * otherwise false.
1919
+ */
1920
+ static bool ggml_backend_cann_supports_buft(
1921
+ ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
1922
+ if (ggml_backend_buft_is_cann(buft)) {
1923
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
1924
+ ggml_backend_cann_buffer_type_context * buft_ctx =
1925
+ (ggml_backend_cann_buffer_type_context *)buft->context;
1926
+ return buft_ctx->device == dev_ctx->device;
1927
+ }
1928
+ return false;
1929
+ }
1930
+
1931
+ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
1932
+ ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
1933
+ return ggml_backend_cann_buffer_type(ctx->device);
1934
+ }
1935
+
1936
+ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
1937
+ GGML_UNUSED(dev);
1938
+ return ggml_backend_cann_host_buffer_type();
1939
+ }
1940
+
1941
+ /**
1942
+ * @brief Creates a new event for the CANN backend device.
1943
+ *
1944
+ * This function initializes a new event for the CANN backend by setting the
1945
+ * device and creating an ACL runtime event. The created event is then wrapped
1946
+ * in a ggml_backend_event structure and returned.
1947
+ *
1948
+ * @param backend Pointer to the CANN backend.
1949
+ * @return ggml_backend_event_t Returns a pointer to the new event structure.
1950
+ */
1951
+ static ggml_backend_event_t ggml_backend_cann_device_event_new(
1952
+ ggml_backend_dev_t dev) {
1953
+ ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
1954
+
1955
+ ggml_cann_set_device(dev_ctx->device);
1956
+
1957
+ aclrtEvent event;
1958
+ ACL_CHECK(aclrtCreateEvent(&event));
1959
+
1960
+ return new ggml_backend_event{
1961
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
1962
+ /* .context = */ event,
1963
+ };
1964
+ }
1965
+
1966
+ /**
1967
+ * @brief Frees a CANN backend event.
1968
+ *
1969
+ * This function destroys the ACL runtime event associated with the given CANN
1970
+ * backend event and then deletes the event structure itself.
1971
+ *
1972
+ * @param event Pointer to the event structure to be freed.
1973
+ */
1974
+ static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
1975
+ ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
1976
+
1977
+ delete event;
1978
+ GGML_UNUSED(dev);
1979
+ }
1980
+
1981
+ /**
1982
+ * @brief Synchronizes the given event on the CANN backend.
1983
+ *
1984
+ * This function waits for the specified event to complete on the ACL runtime.
1985
+ *
1986
+ * @param event Pointer to the event structure to be synchronized.
1987
+ */
1988
+ static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
1989
+ ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
1990
+
1991
+ GGML_UNUSED(dev);
1992
+ }
1993
+
1994
+ static const ggml_backend_device_i ggml_backend_cann_device_interface = {
1995
+ /* .get_name = */ ggml_backend_cann_device_get_name,
1996
+ /* .get_description = */ ggml_backend_cann_device_get_description,
1997
+ /* .get_memory = */ ggml_backend_cann_device_get_memory,
1998
+ /* .get_type = */ ggml_backend_cann_device_get_type,
1999
+ /* .get_props = */ ggml_backend_cann_device_get_props,
2000
+ /* .init_backend = */ ggml_backend_cann_device_init, // called for every card
2001
+ /* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
2002
+ /* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
2003
+ /* .buffer_from_host_ptr = */ NULL, // not supported for CANN
2004
+ /* .supports_op = */ ggml_backend_cann_supports_op,
2005
+ /* .supports_buft = */ ggml_backend_cann_supports_buft,
2006
+ /* .offload_op = */ ggml_backend_cann_offload_op,
2007
+ /* .event_new = */ ggml_backend_cann_device_event_new,
2008
+ /* .event_free = */ ggml_backend_cann_device_event_free,
2009
+ /* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
2010
+ };
2011
+
2012
+
2013
+ // backend reg
2014
+ struct ggml_backend_cann_reg_context {
2015
+ std::vector<ggml_backend_dev_t> devices;
2016
+ };
2017
+
2018
+ static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
2019
+ GGML_UNUSED(reg);
2020
+ return GGML_CANN_NAME;
2021
+ }
2022
+
2023
+ static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
2024
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2025
+ return ctx->devices.size();
2026
+ }
2027
+
2028
+ static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
2029
+ ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
2030
+ GGML_ASSERT(index < ctx->devices.size());
2031
+ return ctx->devices[index];
2032
+ }
2033
+
2034
+ static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
2035
+ GGML_UNUSED(reg);
2036
+ GGML_UNUSED(name);
2037
+ // reserved for future use
2038
+ return nullptr;
2039
+ }
2040
+
2041
+ static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
2042
+ /* .get_name = */ ggml_backend_cann_reg_get_name,
2043
+ /* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
2044
+ /* .get_device_get = */ ggml_backend_cann_reg_get_device,
2045
+ /* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
2046
+ };
2047
+
2048
+ // backend registry, called only once for cann backend
2049
+ ggml_backend_reg_t ggml_backend_cann_reg() {
2050
+ static ggml_backend_reg reg;
2051
+ static bool initialized = false;
2052
+
2053
+ {
2054
+ static std::mutex mutex;
2055
+ std::lock_guard<std::mutex> lock(mutex);
2056
+ if (!initialized) {
2057
+ aclInit(nullptr);
2058
+ ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
2059
+
2060
+ for (int i = 0; i < ggml_cann_info().device_count; i++) {
2061
+ ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
2062
+ dev_ctx->description = aclrtGetSocName();
2063
+ dev_ctx->device = i;
2064
+ dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
2065
+ ggml_cann_set_device(i);
2066
+ ggml_backend_dev_t dev = new ggml_backend_device {
2067
+ /* .interface = */ ggml_backend_cann_device_interface,
2068
+ /* .reg = */ &reg,
2069
+ /* .context = */ dev_ctx
2070
+ };
2071
+ ctx->devices.push_back(dev);
2072
+ }
2073
+
2074
+ reg = ggml_backend_reg {
2075
+ /* .interface = */ ggml_backend_cann_reg_interface,
2076
+ /* .context = */ ctx
2077
+ };
2078
+ }
2079
+
2080
+ initialized = true;
2081
+ }
2082
+
2083
+ return &reg;
2084
+ }
2085
+
2086
+ ggml_backend_t ggml_backend_cann_init(int32_t device) {
2087
+ aclInit(nullptr);
2088
+ if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
2089
+ GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
2090
+ return nullptr;
2091
+ }
2092
+
2093
+ ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
2094
+ if (ctx == nullptr) {
2095
+ GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
2096
+ return nullptr;
2097
+ }
2098
+ ggml_cann_set_device(ctx->device);
2099
+ ggml_backend_t cann_backend =
2100
+ new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
2101
+ /* .interface = */ ggml_backend_cann_interface,
2102
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
2103
+ /* .context = */ ctx};
2104
+
2105
+ return cann_backend;
2106
+ }
2107
+
2108
+ bool ggml_backend_is_cann(ggml_backend_t backend) {
2109
+ return backend != NULL &&
2110
+ ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
2111
+ }
2112
+
2113
+ int32_t ggml_backend_cann_get_device_count() {
2114
+ return ggml_cann_info().device_count;
2115
+ }
2116
+
2117
+ void ggml_backend_cann_get_device_description(
2118
+ int32_t device, char* description, size_t description_size) {
2119
+ ggml_cann_set_device(device);
2120
+ const char* soc_name = aclrtGetSocName();
2121
+ snprintf(description, description_size, "%s", soc_name);
2122
+ }
2123
+
2124
+ void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
2125
+ size_t* total) {
2126
+ ggml_cann_set_device(device);
2127
+ ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
2128
+ }
ggml/src/ggml-cpu/CMakeLists.txt ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ add_library(ggml-cpu
2
+ ggml-cpu.c
3
+ ggml-cpu.cpp
4
+ ggml-cpu-aarch64.c
5
+ ggml-cpu-aarch64.h
6
+ ggml-cpu-quants.c
7
+ ggml-cpu-quants.h
8
+ )
9
+
10
+ target_link_libraries(ggml-cpu PRIVATE ggml-base)
11
+ target_include_directories(ggml-cpu PRIVATE . ..)
12
+
13
+ if (APPLE AND GGML_ACCELERATE)
14
+ find_library(ACCELERATE_FRAMEWORK Accelerate)
15
+ if (ACCELERATE_FRAMEWORK)
16
+ message(STATUS "Accelerate framework found")
17
+
18
+ add_compile_definitions(GGML_USE_ACCELERATE)
19
+ add_compile_definitions(ACCELERATE_NEW_LAPACK)
20
+ add_compile_definitions(ACCELERATE_LAPACK_ILP64)
21
+
22
+ target_link_libraries(ggml-cpu PRIVATE ${ACCELERATE_FRAMEWORK})
23
+ else()
24
+ message(WARNING "Accelerate framework not found")
25
+ endif()
26
+ endif()
27
+
28
+ if (GGML_OPENMP)
29
+ find_package(OpenMP)
30
+ if (OpenMP_FOUND)
31
+ message(STATUS "OpenMP found")
32
+
33
+ add_compile_definitions(GGML_USE_OPENMP)
34
+
35
+ target_link_libraries(ggml-cpu PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
36
+
37
+ # FIXME: should be replaced with a compiler id check
38
+ #if (GGML_MUSA)
39
+ # list(APPEND GGML_CPU_EXTRA_INCLUDES "/usr/lib/llvm-14/lib/clang/14.0.0/include")
40
+ # list(APPEND GGML_CPU_EXTRA_LIBS_PRIVATE "/usr/lib/llvm-14/lib/libomp.so")
41
+ #endif()
42
+ else()
43
+ message(WARNING "OpenMP not found")
44
+ endif()
45
+ endif()
46
+
47
+ if (GGML_LLAMAFILE)
48
+ message(STATUS "Using llamafile")
49
+
50
+ add_compile_definitions(GGML_USE_LLAMAFILE)
51
+
52
+ target_sources(ggml-cpu PRIVATE
53
+ llamafile/sgemm.cpp
54
+ llamafile/sgemm.h)
55
+ endif()
56
+
57
+ if (GGML_CPU_HBM)
58
+ find_library(memkind memkind REQUIRED)
59
+
60
+ message(STATUS "Using memkind for CPU HBM")
61
+
62
+ add_compile_definitions(GGML_USE_CPU_HBM)
63
+
64
+ target_link_libraries(ggml-cpu PUBLIC memkind)
65
+ endif()
66
+
67
+ if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
68
+ CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
69
+ (NOT CMAKE_OSX_ARCHITECTURES AND
70
+ NOT CMAKE_GENERATOR_PLATFORM_LWR AND
71
+ CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
72
+
73
+ message(STATUS "ARM detected")
74
+
75
+ if (MSVC)
76
+ add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead
77
+ add_compile_definitions(__ARM_NEON)
78
+ add_compile_definitions(__ARM_FEATURE_FMA)
79
+
80
+ set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})
81
+ string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2")
82
+
83
+ check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD)
84
+ if (GGML_COMPILER_SUPPORT_DOTPROD)
85
+ add_compile_definitions(__ARM_FEATURE_DOTPROD)
86
+ endif ()
87
+
88
+ check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
89
+
90
+ if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
91
+ add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
92
+ endif ()
93
+
94
+ check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
95
+ if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
96
+ add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
97
+ endif ()
98
+
99
+ set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
100
+ else()
101
+ check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
102
+ if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
103
+ list(APPEND ARCH_FLAGS -mfp16-format=ieee)
104
+ endif()
105
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
106
+ # Raspberry Pi 1, Zero
107
+ list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)
108
+ endif()
109
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
110
+ if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
111
+ # Android armeabi-v7a
112
+ list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)
113
+ else()
114
+ # Raspberry Pi 2
115
+ list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
116
+ endif()
117
+ endif()
118
+ if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
119
+ # Android arm64-v8a
120
+ # Raspberry Pi 3, 4, Zero 2 (32-bit)
121
+ list(APPEND ARCH_FLAGS -mno-unaligned-access)
122
+ endif()
123
+ if (GGML_SVE)
124
+ list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
125
+ endif()
126
+ endif()
127
+ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
128
+ (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
129
+ CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
130
+ message(STATUS "x86 detected")
131
+ if (MSVC)
132
+ # instruction set detection for MSVC only
133
+ if (GGML_NATIVE)
134
+ # TODO: improve, should not reference files from the parent folder
135
+ include(cmake/FindSIMD.cmake)
136
+ endif ()
137
+ if (GGML_AVX512)
138
+ list(APPEND ARCH_FLAGS /arch:AVX512)
139
+ # MSVC has no compile-time flags enabling specific
140
+ # AVX512 extensions, neither it defines the
141
+ # macros corresponding to the extensions.
142
+ # Do it manually.
143
+ if (GGML_AVX512_VBMI)
144
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
145
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
146
+ endif()
147
+ if (GGML_AVX512_VNNI)
148
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
149
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
150
+ endif()
151
+ if (GGML_AVX512_BF16)
152
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
153
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
154
+ endif()
155
+ if (GGML_AMX_TILE)
156
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_TILE__>)
157
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_TILE__>)
158
+ endif()
159
+ if (GGML_AMX_INT8)
160
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_INT8__>)
161
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_INT8__>)
162
+ endif()
163
+ if (GGML_AMX_BF16)
164
+ add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_BF16__>)
165
+ add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_BF16__>)
166
+ endif()
167
+ elseif (GGML_AVX2)
168
+ list(APPEND ARCH_FLAGS /arch:AVX2)
169
+ elseif (GGML_AVX)
170
+ list(APPEND ARCH_FLAGS /arch:AVX)
171
+ endif()
172
+ else()
173
+ if (GGML_NATIVE)
174
+ list(APPEND ARCH_FLAGS -march=native)
175
+ endif()
176
+ if (GGML_F16C)
177
+ list(APPEND ARCH_FLAGS -mf16c)
178
+ endif()
179
+ if (GGML_FMA)
180
+ list(APPEND ARCH_FLAGS -mfma)
181
+ endif()
182
+ if (GGML_AVX)
183
+ list(APPEND ARCH_FLAGS -mavx)
184
+ endif()
185
+ if (GGML_AVX2)
186
+ list(APPEND ARCH_FLAGS -mavx2)
187
+ endif()
188
+ if (GGML_AVX512)
189
+ list(APPEND ARCH_FLAGS -mavx512f)
190
+ list(APPEND ARCH_FLAGS -mavx512dq)
191
+ list(APPEND ARCH_FLAGS -mavx512bw)
192
+ endif()
193
+ if (GGML_AVX512_VBMI)
194
+ list(APPEND ARCH_FLAGS -mavx512vbmi)
195
+ endif()
196
+ if (GGML_AVX512_VNNI)
197
+ list(APPEND ARCH_FLAGS -mavx512vnni)
198
+ endif()
199
+ if (GGML_AVX512_BF16)
200
+ list(APPEND ARCH_FLAGS -mavx512bf16)
201
+ endif()
202
+ if (GGML_AMX_TILE)
203
+ list(APPEND ARCH_FLAGS -mamx-tile)
204
+ endif()
205
+ if (GGML_AMX_INT8)
206
+ list(APPEND ARCH_FLAGS -mamx-int8)
207
+ endif()
208
+ if (GGML_AMX_BF16)
209
+ list(APPEND ARCH_FLAGS -mamx-bf16)
210
+ endif()
211
+ endif()
212
+ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
213
+ message(STATUS "PowerPC detected")
214
+ execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1"
215
+ OUTPUT_VARIABLE POWER10_M)
216
+ string(FIND ${POWER10_M} "POWER10" substring_index)
217
+ if(${substring_index} GREATER_EQUAL 0)
218
+ list(APPEND ARCH_FLAGS -mcpu=power10)
219
+ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
220
+ list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
221
+ else()
222
+ list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
223
+ #TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
224
+ endif()
225
+ elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
226
+ message(STATUS "loongarch64 detected")
227
+
228
+ list(APPEND ARCH_FLAGS -march=loongarch64)
229
+ if (GGML_LASX)
230
+ list(APPEND ARCH_FLAGS -mlasx)
231
+ endif()
232
+ if (GGML_LSX)
233
+ list(APPEND ARCH_FLAGS -mlsx)
234
+ endif()
235
+ else()
236
+ message(STATUS "Unknown architecture")
237
+ endif()
238
+
239
+ target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
240
+ target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
241
+
242
+ if (EMSCRIPTEN)
243
+ set_target_properties(ggml-cpu PROPERTIES COMPILE_FLAGS "-msimd128")
244
+ endif()
ggml/src/ggml-cpu/cmake/FindSIMD.cmake ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include(CheckCSourceRuns)
2
+
3
+ set(AVX_CODE "
4
+ #include <immintrin.h>
5
+ int main()
6
+ {
7
+ __m256 a;
8
+ a = _mm256_set1_ps(0);
9
+ return 0;
10
+ }
11
+ ")
12
+
13
+ set(AVX512_CODE "
14
+ #include <immintrin.h>
15
+ int main()
16
+ {
17
+ __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
18
+ 0, 0, 0, 0, 0, 0, 0, 0,
19
+ 0, 0, 0, 0, 0, 0, 0, 0,
20
+ 0, 0, 0, 0, 0, 0, 0, 0,
21
+ 0, 0, 0, 0, 0, 0, 0, 0,
22
+ 0, 0, 0, 0, 0, 0, 0, 0,
23
+ 0, 0, 0, 0, 0, 0, 0, 0,
24
+ 0, 0, 0, 0, 0, 0, 0, 0);
25
+ __m512i b = a;
26
+ __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
27
+ return 0;
28
+ }
29
+ ")
30
+
31
+ set(AVX2_CODE "
32
+ #include <immintrin.h>
33
+ int main()
34
+ {
35
+ __m256i a = {0};
36
+ a = _mm256_abs_epi16(a);
37
+ __m256i x;
38
+ _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
39
+ return 0;
40
+ }
41
+ ")
42
+
43
+ set(FMA_CODE "
44
+ #include <immintrin.h>
45
+ int main()
46
+ {
47
+ __m256 acc = _mm256_setzero_ps();
48
+ const __m256 d = _mm256_setzero_ps();
49
+ const __m256 p = _mm256_setzero_ps();
50
+ acc = _mm256_fmadd_ps( d, p, acc );
51
+ return 0;
52
+ }
53
+ ")
54
+
55
+ macro(check_sse type flags)
56
+ set(__FLAG_I 1)
57
+ set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
58
+ foreach (__FLAG ${flags})
59
+ if (NOT ${type}_FOUND)
60
+ set(CMAKE_REQUIRED_FLAGS ${__FLAG})
61
+ check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
62
+ if (HAS_${type}_${__FLAG_I})
63
+ set(${type}_FOUND TRUE CACHE BOOL "${type} support")
64
+ set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
65
+ endif()
66
+ math(EXPR __FLAG_I "${__FLAG_I}+1")
67
+ endif()
68
+ endforeach()
69
+ set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
70
+
71
+ if (NOT ${type}_FOUND)
72
+ set(${type}_FOUND FALSE CACHE BOOL "${type} support")
73
+ set(${type}_FLAGS "" CACHE STRING "${type} flags")
74
+ endif()
75
+
76
+ mark_as_advanced(${type}_FOUND ${type}_FLAGS)
77
+ endmacro()
78
+
79
+ # flags are for MSVC only!
80
+ check_sse("AVX" " ;/arch:AVX")
81
+ if (NOT ${AVX_FOUND})
82
+ set(GGML_AVX OFF)
83
+ else()
84
+ set(GGML_AVX ON)
85
+ endif()
86
+
87
+ check_sse("AVX2" " ;/arch:AVX2")
88
+ check_sse("FMA" " ;/arch:AVX2")
89
+ if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
90
+ set(GGML_AVX2 OFF)
91
+ else()
92
+ set(GGML_AVX2 ON)
93
+ endif()
94
+
95
+ check_sse("AVX512" " ;/arch:AVX512")
96
+ if (NOT ${AVX512_FOUND})
97
+ set(GGML_AVX512 OFF)
98
+ else()
99
+ set(GGML_AVX512 ON)
100
+ endif()
ggml/src/ggml-cpu/ggml-cpu-aarch64.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cpu/ggml-cpu-aarch64.h ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include "ggml.h"
4
+
5
+ // GGML internal header
6
+
7
+ #ifdef __cplusplus
8
+ extern "C" {
9
+ #endif
10
+
11
+ // Quantization
12
+ void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
13
+
14
+ // GEMV
15
+ void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
16
+ void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
17
+ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
18
+
19
+ // GEMM
20
+ void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
21
+ void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
22
+ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
23
+
24
+ #ifdef __cplusplus
25
+ }
26
+ #endif
27
+
ggml/src/ggml-cpu/ggml-cpu-impl.h ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ // GGML CPU internal header
4
+
5
+ #include "ggml.h"
6
+ #include "ggml-impl.h"
7
+ #include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
8
+ //#include <stddef.h>
9
+ #include <stdbool.h>
10
+ #include <string.h> // memcpy
11
+ #include <math.h> // fabsf
12
+
13
+
14
+ #ifdef __cplusplus
15
+ extern "C" {
16
+ #endif
17
+
18
+ #if defined(_MSC_VER)
19
+
20
+ #define m512bh(p) p
21
+ #define m512i(p) p
22
+
23
+ #else
24
+
25
+ #define m512bh(p) (__m512bh)(p)
26
+ #define m512i(p) (__m512i)(p)
27
+
28
+ #endif
29
+
30
+ // __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
31
+ #if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
32
+ #ifndef __FMA__
33
+ #define __FMA__
34
+ #endif
35
+ #ifndef __F16C__
36
+ #define __F16C__
37
+ #endif
38
+ #endif
39
+
40
+ // __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
41
+ #if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
42
+ #ifndef __SSE3__
43
+ #define __SSE3__
44
+ #endif
45
+ #ifndef __SSSE3__
46
+ #define __SSSE3__
47
+ #endif
48
+ #endif
49
+
50
+ #if defined(__ARM_FEATURE_SVE)
51
+ #include <arm_sve.h>
52
+ #include <sys/prctl.h>
53
+ #endif
54
+
55
+ // 16-bit float
56
+ // on Arm, we use __fp16
57
+ // on x86, we use uint16_t
58
+ #if defined(__ARM_NEON)
59
+
60
+ // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
61
+ //
62
+ // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
63
+ //
64
+ #include <arm_neon.h>
65
+
66
+ #ifdef _MSC_VER
67
+
68
+ typedef uint16_t ggml_fp16_internal_t;
69
+
70
+ #define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
71
+
72
+ #else
73
+
74
+ typedef __fp16 ggml_fp16_internal_t;
75
+
76
+ #define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
77
+
78
+ #endif // _MSC_VER
79
+
80
+ #if !defined(__aarch64__)
81
+
82
+ // 32-bit ARM compatibility
83
+
84
+ // vaddlvq_s16
85
+ // vpaddq_s16
86
+ // vpaddq_s32
87
+ // vaddvq_s32
88
+ // vaddvq_f32
89
+ // vmaxvq_f32
90
+ // vcvtnq_s32_f32
91
+ // vzip1_u8
92
+ // vzip2_u8
93
+
94
+ inline static int32_t vaddlvq_s16(int16x8_t v) {
95
+ int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
96
+ return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
97
+ }
98
+
99
+ inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
100
+ int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
101
+ int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
102
+ return vcombine_s16(a0, b0);
103
+ }
104
+
105
+ inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
106
+ int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
107
+ int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
108
+ return vcombine_s32(a0, b0);
109
+ }
110
+
111
+ inline static int32_t vaddvq_s32(int32x4_t v) {
112
+ return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
113
+ }
114
+
115
+ inline static float vaddvq_f32(float32x4_t v) {
116
+ return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
117
+ }
118
+
119
+ inline static float vmaxvq_f32(float32x4_t v) {
120
+ return
121
+ MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
122
+ MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
123
+ }
124
+
125
+ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
126
+ int32x4_t res;
127
+
128
+ res[0] = roundf(vgetq_lane_f32(v, 0));
129
+ res[1] = roundf(vgetq_lane_f32(v, 1));
130
+ res[2] = roundf(vgetq_lane_f32(v, 2));
131
+ res[3] = roundf(vgetq_lane_f32(v, 3));
132
+
133
+ return res;
134
+ }
135
+
136
+ inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
137
+ uint8x8_t res;
138
+
139
+ res[0] = a[0]; res[1] = b[0];
140
+ res[2] = a[1]; res[3] = b[1];
141
+ res[4] = a[2]; res[5] = b[2];
142
+ res[6] = a[3]; res[7] = b[3];
143
+
144
+ return res;
145
+ }
146
+
147
+ inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
148
+ uint8x8_t res;
149
+
150
+ res[0] = a[4]; res[1] = b[4];
151
+ res[2] = a[5]; res[3] = b[5];
152
+ res[4] = a[6]; res[5] = b[6];
153
+ res[6] = a[7]; res[7] = b[7];
154
+
155
+ return res;
156
+ }
157
+
158
+ // vld1q_s16_x2
159
+ // vld1q_u8_x2
160
+ // vld1q_u8_x4
161
+ // vld1q_s8_x2
162
+ // vld1q_s8_x4
163
+ // TODO: double-check these work correctly
164
+
165
+ typedef struct ggml_int16x8x2_t {
166
+ int16x8_t val[2];
167
+ } ggml_int16x8x2_t;
168
+
169
+ inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
170
+ ggml_int16x8x2_t res;
171
+
172
+ res.val[0] = vld1q_s16(ptr + 0);
173
+ res.val[1] = vld1q_s16(ptr + 8);
174
+
175
+ return res;
176
+ }
177
+
178
+ typedef struct ggml_uint8x16x2_t {
179
+ uint8x16_t val[2];
180
+ } ggml_uint8x16x2_t;
181
+
182
+ inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
183
+ ggml_uint8x16x2_t res;
184
+
185
+ res.val[0] = vld1q_u8(ptr + 0);
186
+ res.val[1] = vld1q_u8(ptr + 16);
187
+
188
+ return res;
189
+ }
190
+
191
+ typedef struct ggml_uint8x16x4_t {
192
+ uint8x16_t val[4];
193
+ } ggml_uint8x16x4_t;
194
+
195
+ inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
196
+ ggml_uint8x16x4_t res;
197
+
198
+ res.val[0] = vld1q_u8(ptr + 0);
199
+ res.val[1] = vld1q_u8(ptr + 16);
200
+ res.val[2] = vld1q_u8(ptr + 32);
201
+ res.val[3] = vld1q_u8(ptr + 48);
202
+
203
+ return res;
204
+ }
205
+
206
+ typedef struct ggml_int8x16x2_t {
207
+ int8x16_t val[2];
208
+ } ggml_int8x16x2_t;
209
+
210
+ inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
211
+ ggml_int8x16x2_t res;
212
+
213
+ res.val[0] = vld1q_s8(ptr + 0);
214
+ res.val[1] = vld1q_s8(ptr + 16);
215
+
216
+ return res;
217
+ }
218
+
219
+ typedef struct ggml_int8x16x4_t {
220
+ int8x16_t val[4];
221
+ } ggml_int8x16x4_t;
222
+
223
+ inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
224
+ ggml_int8x16x4_t res;
225
+
226
+ res.val[0] = vld1q_s8(ptr + 0);
227
+ res.val[1] = vld1q_s8(ptr + 16);
228
+ res.val[2] = vld1q_s8(ptr + 32);
229
+ res.val[3] = vld1q_s8(ptr + 48);
230
+
231
+ return res;
232
+ }
233
+
234
+ // NOTE: not tested
235
+ inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
236
+ int8x16_t res;
237
+
238
+ res[ 0] = a[b[ 0]];
239
+ res[ 1] = a[b[ 1]];
240
+ res[ 2] = a[b[ 2]];
241
+ res[ 3] = a[b[ 3]];
242
+ res[ 4] = a[b[ 4]];
243
+ res[ 5] = a[b[ 5]];
244
+ res[ 6] = a[b[ 6]];
245
+ res[ 7] = a[b[ 7]];
246
+ res[ 8] = a[b[ 8]];
247
+ res[ 9] = a[b[ 9]];
248
+ res[10] = a[b[10]];
249
+ res[11] = a[b[11]];
250
+ res[12] = a[b[12]];
251
+ res[13] = a[b[13]];
252
+ res[14] = a[b[14]];
253
+ res[15] = a[b[15]];
254
+
255
+ return res;
256
+ }
257
+
258
+ // NOTE: not tested
259
+ inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
260
+ uint8x16_t res;
261
+
262
+ res[ 0] = a[b[ 0]];
263
+ res[ 1] = a[b[ 1]];
264
+ res[ 2] = a[b[ 2]];
265
+ res[ 3] = a[b[ 3]];
266
+ res[ 4] = a[b[ 4]];
267
+ res[ 5] = a[b[ 5]];
268
+ res[ 6] = a[b[ 6]];
269
+ res[ 7] = a[b[ 7]];
270
+ res[ 8] = a[b[ 8]];
271
+ res[ 9] = a[b[ 9]];
272
+ res[10] = a[b[10]];
273
+ res[11] = a[b[11]];
274
+ res[12] = a[b[12]];
275
+ res[13] = a[b[13]];
276
+ res[14] = a[b[14]];
277
+ res[15] = a[b[15]];
278
+
279
+ return res;
280
+ }
281
+
282
+ #else
283
+
284
+ #define ggml_int16x8x2_t int16x8x2_t
285
+ #define ggml_uint8x16x2_t uint8x16x2_t
286
+ #define ggml_uint8x16x4_t uint8x16x4_t
287
+ #define ggml_int8x16x2_t int8x16x2_t
288
+ #define ggml_int8x16x4_t int8x16x4_t
289
+
290
+ #define ggml_vld1q_s16_x2 vld1q_s16_x2
291
+ #define ggml_vld1q_u8_x2 vld1q_u8_x2
292
+ #define ggml_vld1q_u8_x4 vld1q_u8_x4
293
+ #define ggml_vld1q_s8_x2 vld1q_s8_x2
294
+ #define ggml_vld1q_s8_x4 vld1q_s8_x4
295
+ #define ggml_vqtbl1q_s8 vqtbl1q_s8
296
+ #define ggml_vqtbl1q_u8 vqtbl1q_u8
297
+
298
+ #endif // !defined(__aarch64__)
299
+
300
+ #if !defined(__ARM_FEATURE_DOTPROD)
301
+
302
+ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
303
+ const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
304
+ const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
305
+
306
+ return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
307
+ }
308
+
309
+ #else
310
+
311
+ #define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
312
+
313
+ #endif // !defined(__ARM_FEATURE_DOTPROD)
314
+
315
+ #endif // defined(__ARM_NEON)
316
+
317
+ #ifdef __wasm_simd128__
318
+ #include <wasm_simd128.h>
319
+ #else
320
+ #ifdef __POWER9_VECTOR__
321
+ #include <altivec.h>
322
+ #undef bool
323
+ #define bool _Bool
324
+ #else
325
+ #if defined(_MSC_VER) || defined(__MINGW32__)
326
+ #include <intrin.h>
327
+ #else
328
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
329
+ #if !defined(__riscv)
330
+ #include <immintrin.h>
331
+ #endif
332
+ #endif
333
+ #endif
334
+ #endif
335
+ #endif
336
+
337
+ #ifdef __riscv_v_intrinsic
338
+ #include <riscv_vector.h>
339
+ #endif
340
+
341
+ #if defined(__loongarch64)
342
+ #if defined(__loongarch_asx)
343
+ #include <lasxintrin.h>
344
+ #endif
345
+ #if defined(__loongarch_sx)
346
+ #include <lsxintrin.h>
347
+ #endif
348
+ #endif
349
+
350
+ #if defined(__loongarch_asx)
351
+
352
+ typedef union {
353
+ int32_t i;
354
+ float f;
355
+ } ft_union;
356
+
357
+ /* float type data load instructions */
358
+ static __m128 __lsx_vreplfr2vr_s(float val) {
359
+ ft_union fi_tmpval = {.f = val};
360
+ return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
361
+ }
362
+
363
+ static __m256 __lasx_xvreplfr2vr_s(float val) {
364
+ ft_union fi_tmpval = {.f = val};
365
+ return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
366
+ }
367
+ #endif
368
+
369
+ #ifdef __cplusplus
370
+ }
371
+ #endif
ggml/src/ggml-cpu/ggml-cpu-quants.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cpu/ggml-cpu-quants.h ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #define GGML_COMMON_DECL_C
4
+ #include "ggml-common.h"
5
+
6
+ #include "ggml.h"
7
+
8
+ // GGML CPU internal header
9
+
10
+ #ifdef __cplusplus
11
+ extern "C" {
12
+ #endif
13
+
14
+ // Quantization
15
+ void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
16
+ void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
17
+ void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
18
+ void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
19
+ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
20
+ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
21
+
22
+ void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
23
+ void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
24
+ void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
25
+ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
26
+ void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
27
+ void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
28
+
29
+ void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
30
+ void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
31
+
32
+ void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
33
+ void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
34
+
35
+ // Dot product
36
+ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
37
+ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
38
+ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
39
+ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
40
+ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
41
+
42
+ void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
43
+ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
44
+ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
45
+ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
46
+ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
47
+
48
+ void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
49
+ void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
50
+
51
+ void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
52
+ void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
53
+ void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
54
+ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
55
+ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
56
+ void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
57
+ void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
58
+ void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
59
+ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
60
+
61
+ #ifdef __cplusplus
62
+ }
63
+ #endif
ggml/src/ggml-cpu/ggml-cpu.c ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cpu/ggml-cpu.cpp ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include "ggml-backend.h"
2
+ #include "ggml-backend-impl.h"
3
+ #include "ggml-cpu.h"
4
+ #include "ggml-impl.h"
5
+ #include <cctype>
6
+ #include <string>
7
+ #include <vector>
8
+
9
+ #if defined(__APPLE__)
10
+ #include <sys/types.h>
11
+ #include <sys/sysctl.h>
12
+ #endif
13
+
14
+ #if defined(_WIN32)
15
+ #define WIN32_LEAN_AND_MEAN
16
+ #ifndef NOMINMAX
17
+ #define NOMINMAX
18
+ #endif
19
+ #include <windows.h>
20
+ #endif
21
+
22
+ // ggml-backend interface
23
+
24
+ #ifdef GGML_USE_CPU_HBM
25
+
26
+ // buffer type HBM
27
+
28
+ #include <hbwmalloc.h>
29
+
30
+ static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
31
+ return "CPU_HBM";
32
+
33
+ GGML_UNUSED(buft);
34
+ }
35
+
36
+ static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
37
+ hbw_free(buffer->context);
38
+ }
39
+
40
+ static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
41
+ void * ptr;
42
+ int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
43
+ if (result != 0) {
44
+ GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
45
+ return NULL;
46
+ }
47
+
48
+ ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
49
+ buffer->buft = buft;
50
+ buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
51
+
52
+ return buffer;
53
+ }
54
+
55
+ ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
56
+ static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
57
+ /* .iface = */ {
58
+ /* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
59
+ /* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
60
+ /* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
61
+ /* .get_max_size = */ NULL, // defaults to SIZE_MAX
62
+ /* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
63
+ /* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
64
+ },
65
+ /* .context = */ NULL,
66
+ };
67
+
68
+ return &ggml_backend_cpu_buffer_type_hbm;
69
+ }
70
+ #endif
71
+
72
+ static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
73
+ static ggml_backend_buffer_type_t bufts[] = {
74
+ #ifdef GGML_USE_CPU_HBM
75
+ ggml_backend_cpu_hbm_buffer_type(),
76
+ #endif
77
+ NULL
78
+ };
79
+
80
+ return bufts;
81
+
82
+ GGML_UNUSED(device);
83
+ }
84
+
85
+ // CPU backend - backend (stream)
86
+
87
+ struct ggml_backend_cpu_context {
88
+ int n_threads;
89
+ ggml_threadpool_t threadpool;
90
+
91
+ uint8_t * work_data;
92
+ size_t work_size;
93
+
94
+ ggml_abort_callback abort_callback;
95
+ void * abort_callback_data;
96
+ };
97
+
98
+ static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
99
+ return "CPU";
100
+
101
+ GGML_UNUSED(backend);
102
+ }
103
+
104
+ static void ggml_backend_cpu_free(ggml_backend_t backend) {
105
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
106
+ delete[] cpu_ctx->work_data;
107
+ delete cpu_ctx;
108
+ delete backend;
109
+ }
110
+
111
+ struct ggml_backend_plan_cpu {
112
+ struct ggml_cplan cplan;
113
+ struct ggml_cgraph cgraph;
114
+ };
115
+
116
+ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
117
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
118
+
119
+ struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu;
120
+
121
+ cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
122
+ cpu_plan->cgraph = *cgraph; // FIXME: deep copy
123
+
124
+ if (cpu_plan->cplan.work_size > 0) {
125
+ cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size];
126
+ if (cpu_plan->cplan.work_data == NULL) {
127
+ delete cpu_plan;
128
+ return NULL;
129
+ }
130
+ }
131
+
132
+ cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
133
+ cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
134
+
135
+ return cpu_plan;
136
+ }
137
+
138
+ static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
139
+ struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
140
+
141
+ delete[] cpu_plan->cplan.work_data;
142
+ delete cpu_plan;
143
+
144
+ GGML_UNUSED(backend);
145
+ }
146
+
147
+ static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
148
+ struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
149
+
150
+ return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
151
+
152
+ GGML_UNUSED(backend);
153
+ }
154
+
155
+ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
156
+ struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
157
+
158
+ struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
159
+
160
+ if (cpu_ctx->work_size < cplan.work_size) {
161
+ delete[] cpu_ctx->work_data;
162
+ cpu_ctx->work_data = new uint8_t[cplan.work_size];
163
+ if (cpu_ctx->work_data == NULL) {
164
+ cpu_ctx->work_size = 0;
165
+ return GGML_STATUS_ALLOC_FAILED;
166
+ }
167
+ cpu_ctx->work_size = cplan.work_size;
168
+ }
169
+ cplan.work_data = (uint8_t *)cpu_ctx->work_data;
170
+
171
+ cplan.abort_callback = cpu_ctx->abort_callback;
172
+ cplan.abort_callback_data = cpu_ctx->abort_callback_data;
173
+
174
+ return ggml_graph_compute(cgraph, &cplan);
175
+ }
176
+
177
+ static const struct ggml_backend_i ggml_backend_cpu_i = {
178
+ /* .get_name = */ ggml_backend_cpu_get_name,
179
+ /* .free = */ ggml_backend_cpu_free,
180
+ /* .set_tensor_async = */ NULL,
181
+ /* .get_tensor_async = */ NULL,
182
+ /* .cpy_tensor_async = */ NULL,
183
+ /* .synchronize = */ NULL,
184
+ /* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
185
+ /* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
186
+ /* .graph_plan_update = */ NULL,
187
+ /* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
188
+ /* .graph_compute = */ ggml_backend_cpu_graph_compute,
189
+ /* .event_record = */ NULL,
190
+ /* .event_wait = */ NULL,
191
+ };
192
+
193
+ static ggml_guid_t ggml_backend_cpu_guid(void) {
194
+ static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
195
+ return &guid;
196
+ }
197
+
198
+ ggml_backend_t ggml_backend_cpu_init(void) {
199
+ // initialize CPU backend now to avoid slowing the first graph computation
200
+ ggml_cpu_init();
201
+
202
+ struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context;
203
+ if (ctx == NULL) {
204
+ return NULL;
205
+ }
206
+
207
+ ctx->n_threads = GGML_DEFAULT_N_THREADS;
208
+ ctx->threadpool = NULL;
209
+ ctx->work_data = NULL;
210
+ ctx->work_size = 0;
211
+ ctx->abort_callback = NULL;
212
+ ctx->abort_callback_data = NULL;
213
+
214
+ ggml_backend_t cpu_backend = new ggml_backend {
215
+ /* .guid = */ ggml_backend_cpu_guid(),
216
+ /* .interface = */ ggml_backend_cpu_i,
217
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
218
+ /* .context = */ ctx,
219
+ };
220
+
221
+ if (cpu_backend == NULL) {
222
+ delete ctx;
223
+ return NULL;
224
+ }
225
+
226
+ return cpu_backend;
227
+ }
228
+
229
+ bool ggml_backend_is_cpu(ggml_backend_t backend) {
230
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
231
+ }
232
+
233
+ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
234
+ GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
235
+
236
+ struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
237
+ ctx->n_threads = n_threads;
238
+ }
239
+
240
+ void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
241
+ GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
242
+
243
+ struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
244
+
245
+ if (ctx->threadpool && ctx->threadpool != threadpool) {
246
+ // already had a different threadpool, pause/suspend it before switching
247
+ ggml_threadpool_pause(ctx->threadpool);
248
+ }
249
+ ctx->threadpool = threadpool;
250
+ }
251
+
252
+ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
253
+ GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
254
+
255
+ struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
256
+ ctx->abort_callback = abort_callback;
257
+ ctx->abort_callback_data = abort_callback_data;
258
+ }
259
+
260
+ // CPU backend - device
261
+
262
+ struct ggml_backend_cpu_device_context {
263
+ std::string description = "CPU";
264
+
265
+ ggml_backend_cpu_device_context() {
266
+ #ifdef __APPLE__
267
+ size_t len = 0;
268
+ if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) {
269
+ description.resize(len);
270
+ sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT
271
+ }
272
+ #elif defined(__linux__)
273
+ FILE * f = fopen("/proc/cpuinfo", "r");
274
+ if (f) {
275
+ char buf[1024];
276
+ while (fgets(buf, sizeof(buf), f)) {
277
+ if (strncmp(buf, "model name", 10) == 0) {
278
+ char * p = strchr(buf, ':');
279
+ if (p) {
280
+ p++;
281
+ while (std::isspace(*p)) {
282
+ p++;
283
+ }
284
+ while (std::isspace(p[strlen(p) - 1])) {
285
+ p[strlen(p) - 1] = '\0';
286
+ }
287
+ description = p;
288
+ break;
289
+ }
290
+ }
291
+ }
292
+ fclose(f);
293
+ }
294
+ #elif defined(_WIN32)
295
+ HKEY hKey;
296
+ if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,
297
+ TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"),
298
+ 0,
299
+ KEY_READ,
300
+ &hKey) == ERROR_SUCCESS) {
301
+ DWORD cpu_brand_size = 0;
302
+ if (RegQueryValueExA(hKey,
303
+ TEXT("ProcessorNameString"),
304
+ NULL,
305
+ NULL,
306
+ NULL,
307
+ &cpu_brand_size) == ERROR_SUCCESS) {
308
+ description.resize(cpu_brand_size);
309
+ if (RegQueryValueExA(hKey,
310
+ TEXT("ProcessorNameString"),
311
+ NULL,
312
+ NULL,
313
+ (LPBYTE)&description[0], // NOLINT
314
+ &cpu_brand_size) == ERROR_SUCCESS) {
315
+ if (description.find('\0') != std::string::npos) {
316
+ description.resize(description.find('\0'));
317
+ }
318
+ }
319
+ }
320
+ RegCloseKey(hKey);
321
+ }
322
+ #endif
323
+ }
324
+ };
325
+
326
+ static const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) {
327
+ return "CPU";
328
+
329
+ GGML_UNUSED(dev);
330
+ }
331
+
332
+ static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) {
333
+ struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context;
334
+
335
+ return ctx->description.c_str();
336
+ }
337
+
338
+ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
339
+ // TODO
340
+ *free = 0;
341
+ *total = 0;
342
+
343
+ GGML_UNUSED(dev);
344
+ }
345
+
346
+ static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {
347
+ return GGML_BACKEND_DEVICE_TYPE_CPU;
348
+
349
+ GGML_UNUSED(dev);
350
+ }
351
+
352
+ static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
353
+ props->name = ggml_backend_cpu_device_get_name(dev);
354
+ props->description = ggml_backend_cpu_device_get_description(dev);
355
+ props->type = ggml_backend_cpu_device_get_type(dev);
356
+ ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
357
+ props->caps = {
358
+ /* .async = */ false,
359
+ /* .host_buffer = */ false,
360
+ /* .buffer_from_host_ptr = */ true,
361
+ /* .events = */ false,
362
+ };
363
+ }
364
+
365
+ static ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) {
366
+ return ggml_backend_cpu_init();
367
+
368
+ GGML_UNUSED(dev);
369
+ GGML_UNUSED(params);
370
+ }
371
+
372
+ static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) {
373
+ return ggml_backend_cpu_buffer_type();
374
+
375
+ GGML_UNUSED(dev);
376
+ }
377
+
378
+ static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
379
+ return ggml_backend_cpu_buffer_from_ptr(ptr, size);
380
+
381
+ GGML_UNUSED(dev);
382
+ GGML_UNUSED(max_tensor_size);
383
+ }
384
+
385
+ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
386
+ switch (op->op) {
387
+ case GGML_OP_CPY:
388
+ return
389
+ op->type != GGML_TYPE_IQ2_XXS &&
390
+ op->type != GGML_TYPE_IQ2_XS &&
391
+ op->type != GGML_TYPE_IQ1_S &&
392
+ op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
393
+ case GGML_OP_MUL_MAT:
394
+ return op->src[1]->type == GGML_TYPE_F32;// FIXME || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type;
395
+ case GGML_OP_ROPE_BACK:
396
+ return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
397
+ case GGML_OP_IM2COL_BACK:
398
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
399
+ case GGML_OP_OUT_PROD:
400
+ return (op->src[0]->type == GGML_TYPE_F32 || ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == GGML_TYPE_F32;
401
+ default:
402
+ return true;
403
+ }
404
+
405
+ GGML_UNUSED(dev);
406
+ }
407
+
408
+ static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
409
+ return ggml_backend_buft_is_host(buft);
410
+
411
+ GGML_UNUSED(dev);
412
+ }
413
+
414
+ static const struct ggml_backend_device_i ggml_backend_cpu_device_i = {
415
+ /* .get_name = */ ggml_backend_cpu_device_get_name,
416
+ /* .get_description = */ ggml_backend_cpu_device_get_description,
417
+ /* .get_memory = */ ggml_backend_cpu_device_get_memory,
418
+ /* .get_type = */ ggml_backend_cpu_device_get_type,
419
+ /* .get_props = */ ggml_backend_cpu_device_get_props,
420
+ /* .init_backend = */ ggml_backend_cpu_device_init_backend,
421
+ /* .get_buffer_type = */ ggml_backend_cpu_device_get_buffer_type,
422
+ /* .get_host_buffer_type = */ NULL,
423
+ /* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr,
424
+ /* .supports_op = */ ggml_backend_cpu_device_supports_op,
425
+ /* .supports_buft = */ ggml_backend_cpu_device_supports_buft,
426
+ /* .offload_op = */ NULL,
427
+ /* .event_new = */ NULL,
428
+ /* .event_free = */ NULL,
429
+ /* .event_synchronize = */ NULL,
430
+ };
431
+
432
+ // CPU backend - backend (reg)
433
+
434
+ static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {
435
+ return "CPU";
436
+
437
+ GGML_UNUSED(reg);
438
+ }
439
+
440
+ static size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) {
441
+ return 1;
442
+
443
+ GGML_UNUSED(reg);
444
+ }
445
+
446
+ static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
447
+ GGML_ASSERT(index == 0);
448
+
449
+ static ggml_backend_cpu_device_context ctx;
450
+ static ggml_backend_device ggml_backend_cpu_device = {
451
+ /* .iface = */ ggml_backend_cpu_device_i,
452
+ /* .reg = */ reg,
453
+ /* .context = */ &ctx,
454
+ };
455
+
456
+ return &ggml_backend_cpu_device;
457
+ }
458
+
459
+ struct ggml_backend_feature {
460
+ const char * name;
461
+ const char * value;
462
+ };
463
+
464
+ // Not used yet
465
+ // This is intended to replace the the ggml_cpu_has_* functions when loading the CPU backend dynamically,
466
+ // and additionally to allow other backends to expose their own list of features that applications can query using the same API.
467
+ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t reg) {
468
+ static std::vector<ggml_backend_feature> features = []() {
469
+ std::vector<ggml_backend_feature> features;
470
+ if (ggml_cpu_has_sse3()) {
471
+ features.push_back({ "SSE3", "1" });
472
+ }
473
+ if (ggml_cpu_has_ssse3()) {
474
+ features.push_back({ "SSSE3", "1" });
475
+ }
476
+ if (ggml_cpu_has_avx()) {
477
+ features.push_back({ "AVX", "1" });
478
+ }
479
+ if (ggml_cpu_has_avx2()) {
480
+ features.push_back({ "AVX2", "1" });
481
+ }
482
+ if (ggml_cpu_has_f16c()) {
483
+ features.push_back({ "F16C", "1" });
484
+ }
485
+ if (ggml_cpu_has_fma()) {
486
+ features.push_back({ "FMA", "1" });
487
+ }
488
+ if (ggml_cpu_has_avx_vnni()) {
489
+ features.push_back({ "AVX_VNNI", "1" });
490
+ }
491
+ if (ggml_cpu_has_avx512()) {
492
+ features.push_back({ "AVX512", "1" });
493
+ }
494
+ if (ggml_cpu_has_avx512_vbmi()) {
495
+ features.push_back({ "AVX512_VBMI", "1" });
496
+ }
497
+ if (ggml_cpu_has_avx512_vnni()) {
498
+ features.push_back({ "AVX512_VNNI", "1" });
499
+ }
500
+ if (ggml_cpu_has_avx512_bf16()) {
501
+ features.push_back({ "AVX512_BF16", "1" });
502
+ }
503
+ if (ggml_cpu_has_amx_int8()) {
504
+ features.push_back({ "AMX_INT8", "1" });
505
+ }
506
+ if (ggml_cpu_has_neon()) {
507
+ features.push_back({ "NEON", "1" });
508
+ }
509
+ if (ggml_cpu_has_arm_fma()) {
510
+ features.push_back({ "ARM_FMA", "1" });
511
+ }
512
+ if (ggml_cpu_has_fp16_va()) {
513
+ features.push_back({ "FP16_VA", "1" });
514
+ }
515
+ if (ggml_cpu_has_matmul_int8()) {
516
+ features.push_back({ "MATMUL_INT8", "1" });
517
+ }
518
+ if (ggml_cpu_has_sve()) {
519
+ features.push_back({ "SVE", "1" });
520
+ }
521
+ if (ggml_cpu_get_sve_cnt() > 0) {
522
+ static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
523
+ features.push_back({ "SVE_CNT", sve_cnt.c_str() });
524
+ }
525
+ if (ggml_cpu_has_riscv_v()) {
526
+ features.push_back({ "RISCV_V", "1" });
527
+ }
528
+ if (ggml_cpu_has_vsx()) {
529
+ features.push_back({ "VSX", "1" });
530
+ }
531
+ if (ggml_cpu_has_wasm_simd()) {
532
+ features.push_back({ "WASM_SIMD", "1" });
533
+ }
534
+ if (ggml_cpu_has_llamafile()) {
535
+ features.push_back({ "LLAMAFILE", "1" });
536
+ }
537
+
538
+ features.push_back({ nullptr, nullptr });
539
+
540
+ return features;
541
+ }();
542
+
543
+ return features.data();
544
+
545
+ GGML_UNUSED(reg);
546
+ }
547
+
548
+ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
549
+ if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
550
+ return (void *)ggml_backend_cpu_set_n_threads;
551
+ }
552
+ if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
553
+ return (void *)ggml_backend_cpu_get_extra_bufts;
554
+ }
555
+
556
+ return NULL;
557
+
558
+ GGML_UNUSED(reg);
559
+ }
560
+
561
+ static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
562
+ /* .get_name = */ ggml_backend_cpu_reg_get_name,
563
+ /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
564
+ /* .get_device = */ ggml_backend_cpu_reg_get_device,
565
+ /* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
566
+ };
567
+
568
+ ggml_backend_reg_t ggml_backend_cpu_reg(void) {
569
+ static struct ggml_backend_reg ggml_backend_cpu_reg = {
570
+ /* .iface = */ ggml_backend_cpu_reg_i,
571
+ /* .context = */ NULL,
572
+ };
573
+
574
+ return &ggml_backend_cpu_reg;
575
+ }
ggml/src/ggml-cpu/llamafile/sgemm.cpp ADDED
@@ -0,0 +1,1884 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright 2024 Mozilla Foundation
2
+ //
3
+ // Permission is hereby granted, free of charge, to any person obtaining
4
+ // a copy of this software and associated documentation files (the
5
+ // "Software"), to deal in the Software without restriction, including
6
+ // without limitation the rights to use, copy, modify, merge, publish,
7
+ // distribute, sublicense, and/or sell copies of the Software, and to
8
+ // permit persons to whom the Software is furnished to do so, subject to
9
+ // the following conditions:
10
+ //
11
+ // The above copyright notice and this permission notice shall be
12
+ // included in all copies or substantial portions of the Software.
13
+ //
14
+ // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
18
+ // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
19
+ // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
20
+ // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ // SOFTWARE.
22
+
23
+ //
24
+ // _ _ ___ _ _ ___
25
+ // | |_(_)_ _ _ _| _ ) | /_\ / __|
26
+ // | _| | ' \ || | _ \ |__ / _ \\__ \.
27
+ // \__|_|_||_\_, |___/____/_/ \_\___/
28
+ // |__/
29
+ //
30
+ // BASIC LINEAR ALGEBRA SUBPROGRAMS
31
+ //
32
+ //
33
+ // This file implements multithreaded CPU matrix multiplication for the
34
+ // common contiguous use case C = Aᵀ * B. These kernels are designed to
35
+ // have excellent performance[1] for matrices that fit in the CPU cache
36
+ // without imposing any overhead such as cache filling or malloc calls.
37
+ //
38
+ // This implementation does not guarantee any upper bound with rounding
39
+ // errors, which grow along with k. Our goal's to maximally exploit the
40
+ // hardware for performance, and then use whatever resources remain for
41
+ // improving numerical accuracy.
42
+ //
43
+ // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
44
+ // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
45
+
46
+ #if defined(__GNUC__)
47
+ #pragma GCC diagnostic ignored "-Wpedantic"
48
+ #pragma GCC diagnostic ignored "-Wignored-attributes"
49
+ #endif
50
+
51
+ #include "sgemm.h"
52
+ #include "ggml-impl.h"
53
+ #include "ggml-cpu-impl.h"
54
+ #include "ggml-quants.h"
55
+
56
+ #ifdef _MSC_VER
57
+ #define NOINLINE __declspec(noinline)
58
+ #else
59
+ #define NOINLINE __attribute__((__noinline__))
60
+ #endif
61
+
62
+ #if defined(__ARM_NEON) || defined(__AVX512F__)
63
+ #define VECTOR_REGISTERS 32
64
+ #else
65
+ #define VECTOR_REGISTERS 16
66
+ #endif
67
+
68
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
69
+
70
+ namespace {
71
+
72
+ inline float unhalf(ggml_fp16_t d) {
73
+ return GGML_FP16_TO_FP32(d);
74
+ }
75
+
76
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
77
+ // VECTORIZED ARITHMETIC OPERATIONS
78
+
79
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
80
+ inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
81
+ inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
82
+ inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
83
+ #endif // __SSE__
84
+
85
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
86
+ inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
87
+ inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
88
+ inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
89
+ #endif // __AVX__
90
+
91
+ #if defined(__AVX512F__)
92
+ inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
93
+ inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
94
+ inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
95
+ #endif // __AVX512F__
96
+
97
+ #if defined(__ARM_NEON)
98
+ inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
99
+ inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
100
+ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
101
+ #endif // __ARM_NEON
102
+
103
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
104
+ inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
105
+ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
106
+ inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
107
+ #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
108
+
109
+ #if defined(__MMA__)
110
+ typedef vector unsigned char vec_t;
111
+ typedef __vector_quad acc_t;
112
+ #endif
113
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
114
+ // VECTORIZED FUSED MULTIPLY ADD
115
+
116
+ /**
117
+ * Computes a * b + c.
118
+ */
119
+ template <typename T, typename U>
120
+ inline U madd(T a, T b, U c) {
121
+ return add(mul(a, b), c);
122
+ }
123
+
124
+ #if defined(__FMA__)
125
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
126
+ template <>
127
+ inline __m256 madd(__m256 a, __m256 b, __m256 c) {
128
+ return _mm256_fmadd_ps(a, b, c);
129
+ }
130
+ #endif
131
+ #if defined(__AVX512F__)
132
+ template <>
133
+ inline __m512 madd(__m512 a, __m512 b, __m512 c) {
134
+ return _mm512_fmadd_ps(a, b, c);
135
+ }
136
+ #endif
137
+ #endif
138
+
139
+ #if defined(__ARM_FEATURE_FMA)
140
+ template <>
141
+ inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
142
+ return vfmaq_f32(c, b, a);
143
+ }
144
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
145
+ template <>
146
+ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
147
+ return vfmaq_f16(c, b, a);
148
+ }
149
+ #endif
150
+ #endif
151
+
152
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
153
+ // VECTORIZED HORIZONTAL SUM
154
+
155
+ #if defined(__ARM_NEON)
156
+ inline float hsum(float32x4_t x) {
157
+ return vaddvq_f32(x);
158
+ }
159
+ #endif // __ARM_NEON
160
+
161
+ #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
162
+ inline float hsum(float16x8_t x) {
163
+ return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
164
+ vcvt_f32_f16(vget_high_f16(x))));
165
+ }
166
+ #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
167
+
168
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
169
+ inline float hsum(__m128 x) {
170
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
171
+ x = _mm_add_ps(x, _mm_movehl_ps(x, x));
172
+ x = _mm_add_ss(x, _mm_movehdup_ps(x));
173
+ #else
174
+ __m128 t;
175
+ t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
176
+ x = _mm_add_ps(x, t);
177
+ t = _mm_movehl_ps(t, x);
178
+ x = _mm_add_ss(x, t);
179
+ #endif
180
+ return _mm_cvtss_f32(x);
181
+ }
182
+ #endif
183
+
184
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
185
+ inline float hsum(__m256 x) {
186
+ return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
187
+ _mm256_castps256_ps128(x)));
188
+ }
189
+ #endif // __AVX__
190
+
191
+ #if defined(__AVX512F__)
192
+ inline float hsum(__m512 x) {
193
+ return _mm512_reduce_add_ps(x);
194
+ }
195
+ #endif // __AVX512F__
196
+
197
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
198
+ // VECTORIZED MEMORY LOADING
199
+
200
+ template <typename T, typename U> T load(const U *);
201
+
202
+ #if defined(__ARM_NEON)
203
+ template <> inline float32x4_t load(const float *p) {
204
+ return vld1q_f32(p);
205
+ }
206
+ #if !defined(_MSC_VER)
207
+ template <> inline float16x8_t load(const ggml_fp16_t *p) {
208
+ return vld1q_f16((const float16_t *)p);
209
+ }
210
+ template <> inline float32x4_t load(const ggml_fp16_t *p) {
211
+ return vcvt_f32_f16(vld1_f16((const float16_t *)p));
212
+ }
213
+ #endif // _MSC_VER
214
+ #endif // __ARM_NEON
215
+
216
+ #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
217
+ template <> inline __m128 load(const float *p) {
218
+ return _mm_loadu_ps(p);
219
+ }
220
+ #endif // __SSE__
221
+
222
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
223
+ template <> inline __m256 load(const float *p) {
224
+ return _mm256_loadu_ps(p);
225
+ }
226
+ #endif // __AVX__
227
+
228
+ #if defined(__F16C__)
229
+ template <> inline __m256 load(const ggml_fp16_t *p) {
230
+ return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
231
+ }
232
+ #endif // __F16C__
233
+
234
+ #if defined(__AVX512F__)
235
+ template <> inline __m512 load(const float *p) {
236
+ return _mm512_loadu_ps(p);
237
+ }
238
+ template <> inline __m512 load(const ggml_fp16_t *p) {
239
+ return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
240
+ }
241
+ #endif // __AVX512F__
242
+
243
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
244
+ // CONSTANTS
245
+
246
+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
247
+ static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
248
+ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
249
+ #endif
250
+
251
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
252
+ // FLOATING POINT MATRIX MULTIPLICATION
253
+
254
+ template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
255
+ class tinyBLAS {
256
+ public:
257
+ tinyBLAS(int64_t k,
258
+ const TA *A, int64_t lda,
259
+ const TB *B, int64_t ldb,
260
+ TC *C, int64_t ldc,
261
+ int ith, int nth)
262
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
263
+ }
264
+
265
+ void matmul(int64_t m, int64_t n) {
266
+ mnpack(0, m, 0, n);
267
+ }
268
+
269
+ private:
270
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
271
+ int64_t mc, nc, mp, np;
272
+ switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
273
+ #if VECTOR_REGISTERS == 32
274
+ case 0x55:
275
+ mc = 5;
276
+ nc = 5;
277
+ gemm<5, 5>(m0, m, n0, n);
278
+ break;
279
+ case 0x45:
280
+ mc = 4;
281
+ nc = 5;
282
+ gemm<4, 5>(m0, m, n0, n);
283
+ break;
284
+ case 0x54:
285
+ mc = 5;
286
+ nc = 4;
287
+ gemm<5, 4>(m0, m, n0, n);
288
+ break;
289
+ case 0x44:
290
+ mc = 4;
291
+ nc = 4;
292
+ gemm<4, 4>(m0, m, n0, n);
293
+ break;
294
+ case 0x53:
295
+ mc = 5;
296
+ nc = 3;
297
+ gemm<5, 3>(m0, m, n0, n);
298
+ break;
299
+ case 0x35:
300
+ mc = 3;
301
+ nc = 5;
302
+ gemm<3, 5>(m0, m, n0, n);
303
+ break;
304
+ case 0x43:
305
+ mc = 4;
306
+ nc = 3;
307
+ gemm<4, 3>(m0, m, n0, n);
308
+ break;
309
+ #else
310
+ case 0x55:
311
+ case 0x54:
312
+ case 0x53:
313
+ case 0x45:
314
+ case 0x44:
315
+ case 0x43:
316
+ mc = 4;
317
+ nc = 3;
318
+ gemm<4, 3>(m0, m, n0, n);
319
+ break;
320
+ case 0x35:
321
+ #endif
322
+ case 0x34:
323
+ mc = 3;
324
+ nc = 4;
325
+ gemm<3, 4>(m0, m, n0, n);
326
+ break;
327
+ case 0x52:
328
+ mc = 5;
329
+ nc = 2;
330
+ gemm<5, 2>(m0, m, n0, n);
331
+ break;
332
+ case 0x33:
333
+ mc = 3;
334
+ nc = 3;
335
+ gemm<3, 3>(m0, m, n0, n);
336
+ break;
337
+ case 0x25:
338
+ mc = 2;
339
+ nc = 5;
340
+ gemm<2, 5>(m0, m, n0, n);
341
+ break;
342
+ case 0x42:
343
+ mc = 4;
344
+ nc = 2;
345
+ gemm<4, 2>(m0, m, n0, n);
346
+ break;
347
+ case 0x24:
348
+ mc = 2;
349
+ nc = 4;
350
+ gemm<2, 4>(m0, m, n0, n);
351
+ break;
352
+ case 0x32:
353
+ mc = 3;
354
+ nc = 2;
355
+ gemm<3, 2>(m0, m, n0, n);
356
+ break;
357
+ case 0x23:
358
+ mc = 2;
359
+ nc = 3;
360
+ gemm<2, 3>(m0, m, n0, n);
361
+ break;
362
+ case 0x51:
363
+ mc = 5;
364
+ nc = 1;
365
+ gemm<5, 1>(m0, m, n0, n);
366
+ break;
367
+ case 0x41:
368
+ mc = 4;
369
+ nc = 1;
370
+ gemm<4, 1>(m0, m, n0, n);
371
+ break;
372
+ case 0x22:
373
+ mc = 2;
374
+ nc = 2;
375
+ gemm<2, 2>(m0, m, n0, n);
376
+ break;
377
+ case 0x15:
378
+ mc = 1;
379
+ nc = 5;
380
+ gemm<1, 5>(m0, m, n0, n);
381
+ break;
382
+ case 0x14:
383
+ mc = 1;
384
+ nc = 4;
385
+ gemm<1, 4>(m0, m, n0, n);
386
+ break;
387
+ case 0x31:
388
+ mc = 3;
389
+ nc = 1;
390
+ gemm<3, 1>(m0, m, n0, n);
391
+ break;
392
+ case 0x13:
393
+ mc = 1;
394
+ nc = 3;
395
+ gemm<1, 3>(m0, m, n0, n);
396
+ break;
397
+ case 0x21:
398
+ mc = 2;
399
+ nc = 1;
400
+ gemm<2, 1>(m0, m, n0, n);
401
+ break;
402
+ case 0x12:
403
+ mc = 1;
404
+ nc = 2;
405
+ gemm<1, 2>(m0, m, n0, n);
406
+ break;
407
+ case 0x11:
408
+ mc = 1;
409
+ nc = 1;
410
+ gemm<1, 1>(m0, m, n0, n);
411
+ break;
412
+ default:
413
+ return;
414
+ }
415
+ mp = m0 + (m - m0) / mc * mc;
416
+ np = n0 + (n - n0) / nc * nc;
417
+ mnpack(mp, m, n0, np);
418
+ mnpack(m0, m, np, n);
419
+ }
420
+
421
+ template <int RM, int RN>
422
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
423
+ int64_t ytiles = (m - m0) / RM;
424
+ int64_t xtiles = (n - n0) / RN;
425
+ int64_t tiles = xtiles * ytiles;
426
+ int64_t duty = (tiles + nth - 1) / nth;
427
+ int64_t start = duty * ith;
428
+ int64_t end = start + duty;
429
+ if (end > tiles)
430
+ end = tiles;
431
+ for (int64_t job = start; job < end; ++job) {
432
+ int64_t ii = m0 + job / xtiles * RM;
433
+ int64_t jj = n0 + job % xtiles * RN;
434
+ D Cv[RN][RM] = {};
435
+ for (int64_t l = 0; l < k; l += KN)
436
+ for (int64_t j = 0; j < RN; ++j)
437
+ for (int64_t i = 0; i < RM; ++i)
438
+ Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
439
+ load<V>(B + ldb * (jj + j) + l),
440
+ Cv[j][i]);
441
+ for (int64_t j = 0; j < RN; ++j)
442
+ for (int64_t i = 0; i < RM; ++i)
443
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
444
+ }
445
+ }
446
+
447
+ const TA *const A;
448
+ const TB *const B;
449
+ TC *const C;
450
+ const int64_t k;
451
+ const int64_t lda;
452
+ const int64_t ldb;
453
+ const int64_t ldc;
454
+ const int ith;
455
+ const int nth;
456
+ };
457
+
458
+ //////////////////////////////////////////////////////////////////////////////////////////
459
+ // QUANT ZERO MATRIX MULTIPLICATION
460
+
461
+ #if defined(__ARM_FEATURE_DOTPROD)
462
+ template <typename TA>
463
+ class tinyBLAS_Q0_ARM {
464
+ public:
465
+ tinyBLAS_Q0_ARM(int64_t k,
466
+ const TA *A, int64_t lda,
467
+ const block_q8_0 *B, int64_t ldb,
468
+ float *C, int64_t ldc,
469
+ int ith, int nth)
470
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
471
+ }
472
+
473
+ void matmul(int64_t m, int64_t n) {
474
+ mnpack(0, m, 0, n);
475
+ }
476
+
477
+ private:
478
+ NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
479
+ int64_t mc, nc, mp, np;
480
+ switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
481
+ case 0x33:
482
+ mc = 3;
483
+ nc = 3;
484
+ gemm<3, 3>(m0, m, n0, n);
485
+ break;
486
+ case 0x32:
487
+ mc = 3;
488
+ nc = 2;
489
+ gemm<3, 2>(m0, m, n0, n);
490
+ break;
491
+ case 0x23:
492
+ mc = 2;
493
+ nc = 3;
494
+ gemm<2, 3>(m0, m, n0, n);
495
+ break;
496
+ case 0x22:
497
+ mc = 2;
498
+ nc = 2;
499
+ gemm<2, 2>(m0, m, n0, n);
500
+ break;
501
+ case 0x31:
502
+ mc = 3;
503
+ nc = 1;
504
+ gemm<3, 1>(m0, m, n0, n);
505
+ break;
506
+ case 0x13:
507
+ mc = 1;
508
+ nc = 3;
509
+ gemm<1, 3>(m0, m, n0, n);
510
+ break;
511
+ case 0x21:
512
+ mc = 2;
513
+ nc = 1;
514
+ gemm<2, 1>(m0, m, n0, n);
515
+ break;
516
+ case 0x12:
517
+ mc = 1;
518
+ nc = 2;
519
+ gemm<1, 2>(m0, m, n0, n);
520
+ break;
521
+ case 0x11:
522
+ mc = 1;
523
+ nc = 1;
524
+ gemm<1, 1>(m0, m, n0, n);
525
+ break;
526
+ default:
527
+ return;
528
+ }
529
+ mp = m0 + (m - m0) / mc * mc;
530
+ np = n0 + (n - n0) / nc * nc;
531
+ mnpack(mp, m, n0, np);
532
+ mnpack(m0, m, np, n);
533
+ }
534
+
535
+ template <int RM, int RN>
536
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
537
+ int64_t ytiles = (m - m0) / RM;
538
+ int64_t xtiles = (n - n0) / RN;
539
+ int64_t tiles = xtiles * ytiles;
540
+ int64_t duty = (tiles + nth - 1) / nth;
541
+ int64_t start = duty * ith;
542
+ int64_t end = start + duty;
543
+ if (end > tiles)
544
+ end = tiles;
545
+ for (int64_t job = start; job < end; ++job) {
546
+ int64_t ii = m0 + job / xtiles * RM;
547
+ int64_t jj = n0 + job % xtiles * RN;
548
+ float32x4_t Cv[RN][RM] = {};
549
+ for (int64_t l = 0; l < k; ++l)
550
+ for (int64_t j = 0; j < RN; ++j)
551
+ for (int64_t i = 0; i < RM; ++i)
552
+ Cv[j][i] = vmlaq_n_f32(Cv[j][i],
553
+ vcvtq_f32_s32(vdotq_s32(
554
+ vdotq_s32(vdupq_n_s32(0),
555
+ load_lo(A + lda * (ii + i) + l),
556
+ load_lo(B + ldb * (jj + j) + l)),
557
+ load_hi(A + lda * (ii + i) + l),
558
+ load_hi(B + ldb * (jj + j) + l))),
559
+ unhalf(A[lda * (ii + i) + l].d) *
560
+ unhalf(B[ldb * (jj + j) + l].d));
561
+ for (int64_t j = 0; j < RN; ++j)
562
+ for (int64_t i = 0; i < RM; ++i)
563
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
564
+ }
565
+ }
566
+
567
+ inline int8x16_t load_lo(const block_q8_0 *b) {
568
+ return vld1q_s8(b->qs);
569
+ }
570
+
571
+ inline int8x16_t load_hi(const block_q8_0 *b) {
572
+ return vld1q_s8(b->qs + 16);
573
+ }
574
+
575
+ inline int8x16_t load_lo(const block_q4_0 *b) {
576
+ return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
577
+ vdupq_n_u8(0x0f))),
578
+ vdupq_n_s8(0x8));
579
+ }
580
+
581
+ inline int8x16_t load_hi(const block_q4_0 *b) {
582
+ return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
583
+ vdupq_n_s8(0x8));
584
+ }
585
+
586
+ const TA *const A;
587
+ const block_q8_0 *const B;
588
+ float *const C;
589
+ const int64_t k;
590
+ const int64_t lda;
591
+ const int64_t ldb;
592
+ const int64_t ldc;
593
+ const int ith;
594
+ const int nth;
595
+ };
596
+ #endif // __ARM_FEATURE_DOTPROD
597
+
598
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
599
+ template <typename TA, typename TB, typename TC>
600
+ class tinyBLAS_Q0_AVX {
601
+ public:
602
+ tinyBLAS_Q0_AVX(int64_t k,
603
+ const TA *A, int64_t lda,
604
+ const TB *B, int64_t ldb,
605
+ TC *C, int64_t ldc,
606
+ int ith, int nth)
607
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
608
+ }
609
+
610
+ void matmul(int64_t m, int64_t n) {
611
+ mnpack(0, m, 0, n);
612
+ }
613
+
614
+ private:
615
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
616
+ int64_t mc, nc, mp, np;
617
+ switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
618
+ #if VECTOR_REGISTERS == 32
619
+ case 0x44:
620
+ mc = 4;
621
+ nc = 4;
622
+ #if defined(__AVX2__) && defined(__F16C__)
623
+ gemm4xN<4>(m0, m, n0, n);
624
+ #else
625
+ gemm<4, 4>(m0, m, n0, n);
626
+ #endif
627
+ break;
628
+ case 0x43:
629
+ mc = 4;
630
+ nc = 3;
631
+ #if defined(__AVX2__) && defined(__F16C__)
632
+ gemm4xN<3>(m0, m, n0, n);
633
+ #else
634
+ gemm<4, 3>(m0, m, n0, n);
635
+ #endif
636
+ break;
637
+ case 0x34:
638
+ mc = 3;
639
+ nc = 4;
640
+ #if defined(__AVX2__) && defined(__F16C__)
641
+ gemmMx4<3>(m0, m, n0, n);
642
+ #else
643
+ gemm<3, 4>(m0, m, n0, n);
644
+ #endif
645
+ break;
646
+ case 0x33:
647
+ mc = 3;
648
+ nc = 3;
649
+ gemm<3, 3>(m0, m, n0, n);
650
+ break;
651
+ case 0x42:
652
+ mc = 4;
653
+ nc = 2;
654
+ #if defined(__AVX2__) && defined(__F16C__)
655
+ gemm4xN<2>(m0, m, n0, n);
656
+ #else
657
+ gemm<4, 2>(m0, m, n0, n);
658
+ #endif
659
+ break;
660
+ case 0x24:
661
+ mc = 2;
662
+ nc = 4;
663
+ #if defined(__AVX2__) && defined(__F16C__)
664
+ gemmMx4<2>(m0, m, n0, n);
665
+ #else
666
+ gemm<2, 4>(m0, m, n0, n);
667
+ #endif
668
+ break;
669
+ #else
670
+ case 0x44:
671
+ case 0x43:
672
+ case 0x42:
673
+ mc = 4;
674
+ nc = 2;
675
+ #if defined(__AVX2__) && defined(__F16C__)
676
+ gemm4xN<2>(m0, m, n0, n);
677
+ #else
678
+ gemm<4, 2>(m0, m, n0, n);
679
+ #endif
680
+ break;
681
+ case 0x34:
682
+ case 0x24:
683
+ mc = 2;
684
+ nc = 4;
685
+ #if defined(__AVX2__) && defined(__F16C__)
686
+ gemmMx4<2>(m0, m, n0, n);
687
+ #else
688
+ gemm<2, 4>(m0, m, n0, n);
689
+ #endif
690
+ break;
691
+ case 0x33:
692
+ #endif
693
+ case 0x32:
694
+ mc = 3;
695
+ nc = 2;
696
+ gemm<3, 2>(m0, m, n0, n);
697
+ break;
698
+ case 0x23:
699
+ mc = 2;
700
+ nc = 3;
701
+ gemm<2, 3>(m0, m, n0, n);
702
+ break;
703
+ case 0x41:
704
+ mc = 4;
705
+ nc = 1;
706
+ #if defined(__AVX2__) && defined(__F16C__)
707
+ gemm4xN<1>(m0, m, n0, n);
708
+ #else
709
+ gemm<4, 1>(m0, m, n0, n);
710
+ #endif
711
+ break;
712
+ case 0x22:
713
+ mc = 2;
714
+ nc = 2;
715
+ gemm<2, 2>(m0, m, n0, n);
716
+ break;
717
+ case 0x14:
718
+ mc = 1;
719
+ nc = 4;
720
+ #if defined(__AVX2__) && defined(__F16C__)
721
+ gemmMx4<1>(m0, m, n0, n);
722
+ #else
723
+ gemm<1, 4>(m0, m, n0, n);
724
+ #endif
725
+ break;
726
+ case 0x31:
727
+ mc = 3;
728
+ nc = 1;
729
+ gemm<3, 1>(m0, m, n0, n);
730
+ break;
731
+ case 0x13:
732
+ mc = 1;
733
+ nc = 3;
734
+ gemm<1, 3>(m0, m, n0, n);
735
+ break;
736
+ case 0x21:
737
+ mc = 2;
738
+ nc = 1;
739
+ gemm<2, 1>(m0, m, n0, n);
740
+ break;
741
+ case 0x12:
742
+ mc = 1;
743
+ nc = 2;
744
+ gemm<1, 2>(m0, m, n0, n);
745
+ break;
746
+ case 0x11:
747
+ mc = 1;
748
+ nc = 1;
749
+ gemm<1, 1>(m0, m, n0, n);
750
+ break;
751
+ default:
752
+ return;
753
+ }
754
+ mp = m0 + (m - m0) / mc * mc;
755
+ np = n0 + (n - n0) / nc * nc;
756
+ mnpack(mp, m, n0, np);
757
+ mnpack(m0, m, np, n);
758
+ }
759
+
760
+ #if defined(__AVX2__) && defined(__F16C__)
761
+ // Templated functions for gemm of dimensions 4xN
762
+ template <int RN>
763
+ NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
764
+ int64_t ytiles = (m - m0) / 4;
765
+ int64_t xtiles = (n - n0) / RN;
766
+ int64_t tiles = xtiles * ytiles;
767
+ int64_t duty = (tiles + nth - 1) / nth;
768
+ int64_t start = duty * ith;
769
+ int64_t end = start + duty;
770
+ if (end > tiles)
771
+ end = tiles;
772
+ for (int64_t job = start; job < end; ++job) {
773
+ int64_t ii = m0 + job / xtiles * 4;
774
+ int64_t jj = n0 + job % xtiles * RN;
775
+ __m256 Cv[RN][4] = {};
776
+ for (int64_t l = 0; l < k; ++l) {
777
+ uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
778
+ // Convert delta values for four blocks to float values
779
+ __m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
780
+ __m256i avec0 = load(A + lda * (ii + 0) + l);
781
+ __m256i avec1 = load(A + lda * (ii + 1) + l);
782
+ __m256i avec2 = load(A + lda * (ii + 2) + l);
783
+ __m256i avec3 = load(A + lda * (ii + 3) + l);
784
+ for (int64_t j = 0; j < RN; ++j) {
785
+ __m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
786
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
787
+ __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
788
+ dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
789
+ // Computation of dot product and multiplication with appropriate delta value products
790
+ Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
791
+ updot(_mm256_sign_epi8(avec0, avec0),
792
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
793
+ Cv[j][0]);
794
+ Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
795
+ updot(_mm256_sign_epi8(avec1, avec1),
796
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
797
+ Cv[j][1]);
798
+ Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
799
+ updot(_mm256_sign_epi8(avec2, avec2),
800
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
801
+ Cv[j][2]);
802
+ Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
803
+ updot(_mm256_sign_epi8(avec3, avec3),
804
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
805
+ Cv[j][3]);
806
+ }
807
+ }
808
+
809
+ for (int64_t j = 0; j < RN; ++j)
810
+ for (int64_t i = 0; i < 4; ++i)
811
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
812
+ }
813
+ }
814
+
815
+ // Templated functions for gemm of dimensions Mx4
816
+ template <int RM>
817
+ NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
818
+ int64_t ytiles = (m - m0) / RM;
819
+ int64_t xtiles = (n - n0) / 4;
820
+ int64_t tiles = xtiles * ytiles;
821
+ int64_t duty = (tiles + nth - 1) / nth;
822
+ int64_t start = duty * ith;
823
+ int64_t end = start + duty;
824
+ if (end > tiles)
825
+ end = tiles;
826
+ for (int64_t job = start; job < end; ++job) {
827
+ int64_t ii = m0 + job / xtiles * RM;
828
+ int64_t jj = n0 + job % xtiles * 4;
829
+ __m256 Cv[4][RM] = {};
830
+ for (int64_t l = 0; l < k; ++l) {
831
+ uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
832
+ // Convert delta values for four blocks to float values
833
+ __m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
834
+ __m256i bvec0 = load(B + ldb * (jj + 0) + l);
835
+ __m256i bvec1 = load(B + ldb * (jj + 1) + l);
836
+ __m256i bvec2 = load(B + ldb * (jj + 2) + l);
837
+ __m256i bvec3 = load(B + ldb * (jj + 3) + l);
838
+ for (int64_t i = 0; i < RM; ++i) {
839
+ __m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
840
+ // Computation of product of delta values for four blocks and replicate it across 256 bit lane
841
+ __m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
842
+ dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
843
+ // Computation of dot product and multiplication with appropriate delta value products
844
+ Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
845
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
846
+ load(A + lda * (ii + i) + l)),
847
+ _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
848
+ Cv[0][i]);
849
+ Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
850
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
851
+ load(A + lda * (ii + i) + l)),
852
+ _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
853
+ Cv[1][i]);
854
+ Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
855
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
856
+ load(A + lda * (ii + i) + l)),
857
+ _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
858
+ Cv[2][i]);
859
+ Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
860
+ updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
861
+ load(A + lda * (ii + i) + l)),
862
+ _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
863
+ Cv[3][i]);
864
+ }
865
+ }
866
+ for (int64_t j = 0; j < 4; ++j)
867
+ for (int64_t i = 0; i < RM; ++i)
868
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
869
+ }
870
+ }
871
+ #endif
872
+
873
+ template <int RM, int RN>
874
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
875
+ int64_t ytiles = (m - m0) / RM;
876
+ int64_t xtiles = (n - n0) / RN;
877
+ int64_t tiles = xtiles * ytiles;
878
+ int64_t duty = (tiles + nth - 1) / nth;
879
+ int64_t start = duty * ith;
880
+ int64_t end = start + duty;
881
+ if (end > tiles)
882
+ end = tiles;
883
+ for (int64_t job = start; job < end; ++job) {
884
+ int64_t ii = m0 + job / xtiles * RM;
885
+ int64_t jj = n0 + job % xtiles * RN;
886
+ __m256 Cv[RN][RM] = {};
887
+ for (int64_t l = 0; l < k; ++l)
888
+ for (int64_t j = 0; j < RN; ++j)
889
+ for (int64_t i = 0; i < RM; ++i) {
890
+ #if defined(__AVX2__)
891
+ __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
892
+ load(A + lda * (ii + i) + l)),
893
+ _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
894
+ load(A + lda * (ii + i) + l)));
895
+ #else
896
+ __m128i ali0 = load0(A + lda * (ii + i) + l);
897
+ __m128i ali1 = load1(A + lda * (ii + i) + l);
898
+ __m128i blj0 = load0(B + ldb * (jj + j) + l);
899
+ __m128i blj1 = load1(B + ldb * (jj + j) + l);
900
+
901
+ __m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
902
+ __m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
903
+ __m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
904
+ __m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
905
+
906
+ // updot
907
+ const __m128i oneFill = _mm_set1_epi16(1);
908
+ __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
909
+ __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
910
+ __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
911
+ #endif
912
+ Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
913
+ unhalf(B[ldb * (jj + j) + l].d)),
914
+ udTmp,
915
+ Cv[j][i]);
916
+ }
917
+ for (int64_t j = 0; j < RN; ++j)
918
+ for (int64_t i = 0; i < RM; ++i)
919
+ C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
920
+ }
921
+ }
922
+
923
+ inline __m256i load(const block_q8_0 *b) {
924
+ return _mm256_loadu_si256((const __m256i *)b->qs);
925
+ }
926
+
927
+ inline __m128i load0(const block_q8_0 *b) {
928
+ return _mm_loadu_si128((const __m128i *)b->qs);
929
+ }
930
+
931
+ inline __m128i load1(const block_q8_0 *b) {
932
+ return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
933
+ }
934
+
935
+ inline __m256i load(const block_q4_0 *b) {
936
+ return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
937
+ }
938
+
939
+ inline __m128i load0(const block_q4_0 *b) {
940
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
941
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
942
+ }
943
+
944
+ inline __m128i load1(const block_q4_0 *b) {
945
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
946
+ return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
947
+ }
948
+
949
+ inline __m256i load(const block_q5_0 *b) {
950
+ return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
951
+ }
952
+
953
+ inline __m128i load0(const block_q5_0* b) {
954
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
955
+ uint32_t x32;
956
+ memcpy(&x32, b->qh, sizeof(uint32_t));
957
+ __m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
958
+ __m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
959
+ _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
960
+ _mm_shuffle_epi8(_mm_set1_epi32(x32),
961
+ _mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
962
+ bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
963
+ return _mm_or_si128(qxl, bytesl);
964
+ }
965
+
966
+ inline __m128i load1(const block_q5_0* b) {
967
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
968
+ uint32_t x32;
969
+ memcpy(&x32, b->qh, sizeof(uint32_t));
970
+ __m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
971
+ __m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
972
+ _mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
973
+ _mm_shuffle_epi8(_mm_set1_epi32(x32),
974
+ _mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
975
+ bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
976
+ return _mm_or_si128(qxh, bytesh);
977
+ }
978
+
979
+ inline __m256i load(const block_iq4_nl *b) {
980
+ return MM256_SET_M128I(load1(b), load0(b));
981
+ }
982
+
983
+ inline __m128i load0(const block_iq4_nl *b) {
984
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
985
+ return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
986
+ }
987
+
988
+ inline __m128i load1(const block_iq4_nl *b) {
989
+ const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
990
+ return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
991
+ }
992
+
993
+ inline __m256 updot(__m256i u, __m256i s) {
994
+ __m256i res;
995
+ #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
996
+ res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
997
+ #else
998
+ res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
999
+ #endif
1000
+ return _mm256_cvtepi32_ps(res);
1001
+ }
1002
+
1003
+ static inline __m256i denibble(const uint8_t *p) {
1004
+ __m128i x = _mm_loadu_si128((const __m128i *)p);
1005
+ return _mm256_and_si256(_mm256_set1_epi8(15),
1006
+ _mm256_insertf128_si256(_mm256_castsi128_si256(x),
1007
+ _mm_srli_epi16(x, 4), 1));
1008
+ }
1009
+
1010
+ static inline __m256i bittobyte(const uint8_t *p) {
1011
+ uint32_t x32;
1012
+ memcpy(&x32, p, sizeof(uint32_t));
1013
+ __m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
1014
+ _mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
1015
+ _mm256_shuffle_epi8(_mm256_set1_epi32(x32),
1016
+ _mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
1017
+ 0x0101010101010101, 0x0000000000000000))));
1018
+ return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
1019
+ }
1020
+
1021
+ const TA *const A;
1022
+ const TB *const B;
1023
+ TC *const C;
1024
+ const int64_t k;
1025
+ const int64_t lda;
1026
+ const int64_t ldb;
1027
+ const int64_t ldc;
1028
+ const int ith;
1029
+ const int nth;
1030
+ };
1031
+ #endif // __AVX__
1032
+
1033
+ //PPC Implementation
1034
+ #if defined(__MMA__)
1035
+
1036
+ #define SAVE_ACC(ACC, ii, jj) \
1037
+ __builtin_mma_disassemble_acc(vec_C, ACC); \
1038
+ for (int I = 0; I < 4; I++) { \
1039
+ for (int J = 0; J < 4; J++) { \
1040
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
1041
+ } \
1042
+ } \
1043
+
1044
+ template <typename TA, typename TB, typename TC>
1045
+ class tinyBLAS_PPC {
1046
+ public:
1047
+ tinyBLAS_PPC(int64_t k,
1048
+ const TA *A, int64_t lda,
1049
+ const TB *B, int64_t ldb,
1050
+ TC *C, int64_t ldc,
1051
+ int ith, int nth)
1052
+ : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
1053
+ }
1054
+
1055
+ void matmul(int64_t m, int64_t n) {
1056
+ mnpack(0, m, 0, n);
1057
+ }
1058
+
1059
+ private:
1060
+
1061
+ void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
1062
+
1063
+ void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
1064
+ int64_t i, j;
1065
+ float *aoffset = NULL, *boffset = NULL;
1066
+ float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
1067
+ float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
1068
+
1069
+ aoffset = const_cast<float*>(a);
1070
+ boffset = vec;
1071
+ j = (rows >> 3);
1072
+ if (j > 0) {
1073
+ do {
1074
+ aoffset1 = aoffset;
1075
+ aoffset2 = aoffset1 + lda;
1076
+ aoffset3 = aoffset2 + lda;
1077
+ aoffset4 = aoffset3 + lda;
1078
+ aoffset5 = aoffset4 + lda;
1079
+ aoffset6 = aoffset5 + lda;
1080
+ aoffset7 = aoffset6 + lda;
1081
+ aoffset8 = aoffset7 + lda;
1082
+ aoffset += 8 * lda;
1083
+ i = (cols >> 3);
1084
+ if (i > 0) {
1085
+ __vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
1086
+ vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
1087
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1088
+ do {
1089
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1090
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1091
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1092
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1093
+ C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
1094
+ C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
1095
+ C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
1096
+ C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
1097
+ __builtin_vsx_disassemble_pair(c1, &C1);
1098
+ __builtin_vsx_disassemble_pair(c2, &C2);
1099
+ __builtin_vsx_disassemble_pair(c3, &C3);
1100
+ __builtin_vsx_disassemble_pair(c4, &C4);
1101
+ __builtin_vsx_disassemble_pair(c5, &C5);
1102
+ __builtin_vsx_disassemble_pair(c6, &C6);
1103
+ __builtin_vsx_disassemble_pair(c7, &C7);
1104
+ __builtin_vsx_disassemble_pair(c8, &C8);
1105
+
1106
+ t1 = vec_mergeh(c1[0], c2[0]);
1107
+ t2 = vec_mergeh(c3[0], c4[0]);
1108
+ t3 = vec_mergeh(c5[0], c6[0]);
1109
+ t4 = vec_mergeh(c7[0], c8[0]);
1110
+ t5 = vec_xxpermdi(t1, t2, 0);
1111
+ t6 = vec_xxpermdi(t3, t4, 0);
1112
+ t7 = vec_xxpermdi(t1, t2, 3);
1113
+ t8 = vec_xxpermdi(t3, t4, 3);
1114
+ vec_xst(t5, 0, boffset);
1115
+ vec_xst(t6, 0, boffset+4);
1116
+ vec_xst(t7, 0, boffset+8);
1117
+ vec_xst(t8, 0, boffset+12);
1118
+
1119
+ t1 = vec_mergel(c1[0], c2[0]);
1120
+ t2 = vec_mergel(c3[0], c4[0]);
1121
+ t3 = vec_mergel(c5[0], c6[0]);
1122
+ t4 = vec_mergel(c7[0], c8[0]);
1123
+ t5 = vec_xxpermdi(t1, t2, 0);
1124
+ t6 = vec_xxpermdi(t3, t4, 0);
1125
+ t7 = vec_xxpermdi(t1, t2, 3);
1126
+ t8 = vec_xxpermdi(t3, t4, 3);
1127
+ vec_xst(t5, 0, boffset+16);
1128
+ vec_xst(t6, 0, boffset+20);
1129
+ vec_xst(t7, 0, boffset+24);
1130
+ vec_xst(t8, 0, boffset+28);
1131
+
1132
+ t1 = vec_mergeh(c1[1], c2[1]);
1133
+ t2 = vec_mergeh(c3[1], c4[1]);
1134
+ t3 = vec_mergeh(c5[1], c6[1]);
1135
+ t4 = vec_mergeh(c7[1], c8[1]);
1136
+ t5 = vec_xxpermdi(t1, t2, 0);
1137
+ t6 = vec_xxpermdi(t3, t4, 0);
1138
+ t7 = vec_xxpermdi(t1, t2, 3);
1139
+ t8 = vec_xxpermdi(t3, t4, 3);
1140
+ vec_xst(t5, 0, boffset+32);
1141
+ vec_xst(t6, 0, boffset+36);
1142
+ vec_xst(t7, 0, boffset+40);
1143
+ vec_xst(t8, 0, boffset+44);
1144
+
1145
+ t1 = vec_mergel(c1[1], c2[1]);
1146
+ t2 = vec_mergel(c3[1], c4[1]);
1147
+ t3 = vec_mergel(c5[1], c6[1]);
1148
+ t4 = vec_mergel(c7[1], c8[1]);
1149
+ t5 = vec_xxpermdi(t1, t2, 0);
1150
+ t6 = vec_xxpermdi(t3, t4, 0);
1151
+ t7 = vec_xxpermdi(t1, t2, 3);
1152
+ t8 = vec_xxpermdi(t3, t4, 3);
1153
+ vec_xst(t5, 0, boffset+48);
1154
+ vec_xst(t6, 0, boffset+52);
1155
+ vec_xst(t7, 0, boffset+56);
1156
+ vec_xst(t8, 0, boffset+60);
1157
+
1158
+ aoffset1 += 8*lda;
1159
+ aoffset2 += 8*lda;
1160
+ aoffset3 += 8*lda;
1161
+ aoffset4 += 8*lda;
1162
+ boffset += 64;
1163
+ i--;
1164
+ } while(i > 0);
1165
+ }
1166
+ if (cols & 4) {
1167
+ vector float c1, c2, c3, c4, c5, c6, c7, c8;
1168
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1169
+ c1 = vec_xl(0, aoffset1);
1170
+ c2 = vec_xl(0, aoffset2);
1171
+ c3 = vec_xl(0, aoffset3);
1172
+ c4 = vec_xl(0, aoffset4);
1173
+ c5 = vec_xl(0, aoffset5);
1174
+ c6 = vec_xl(0, aoffset6);
1175
+ c7 = vec_xl(0, aoffset7);
1176
+ c8 = vec_xl(0, aoffset8);
1177
+
1178
+ t1 = vec_mergeh(c1, c2);
1179
+ t2 = vec_mergeh(c3, c4);
1180
+ t3 = vec_mergeh(c5, c6);
1181
+ t4 = vec_mergeh(c7, c8);
1182
+ t5 = vec_xxpermdi(t1, t2, 0);
1183
+ t6 = vec_xxpermdi(t3, t4, 0);
1184
+ t7 = vec_xxpermdi(t1, t2, 3);
1185
+ t8 = vec_xxpermdi(t3, t4, 3);
1186
+ vec_xst(t5, 0, boffset);
1187
+ vec_xst(t6, 0, boffset+4);
1188
+ vec_xst(t7, 0, boffset+8);
1189
+ vec_xst(t8, 0, boffset+12);
1190
+
1191
+ t1 = vec_mergel(c1, c2);
1192
+ t2 = vec_mergel(c3, c4);
1193
+ t3 = vec_mergel(c5, c6);
1194
+ t4 = vec_mergel(c7, c8);
1195
+ t5 = vec_xxpermdi(t1, t2, 0);
1196
+ t6 = vec_xxpermdi(t3, t4, 0);
1197
+ t7 = vec_xxpermdi(t1, t2, 3);
1198
+ t8 = vec_xxpermdi(t3, t4, 3);
1199
+ vec_xst(t5, 0, boffset+16);
1200
+ vec_xst(t6, 0, boffset+20);
1201
+ vec_xst(t7, 0, boffset+24);
1202
+ vec_xst(t8, 0, boffset+28);
1203
+ }
1204
+ j--;
1205
+ } while(j > 0);
1206
+ }
1207
+
1208
+ if (rows & 4) {
1209
+ aoffset1 = aoffset;
1210
+ aoffset2 = aoffset1 + lda;
1211
+ aoffset3 = aoffset2 + lda;
1212
+ aoffset4 = aoffset3 + lda;
1213
+ aoffset += 4 * lda;
1214
+ i = (cols >> 3);
1215
+ if (i > 0) {
1216
+ __vector_pair C1, C2, C3, C4;
1217
+ vector float c1[2], c2[2], c3[2], c4[2];
1218
+ vector float t1, t2, t3, t4, t5, t6, t7, t8;
1219
+ do {
1220
+ C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
1221
+ C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
1222
+ C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
1223
+ C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
1224
+ __builtin_vsx_disassemble_pair(c1, &C1);
1225
+ __builtin_vsx_disassemble_pair(c2, &C2);
1226
+ __builtin_vsx_disassemble_pair(c3, &C3);
1227
+ __builtin_vsx_disassemble_pair(c4, &C4);
1228
+
1229
+ t1 = vec_mergeh(c1[0], c2[0]);
1230
+ t2 = vec_mergeh(c3[0], c4[0]);
1231
+ t3 = vec_mergel(c1[0], c2[0]);
1232
+ t4 = vec_mergel(c3[0], c4[0]);
1233
+ t5 = vec_xxpermdi(t1, t2, 0);
1234
+ t6 = vec_xxpermdi(t1, t2, 3);
1235
+ t7 = vec_xxpermdi(t3, t4, 0);
1236
+ t8 = vec_xxpermdi(t3, t4, 3);
1237
+ vec_xst(t5, 0, boffset);
1238
+ vec_xst(t6, 0, boffset+4);
1239
+ vec_xst(t7, 0, boffset+8);
1240
+ vec_xst(t8, 0, boffset+12);
1241
+
1242
+ t1 = vec_mergeh(c1[1], c2[1]);
1243
+ t2 = vec_mergeh(c3[1], c4[1]);
1244
+ t3 = vec_mergel(c1[1], c2[1]);
1245
+ t4 = vec_mergel(c3[1], c4[1]);
1246
+ t5 = vec_xxpermdi(t1, t2, 0);
1247
+ t6 = vec_xxpermdi(t1, t2, 3);
1248
+ t7 = vec_xxpermdi(t3, t4, 0);
1249
+ t8 = vec_xxpermdi(t3, t4, 3);
1250
+ vec_xst(t5, 0, boffset+16);
1251
+ vec_xst(t6, 0, boffset+20);
1252
+ vec_xst(t7, 0, boffset+24);
1253
+ vec_xst(t8, 0, boffset+28);
1254
+
1255
+ aoffset1 += 8*lda;
1256
+ aoffset2 += 8*lda;
1257
+ aoffset3 += 8*lda;
1258
+ aoffset4 += 8*lda;
1259
+ boffset += 32;
1260
+ i--;
1261
+ } while(i > 0);
1262
+ }
1263
+
1264
+ if (cols & 4) {
1265
+ vector float c1, c2, c3, c4;
1266
+ vector float t1, t2, t3, t4;
1267
+ c1 = vec_xl(0, aoffset1);
1268
+ c2 = vec_xl(0, aoffset2);
1269
+ c3 = vec_xl(0, aoffset3);
1270
+ c4 = vec_xl(0, aoffset4);
1271
+
1272
+ t1 = vec_mergeh(c1, c2);
1273
+ t2 = vec_mergeh(c3, c4);
1274
+ t3 = vec_xxpermdi(t1, t2, 0);
1275
+ t4 = vec_xxpermdi(t1, t2, 3);
1276
+ vec_xst(t3, 0, boffset);
1277
+ vec_xst(t4, 0, boffset+4);
1278
+
1279
+ t1 = vec_mergel(c1, c2);
1280
+ t2 = vec_mergel(c3, c4);
1281
+ t3 = vec_xxpermdi(t1, t2, 0);
1282
+ t4 = vec_xxpermdi(t1, t2, 3);
1283
+ vec_xst(t3, 0, boffset+8);
1284
+ vec_xst(t4, 0, boffset+12);
1285
+ }
1286
+ }
1287
+ if (rows & 3) {
1288
+ aoffset1 = aoffset;
1289
+ aoffset2 = aoffset1 + lda;
1290
+ aoffset3 = aoffset2 + lda;
1291
+ if (cols & 4) {
1292
+ vector float c1, c2, c3, c4 = {0};
1293
+ vector float t1, t2, t3, t4;
1294
+ c1 = vec_xl(0, aoffset1);
1295
+ c2 = vec_xl(0, aoffset2);
1296
+ c3 = vec_xl(0, aoffset3);
1297
+
1298
+ t1 = vec_mergeh(c1, c2);
1299
+ t2 = vec_mergeh(c3, c4);
1300
+ t3 = vec_xxpermdi(t1, t2, 0);
1301
+ t4 = vec_xxpermdi(t1, t2, 3);
1302
+ vec_xst(t3, 0, boffset);
1303
+ vec_xst(t4, 0, boffset+4);
1304
+
1305
+ t1 = vec_mergel(c1, c2);
1306
+ t2 = vec_mergel(c3, c4);
1307
+ t3 = vec_xxpermdi(t1, t2, 0);
1308
+ t4 = vec_xxpermdi(t1, t2, 3);
1309
+ vec_xst(t3, 0, boffset+8);
1310
+ vec_xst(t4, 0, boffset+12);
1311
+ }
1312
+ }
1313
+ }
1314
+
1315
+ void KERNEL_4x4(int64_t ii, int64_t jj) {
1316
+ vec_t vec_A[4], vec_B[4], vec_C[4];
1317
+ acc_t acc_0;
1318
+ __builtin_mma_xxsetaccz(&acc_0);
1319
+ for (int l = 0; l < k; l+=4) {
1320
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1321
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1322
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1323
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1324
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1325
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1326
+ }
1327
+ SAVE_ACC(&acc_0, ii, jj);
1328
+ }
1329
+
1330
+ void KERNEL_4x8(int64_t ii, int64_t jj) {
1331
+ vec_t vec_A[4], vec_B[8], vec_C[4];
1332
+ acc_t acc_0, acc_1;
1333
+ __builtin_mma_xxsetaccz(&acc_0);
1334
+ __builtin_mma_xxsetaccz(&acc_1);
1335
+ for (int64_t l = 0; l < k; l+=4) {
1336
+ READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
1337
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
1338
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
1339
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
1340
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
1341
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
1342
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
1343
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
1344
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
1345
+ __builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
1346
+ }
1347
+ SAVE_ACC(&acc_0, ii, jj);
1348
+ SAVE_ACC(&acc_1, ii, jj+4);
1349
+ }
1350
+
1351
+ void KERNEL_8x4(int64_t ii, int64_t jj) {
1352
+ vec_t vec_A[8], vec_B[4], vec_C[4];
1353
+ acc_t acc_0, acc_1;
1354
+ __builtin_mma_xxsetaccz(&acc_0);
1355
+ __builtin_mma_xxsetaccz(&acc_1);
1356
+ for (int64_t l = 0; l < k; l+=4) {
1357
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
1358
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1359
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
1360
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
1361
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
1362
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
1363
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
1364
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
1365
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
1366
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
1367
+ }
1368
+ SAVE_ACC(&acc_0, ii, jj);
1369
+ SAVE_ACC(&acc_1, ii+4, jj);
1370
+ }
1371
+
1372
+ void KERNEL_8x8(int64_t ii, int64_t jj) {
1373
+ vec_t vec_A[16], vec_B[16], vec_C[4];
1374
+ acc_t acc_0, acc_1, acc_2, acc_3;
1375
+ __builtin_mma_xxsetaccz(&acc_0);
1376
+ __builtin_mma_xxsetaccz(&acc_1);
1377
+ __builtin_mma_xxsetaccz(&acc_2);
1378
+ __builtin_mma_xxsetaccz(&acc_3);
1379
+ for (int l = 0; l < k; l+=8) {
1380
+ READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
1381
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
1382
+ for(int x = 0; x < 16; x+=2) {
1383
+ __builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
1384
+ __builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
1385
+ __builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
1386
+ __builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
1387
+ }
1388
+ }
1389
+ SAVE_ACC(&acc_0, ii, jj);
1390
+ SAVE_ACC(&acc_1, ii, jj+4);
1391
+ SAVE_ACC(&acc_2, ii+4, jj);
1392
+ SAVE_ACC(&acc_3, ii+4, jj+4);
1393
+ }
1394
+
1395
+ void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1396
+ int64_t mc, nc, mp, np;
1397
+ int m_rem = MIN(m - m0, 16);
1398
+ int n_rem = MIN(n - n0, 16);
1399
+ if (m_rem >= 16 && n_rem >= 8) {
1400
+ mc = 8;
1401
+ nc = 8;
1402
+ gemm<8,8>(m0, m, n0, n);
1403
+ } else if(m_rem >= 8 && n_rem >= 16) {
1404
+ mc = 8;
1405
+ nc = 8;
1406
+ gemm<8,8>(m0, m, n0, n);
1407
+ } else if (m_rem >= 8 && n_rem >= 8) {
1408
+ mc = 8;
1409
+ nc = 8;
1410
+ gemm<8,8>(m0, m, n0, n);
1411
+ } else if (m_rem >= 4 && n_rem >= 8) {
1412
+ mc = 4;
1413
+ nc = 8;
1414
+ gemm<4,8>(m0, m, n0, n);
1415
+ } else if (m_rem >= 8 && n_rem >= 4) {
1416
+ mc = 8;
1417
+ nc = 4;
1418
+ gemm<8,4>(m0, m, n0, n);
1419
+ } else if (m_rem >= 4 && n_rem >= 4) {
1420
+ mc = 4;
1421
+ nc = 4;
1422
+ gemm<4,4>(m0, m, n0, n);
1423
+ } else if ((m_rem < 4) && (n_rem > 4)) {
1424
+ nc = 4;
1425
+ switch(m_rem) {
1426
+ case 1:
1427
+ mc = 1;
1428
+ gemm_small(m0, m, n0, n, mc, nc);
1429
+ break;
1430
+ case 2:
1431
+ mc = 2;
1432
+ gemm_small(m0, m, n0, n, mc, nc);
1433
+ break;
1434
+ case 3:
1435
+ mc = 3;
1436
+ gemm_small(m0, m, n0, n, mc, nc);
1437
+ break;
1438
+ default:
1439
+ return;
1440
+ }
1441
+ } else if ((m_rem > 4) && (n_rem < 4)) {
1442
+ mc = 4;
1443
+ switch(n_rem) {
1444
+ case 1:
1445
+ nc = 1;
1446
+ gemm_small(m0, m, n0, n, mc, nc);
1447
+ break;
1448
+ case 2:
1449
+ nc = 2;
1450
+ gemm_small(m0, m, n0, n, mc, nc);
1451
+ break;
1452
+ case 3:
1453
+ nc = 3;
1454
+ gemm_small(m0, m, n0, n, mc, nc);
1455
+ break;
1456
+ default:
1457
+ return;
1458
+ }
1459
+ } else {
1460
+ switch((m_rem << 4) | n_rem) {
1461
+ case 0x43:
1462
+ mc = 4;
1463
+ nc = 3;
1464
+ gemm_small(m0, m, n0, n, mc, nc);
1465
+ break;
1466
+ case 0x42:
1467
+ mc = 4;
1468
+ nc = 2;
1469
+ gemm_small(m0, m, n0, n, mc, nc);
1470
+ break;
1471
+ case 0x41:
1472
+ mc = 4;
1473
+ nc = 1;
1474
+ gemm_small(m0, m, n0, n, mc, nc);
1475
+ break;
1476
+ case 0x34:
1477
+ mc = 3;
1478
+ nc = 4;
1479
+ gemm_small(m0, m, n0, n, mc, nc);
1480
+ break;
1481
+ case 0x33:
1482
+ mc = 3;
1483
+ nc = 3;
1484
+ gemm_small(m0, m, n0, n, mc, nc);
1485
+ break;
1486
+ case 0x32:
1487
+ mc = 3;
1488
+ nc = 2;
1489
+ gemm_small(m0, m, n0, n, mc, nc);
1490
+ break;
1491
+ case 0x31:
1492
+ mc = 3;
1493
+ nc = 1;
1494
+ gemm_small(m0, m, n0, n, mc, nc);
1495
+ break;
1496
+ case 0x24:
1497
+ mc = 2;
1498
+ nc = 4;
1499
+ gemm_small(m0, m, n0, n, mc, nc);
1500
+ break;
1501
+ case 0x23:
1502
+ mc = 2;
1503
+ nc = 3;
1504
+ gemm_small(m0, m, n0, n, mc, nc);
1505
+ break;
1506
+ case 0x22:
1507
+ mc = 2;
1508
+ nc = 2;
1509
+ gemm_small(m0, m, n0, n, mc, nc);
1510
+ break;
1511
+ case 0x21:
1512
+ mc = 2;
1513
+ nc = 1;
1514
+ gemm_small(m0, m, n0, n, mc, nc);
1515
+ break;
1516
+ case 0x14:
1517
+ mc = 1;
1518
+ nc = 4;
1519
+ gemm_small(m0, m, n0, n, mc, nc);
1520
+ break;
1521
+ case 0x13:
1522
+ mc = 1;
1523
+ nc = 3;
1524
+ gemm_small(m0, m, n0, n, mc, nc);
1525
+ break;
1526
+ case 0x12:
1527
+ mc = 1;
1528
+ nc = 2;
1529
+ gemm_small(m0, m, n0, n, mc, nc);
1530
+ break;
1531
+ case 0x11:
1532
+ mc = 1;
1533
+ nc = 1;
1534
+ gemm_small(m0, m, n0, n, mc, nc);
1535
+ break;
1536
+ default:
1537
+ return;
1538
+ }
1539
+ }
1540
+ mp = m0 + (m - m0) / mc * mc;
1541
+ np = n0 + (n - n0) / nc * nc;
1542
+ mnpack(mp, m, n0, np);
1543
+ mnpack(m0, m, np, n);
1544
+ }
1545
+
1546
+ void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
1547
+ int64_t ytiles = (m - m0) / RM;
1548
+ int64_t xtiles = (n - n0) / RN;
1549
+ int64_t tiles = xtiles * ytiles;
1550
+ int64_t duty = (tiles + nth - 1) / nth;
1551
+ int64_t start = duty * ith;
1552
+ int64_t end = start + duty;
1553
+ if (end > tiles)
1554
+ end = tiles;
1555
+ for (int64_t job = start; job < end; ++job) {
1556
+ int64_t ii = m0 + job / xtiles * RM;
1557
+ int64_t jj = n0 + job % xtiles * RN;
1558
+ vec_t vec_C[4];
1559
+ acc_t acc_0;
1560
+ __builtin_mma_xxsetaccz(&acc_0);
1561
+ vec_t vec_A[4], vec_B[4];
1562
+ for (int l=0; l<k; l+=4) {
1563
+ if (RN >= 4 && RM == 1) {
1564
+ float* a = const_cast<float*>(A+(ii)*lda+l);
1565
+ READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
1566
+ vec_A[0] = (vec_t)vec_xl(0,a);
1567
+ vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
1568
+ vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
1569
+ vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
1570
+ } else {
1571
+ READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
1572
+ READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
1573
+ }
1574
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
1575
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
1576
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
1577
+ __builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
1578
+ }
1579
+ __builtin_mma_disassemble_acc(vec_C, &acc_0);
1580
+ for (int I = 0; I < RM; I++) {
1581
+ for (int J = 0; J < RN; J++) {
1582
+ *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
1583
+ }
1584
+ }
1585
+ }
1586
+ }
1587
+
1588
+ template <int RM, int RN>
1589
+ NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
1590
+ int64_t ytiles = (m - m0) / RM;
1591
+ int64_t xtiles = (n - n0) / RN;
1592
+ int64_t tiles = xtiles * ytiles;
1593
+ int64_t duty = (tiles + nth - 1) / nth;
1594
+ int64_t start = duty * ith;
1595
+ int64_t end = start + duty;
1596
+ if (RM == 4 && RN == 4) {
1597
+ kernel = &tinyBLAS_PPC::KERNEL_4x4;
1598
+ } else if (RM == 4 && RN == 8) {
1599
+ kernel = &tinyBLAS_PPC::KERNEL_4x8;
1600
+ } else if (RM == 8 && RN == 4) {
1601
+ kernel = &tinyBLAS_PPC::KERNEL_8x4;
1602
+ } else if (RM == 8 && RN == 8) {
1603
+ kernel = &tinyBLAS_PPC::KERNEL_8x8;
1604
+ }
1605
+ if (end > tiles)
1606
+ end = tiles;
1607
+ for (int64_t job = start; job < end; ++job) {
1608
+ int64_t ii = m0 + job / xtiles * RM;
1609
+ int64_t jj = n0 + job % xtiles * RN;
1610
+ (this->*kernel)(ii, jj);
1611
+ }
1612
+ }
1613
+
1614
+ const TA *const A;
1615
+ const TB *const B;
1616
+ TC *C;
1617
+ TA *At;
1618
+ TB *Bt;
1619
+ const int64_t k;
1620
+ const int64_t lda;
1621
+ const int64_t ldb;
1622
+ const int64_t ldc;
1623
+ const int ith;
1624
+ const int nth;
1625
+ };
1626
+ #endif
1627
+ } // namespace
1628
+
1629
+ /**
1630
+ * Performs optimized matrix multiplication on CPU.
1631
+ *
1632
+ * This subroutine may compute C = Aᵀ * B with column major ordering.
1633
+ * Despite its name, this isn't a generalized implementation. Work is
1634
+ * only performed when a handwritten kernel is written and available.
1635
+ * Otherwise the caller should fall back to a general matmul routine.
1636
+ *
1637
+ * For example, for single-threaded single-precision GEMM you can say
1638
+ *
1639
+ * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
1640
+ * 0, 1,
1641
+ * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
1642
+ *
1643
+ * @param m is rows in `A` and `C`
1644
+ * @param n is cols in `B` and `C`
1645
+ * @param k is cols in `A` and rows in `B`
1646
+ * @param A is first input matrix (always transposed)
1647
+ * @param lda is row stride of `A`
1648
+ * @param B is second input matrix (never transposed)
1649
+ * @param ldb is row stride of `B`
1650
+ * @param C is input/output array of output matrices
1651
+ * @param ldc is row stride of `C`
1652
+ * @param ith is thread id (must be less than `nth`)
1653
+ * @param nth is number of threads (must be greater than zero)
1654
+ * @param Atype is GGML data type of `A`
1655
+ * @param Btype is GGML data type of `B`
1656
+ * @param Ctype is GGML data type of `C`
1657
+ * @return true if this function was able to service the matmul request
1658
+ */
1659
+ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
1660
+ int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
1661
+
1662
+ assert(m >= 0);
1663
+ assert(n >= 0);
1664
+ assert(k >= 0);
1665
+ assert(lda >= k);
1666
+ assert(ldb >= k);
1667
+ assert(ldc >= m);
1668
+ assert(nth > 0);
1669
+ assert(ith < nth);
1670
+
1671
+ // only enable sgemm for prompt processing
1672
+ if (n < 2)
1673
+ return false;
1674
+
1675
+ if (Ctype != GGML_TYPE_F32)
1676
+ return false;
1677
+
1678
+ switch (Atype) {
1679
+
1680
+ case GGML_TYPE_F32: {
1681
+ if (Btype != GGML_TYPE_F32)
1682
+ return false;
1683
+ #if defined(__AVX512F__)
1684
+ if (k % 16)
1685
+ return false;
1686
+ tinyBLAS<16, __m512, __m512, float, float, float> tb{
1687
+ k, (const float *)A, lda,
1688
+ (const float *)B, ldb,
1689
+ (float *)C, ldc,
1690
+ ith, nth};
1691
+ tb.matmul(m, n);
1692
+ return true;
1693
+ #elif defined(__AVX__) || defined(__AVX2__)
1694
+ if (k % 8)
1695
+ return false;
1696
+ tinyBLAS<8, __m256, __m256, float, float, float> tb{
1697
+ k, (const float *)A, lda,
1698
+ (const float *)B, ldb,
1699
+ (float *)C, ldc,
1700
+ ith, nth};
1701
+ tb.matmul(m, n);
1702
+ return true;
1703
+ #elif defined(__ARM_NEON)
1704
+ if (n < 4)
1705
+ return false;
1706
+ if (k % 4)
1707
+ return false;
1708
+ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
1709
+ k, (const float *)A, lda,
1710
+ (const float *)B, ldb,
1711
+ (float *)C, ldc,
1712
+ ith, nth};
1713
+ tb.matmul(m, n);
1714
+ return true;
1715
+ #elif defined(__MMA__)
1716
+ if (k % 8)
1717
+ return false;
1718
+ tinyBLAS_PPC<float, float, float> tb{
1719
+ k, (const float *)A, lda,
1720
+ (const float *)B, ldb,
1721
+ (float *)C, ldc,
1722
+ ith, nth};
1723
+ tb.matmul(m, n);
1724
+ return true;
1725
+ #else
1726
+ return false;
1727
+ #endif
1728
+ }
1729
+
1730
+ case GGML_TYPE_F16: {
1731
+ #if defined(__AVX512F__)
1732
+ if (k % 16)
1733
+ return false;
1734
+ if (Btype != GGML_TYPE_F32)
1735
+ return false;
1736
+ tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
1737
+ k, (const ggml_fp16_t *)A, lda,
1738
+ (const float *)B, ldb,
1739
+ (float *)C, ldc,
1740
+ ith, nth};
1741
+ tb.matmul(m, n);
1742
+ return true;
1743
+ #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
1744
+ if (k % 8)
1745
+ return false;
1746
+ if (Btype != GGML_TYPE_F32)
1747
+ return false;
1748
+ tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
1749
+ k, (const ggml_fp16_t *)A, lda,
1750
+ (const float *)B, ldb,
1751
+ (float *)C, ldc,
1752
+ ith, nth};
1753
+ tb.matmul(m, n);
1754
+ return true;
1755
+ #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
1756
+ if (n < 8)
1757
+ return false;
1758
+ if (k % 8)
1759
+ return false;
1760
+ if (Btype != GGML_TYPE_F16)
1761
+ return false;
1762
+ tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
1763
+ k, (const ggml_fp16_t *)A, lda,
1764
+ (const ggml_fp16_t *)B, ldb,
1765
+ (float *)C, ldc,
1766
+ ith, nth};
1767
+ tb.matmul(m, n);
1768
+ return true;
1769
+ #elif defined(__ARM_NEON) && !defined(_MSC_VER)
1770
+ if (k % 4)
1771
+ return false;
1772
+ if (Btype != GGML_TYPE_F32)
1773
+ return false;
1774
+ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
1775
+ k, (const ggml_fp16_t *)A, lda,
1776
+ (const float *)B, ldb,
1777
+ (float *)C, ldc,
1778
+ ith, nth};
1779
+ tb.matmul(m, n);
1780
+ return true;
1781
+ #else
1782
+ return false;
1783
+ #endif
1784
+ }
1785
+
1786
+ case GGML_TYPE_Q8_0: {
1787
+ if (Btype != GGML_TYPE_Q8_0)
1788
+ return false;
1789
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1790
+ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
1791
+ k, (const block_q8_0 *)A, lda,
1792
+ (const block_q8_0 *)B, ldb,
1793
+ (float *)C, ldc,
1794
+ ith, nth};
1795
+ tb.matmul(m, n);
1796
+ return true;
1797
+ #elif defined(__ARM_FEATURE_DOTPROD)
1798
+ tinyBLAS_Q0_ARM<block_q8_0> tb{
1799
+ k, (const block_q8_0 *)A, lda,
1800
+ (const block_q8_0 *)B, ldb,
1801
+ (float *)C, ldc,
1802
+ ith, nth};
1803
+ tb.matmul(m, n);
1804
+ return true;
1805
+ #else
1806
+ return false;
1807
+ #endif
1808
+ }
1809
+
1810
+ case GGML_TYPE_Q4_0: {
1811
+ if (Btype != GGML_TYPE_Q8_0)
1812
+ return false;
1813
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1814
+ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
1815
+ k, (const block_q4_0 *)A, lda,
1816
+ (const block_q8_0 *)B, ldb,
1817
+ (float *)C, ldc,
1818
+ ith, nth};
1819
+ tb.matmul(m, n);
1820
+ return true;
1821
+ #elif defined(__ARM_FEATURE_DOTPROD)
1822
+ tinyBLAS_Q0_ARM<block_q4_0> tb{
1823
+ k, (const block_q4_0 *)A, lda,
1824
+ (const block_q8_0 *)B, ldb,
1825
+ (float *)C, ldc,
1826
+ ith, nth};
1827
+ tb.matmul(m, n);
1828
+ return true;
1829
+ #else
1830
+ return false;
1831
+ #endif
1832
+ }
1833
+
1834
+ case GGML_TYPE_Q5_0: {
1835
+ if (Btype != GGML_TYPE_Q8_0)
1836
+ return false;
1837
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1838
+ tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
1839
+ k, (const block_q5_0 *)A, lda,
1840
+ (const block_q8_0 *)B, ldb,
1841
+ (float *)C, ldc,
1842
+ ith, nth};
1843
+ tb.matmul(m, n);
1844
+ return true;
1845
+ #else
1846
+ return false;
1847
+ #endif
1848
+ }
1849
+
1850
+ case GGML_TYPE_IQ4_NL: {
1851
+ if (Btype != GGML_TYPE_Q8_0)
1852
+ return false;
1853
+ #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
1854
+ tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
1855
+ k, (const block_iq4_nl *)A, lda,
1856
+ (const block_q8_0 *)B, ldb,
1857
+ (float *)C, ldc,
1858
+ ith, nth};
1859
+ tb.matmul(m, n);
1860
+ return true;
1861
+ #else
1862
+ return false;
1863
+ #endif
1864
+ }
1865
+
1866
+ default:
1867
+ return false;
1868
+ }
1869
+
1870
+ (void)m;
1871
+ (void)n;
1872
+ (void)k;
1873
+ (void)A;
1874
+ (void)lda;
1875
+ (void)B;
1876
+ (void)ldb;
1877
+ (void)C;
1878
+ (void)ldc;
1879
+ (void)ith;
1880
+ (void)nth;
1881
+ (void)Atype;
1882
+ (void)Btype;
1883
+ (void)Ctype;
1884
+ }
ggml/src/ggml-cpu/llamafile/sgemm.h ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+ #include <stdint.h>
3
+ #include <stdbool.h>
4
+ #ifdef __cplusplus
5
+ extern "C" {
6
+ #endif
7
+
8
+ bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
9
+ const void *, int64_t, void *, int64_t, int, int,
10
+ int, int, int);
11
+
12
+ #ifdef __cplusplus
13
+ }
14
+ #endif
ggml/src/ggml-cuda/common.cuh CHANGED
@@ -6,7 +6,7 @@
6
  #include <cstdint>
7
  #include <memory>
8
 
9
- #if defined(GGML_USE_HIPBLAS)
10
  #define GGML_COMMON_DECL_HIP
11
  #define GGML_COMMON_IMPL_HIP
12
  #else
@@ -26,13 +26,13 @@
26
  #include <string>
27
  #include <vector>
28
 
29
- #if defined(GGML_USE_HIPBLAS)
30
  #include "vendors/hip.h"
31
  #elif defined(GGML_USE_MUSA)
32
  #include "vendors/musa.h"
33
  #else
34
  #include "vendors/cuda.h"
35
- #endif // defined(GGML_USE_HIPBLAS)
36
 
37
  #define STRINGIZE_IMPL(...) #__VA_ARGS__
38
  #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
@@ -97,7 +97,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
97
 
98
  #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
99
 
100
- #if !defined(GGML_USE_HIPBLAS)
101
  static const char * cu_get_error_str(CUresult err) {
102
  const char * err_str;
103
  cuGetErrorString(err, &err_str);
@@ -120,21 +120,21 @@ typedef float dfloat; // dequantize float
120
  typedef float2 dfloat2;
121
  #endif // GGML_CUDA_F16
122
 
123
- #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
124
  #define FP16_AVAILABLE
125
- #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
126
 
127
  #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
128
  #define FAST_FP16_AVAILABLE
129
  #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
130
 
131
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
132
  #define FP16_MMA_AVAILABLE
133
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
134
 
135
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
136
  #define INT8_MMA_AVAILABLE
137
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
138
 
139
  #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
140
  #define FLASH_ATTN_AVAILABLE
@@ -156,14 +156,14 @@ static constexpr bool int8_mma_available(const int cc) {
156
  static __device__ void no_device_code(
157
  const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
158
 
159
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
160
  printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
161
  file_name, line, function_name, arch);
162
  GGML_UNUSED(arch_list);
163
  #else
164
  printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
165
  file_name, line, function_name, arch, arch_list);
166
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
167
  __trap();
168
 
169
  GGML_UNUSED(no_device_code); // suppress unused function warning
@@ -176,7 +176,7 @@ static __device__ void no_device_code(
176
  #endif // __CUDA_ARCH__
177
 
178
  static __device__ __forceinline__ int warp_reduce_sum(int x) {
179
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
180
  return __reduce_add_sync(0xffffffff, x);
181
  #else
182
  #pragma unroll
@@ -184,7 +184,7 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
184
  x += __shfl_xor_sync(0xffffffff, x, mask, 32);
185
  }
186
  return x;
187
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
188
  }
189
 
190
  static __device__ __forceinline__ float warp_reduce_sum(float x) {
@@ -207,7 +207,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
207
  static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
208
  #ifdef FP16_AVAILABLE
209
 
210
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
211
  #pragma unroll
212
  for (int mask = 16; mask > 0; mask >>= 1) {
213
  const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
@@ -221,7 +221,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
221
  a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
222
  }
223
  return a;
224
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
225
 
226
  #else
227
  NO_DEVICE_CODE;
@@ -240,11 +240,11 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
240
  static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
241
  #ifdef FP16_AVAILABLE
242
 
243
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
244
  return __float2half(fmaxf(__half2float(a), __half2float(b)));
245
  #else
246
  return __hmax(a, b);
247
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
248
 
249
  #else
250
  NO_DEVICE_CODE;
@@ -254,7 +254,7 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
254
  }
255
 
256
  static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
257
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
258
 
259
  #if CUDART_VERSION >= CUDART_HMAX
260
  return __hmax2(a, b);
@@ -269,11 +269,11 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
269
  GGML_UNUSED(a);
270
  GGML_UNUSED(b);
271
  NO_DEVICE_CODE;
272
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
273
  }
274
 
275
  static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
276
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
277
  #pragma unroll
278
  for (int mask = 16; mask > 0; mask >>= 1) {
279
  x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
@@ -282,7 +282,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
282
  #else
283
  GGML_UNUSED(x);
284
  NO_DEVICE_CODE;
285
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
286
  }
287
 
288
  #if CUDART_VERSION < CUDART_HMASK
@@ -294,7 +294,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
294
  #endif // CUDART_VERSION < CUDART_HMASK
295
 
296
  static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
297
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
298
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
299
  c = __builtin_amdgcn_sdot4(a, b, c, false);
300
  #elif defined(RDNA3)
@@ -320,7 +320,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
320
  #endif
321
  return c;
322
 
323
- #else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
324
 
325
  #if __CUDA_ARCH__ >= MIN_CC_DP4A
326
  return __dp4a(a, b, c);
@@ -330,7 +330,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
330
  return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
331
  #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
332
 
333
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
334
  }
335
 
336
  // TODO: move to ggml-common.h
 
6
  #include <cstdint>
7
  #include <memory>
8
 
9
+ #if defined(GGML_USE_HIP)
10
  #define GGML_COMMON_DECL_HIP
11
  #define GGML_COMMON_IMPL_HIP
12
  #else
 
26
  #include <string>
27
  #include <vector>
28
 
29
+ #if defined(GGML_USE_HIP)
30
  #include "vendors/hip.h"
31
  #elif defined(GGML_USE_MUSA)
32
  #include "vendors/musa.h"
33
  #else
34
  #include "vendors/cuda.h"
35
+ #endif // defined(GGML_USE_HIP)
36
 
37
  #define STRINGIZE_IMPL(...) #__VA_ARGS__
38
  #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
 
97
 
98
  #define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
99
 
100
+ #if !defined(GGML_USE_HIP)
101
  static const char * cu_get_error_str(CUresult err) {
102
  const char * err_str;
103
  cuGetErrorString(err, &err_str);
 
120
  typedef float2 dfloat2;
121
  #endif // GGML_CUDA_F16
122
 
123
+ #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
124
  #define FP16_AVAILABLE
125
+ #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
126
 
127
  #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
128
  #define FAST_FP16_AVAILABLE
129
  #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
130
 
131
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
132
  #define FP16_MMA_AVAILABLE
133
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
134
 
135
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
136
  #define INT8_MMA_AVAILABLE
137
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
138
 
139
  #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
140
  #define FLASH_ATTN_AVAILABLE
 
156
  static __device__ void no_device_code(
157
  const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
158
 
159
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
160
  printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
161
  file_name, line, function_name, arch);
162
  GGML_UNUSED(arch_list);
163
  #else
164
  printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
165
  file_name, line, function_name, arch, arch_list);
166
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
167
  __trap();
168
 
169
  GGML_UNUSED(no_device_code); // suppress unused function warning
 
176
  #endif // __CUDA_ARCH__
177
 
178
  static __device__ __forceinline__ int warp_reduce_sum(int x) {
179
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
180
  return __reduce_add_sync(0xffffffff, x);
181
  #else
182
  #pragma unroll
 
184
  x += __shfl_xor_sync(0xffffffff, x, mask, 32);
185
  }
186
  return x;
187
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
188
  }
189
 
190
  static __device__ __forceinline__ float warp_reduce_sum(float x) {
 
207
  static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
208
  #ifdef FP16_AVAILABLE
209
 
210
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
211
  #pragma unroll
212
  for (int mask = 16; mask > 0; mask >>= 1) {
213
  const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
 
221
  a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
222
  }
223
  return a;
224
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
225
 
226
  #else
227
  NO_DEVICE_CODE;
 
240
  static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
241
  #ifdef FP16_AVAILABLE
242
 
243
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
244
  return __float2half(fmaxf(__half2float(a), __half2float(b)));
245
  #else
246
  return __hmax(a, b);
247
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
248
 
249
  #else
250
  NO_DEVICE_CODE;
 
254
  }
255
 
256
  static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
257
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
258
 
259
  #if CUDART_VERSION >= CUDART_HMAX
260
  return __hmax2(a, b);
 
269
  GGML_UNUSED(a);
270
  GGML_UNUSED(b);
271
  NO_DEVICE_CODE;
272
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
273
  }
274
 
275
  static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
276
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
277
  #pragma unroll
278
  for (int mask = 16; mask > 0; mask >>= 1) {
279
  x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
 
282
  #else
283
  GGML_UNUSED(x);
284
  NO_DEVICE_CODE;
285
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
286
  }
287
 
288
  #if CUDART_VERSION < CUDART_HMASK
 
294
  #endif // CUDART_VERSION < CUDART_HMASK
295
 
296
  static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
297
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
298
  #if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
299
  c = __builtin_amdgcn_sdot4(a, b, c, false);
300
  #elif defined(RDNA3)
 
320
  #endif
321
  return c;
322
 
323
+ #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
324
 
325
  #if __CUDA_ARCH__ >= MIN_CC_DP4A
326
  return __dp4a(a, b, c);
 
330
  return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
331
  #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
332
 
333
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
334
  }
335
 
336
  // TODO: move to ggml-common.h
ggml/src/ggml-cuda/fattn-common.cuh CHANGED
@@ -517,9 +517,9 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
517
  }
518
 
519
  template<int D, int parallel_blocks> // D == head size
520
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
521
  __launch_bounds__(D, 1)
522
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
523
  static __global__ void flash_attn_combine_results(
524
  const float * __restrict__ VKQ_parts,
525
  const float2 * __restrict__ VKQ_meta,
 
517
  }
518
 
519
  template<int D, int parallel_blocks> // D == head size
520
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
521
  __launch_bounds__(D, 1)
522
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
523
  static __global__ void flash_attn_combine_results(
524
  const float * __restrict__ VKQ_parts,
525
  const float2 * __restrict__ VKQ_meta,
ggml/src/ggml-cuda/fattn-tile-f16.cu CHANGED
@@ -5,9 +5,9 @@
5
  #define FATTN_KQ_STRIDE_TILE_F16 64
6
 
7
  template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
8
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
  __launch_bounds__(nwarps*WARP_SIZE, 1)
10
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11
  static __global__ void flash_attn_tile_ext_f16(
12
  const char * __restrict__ Q,
13
  const char * __restrict__ K,
 
5
  #define FATTN_KQ_STRIDE_TILE_F16 64
6
 
7
  template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
8
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
  __launch_bounds__(nwarps*WARP_SIZE, 1)
10
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
  static __global__ void flash_attn_tile_ext_f16(
12
  const char * __restrict__ Q,
13
  const char * __restrict__ K,
ggml/src/ggml-cuda/fattn-tile-f32.cu CHANGED
@@ -5,9 +5,9 @@
5
  #define FATTN_KQ_STRIDE_TILE_F32 32
6
 
7
  template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
8
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
9
  __launch_bounds__(nwarps*WARP_SIZE, 1)
10
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11
  static __global__ void flash_attn_tile_ext_f32(
12
  const char * __restrict__ Q,
13
  const char * __restrict__ K,
 
5
  #define FATTN_KQ_STRIDE_TILE_F32 32
6
 
7
  template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
8
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
9
  __launch_bounds__(nwarps*WARP_SIZE, 1)
10
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
  static __global__ void flash_attn_tile_ext_f32(
12
  const char * __restrict__ Q,
13
  const char * __restrict__ K,
ggml/src/ggml-cuda/fattn-vec-f16.cuh CHANGED
@@ -2,9 +2,9 @@
2
  #include "fattn-common.cuh"
3
 
4
  template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
6
  __launch_bounds__(D, 1)
7
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
8
  static __global__ void flash_attn_vec_ext_f16(
9
  const char * __restrict__ Q,
10
  const char * __restrict__ K,
 
2
  #include "fattn-common.cuh"
3
 
4
  template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
6
  __launch_bounds__(D, 1)
7
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
8
  static __global__ void flash_attn_vec_ext_f16(
9
  const char * __restrict__ Q,
10
  const char * __restrict__ K,
ggml/src/ggml-cuda/fattn-vec-f32.cuh CHANGED
@@ -2,9 +2,9 @@
2
  #include "fattn-common.cuh"
3
 
4
  template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
6
  __launch_bounds__(D, 1)
7
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
8
  static __global__ void flash_attn_vec_ext_f32(
9
  const char * __restrict__ Q,
10
  const char * __restrict__ K,
 
2
  #include "fattn-common.cuh"
3
 
4
  template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
5
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
6
  __launch_bounds__(D, 1)
7
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
8
  static __global__ void flash_attn_vec_ext_f32(
9
  const char * __restrict__ Q,
10
  const char * __restrict__ K,
ggml/src/ggml-cuda/fattn-wmma-f16.cuh CHANGED
@@ -7,9 +7,9 @@
7
 
8
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
9
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
10
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
11
  __launch_bounds__(nwarps*WARP_SIZE, 1)
12
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
13
  static __global__ void flash_attn_ext_f16(
14
  const char * __restrict__ Q,
15
  const char * __restrict__ K,
 
7
 
8
  // D == head size, VKQ_stride == num VKQ rows calculated in parallel:
9
  template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
10
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
11
  __launch_bounds__(nwarps*WARP_SIZE, 1)
12
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
13
  static __global__ void flash_attn_ext_f16(
14
  const char * __restrict__ Q,
15
  const char * __restrict__ K,
ggml/src/ggml-cuda/ggml-cuda.cu ADDED
The diff for this file is too large to render. See raw diff
 
ggml/src/ggml-cuda/ggml/CMakeLists.txt ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
2
+
3
+ find_package(CUDAToolkit)
4
+
5
+ if (CUDAToolkit_FOUND)
6
+ message(STATUS "CUDA Toolkit found")
7
+
8
+ if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
9
+ # 52 == lowest CUDA 12 standard
10
+ # 60 == FP16 CUDA intrinsics
11
+ # 61 == integer CUDA intrinsics
12
+ # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
13
+ if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
14
+ set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
15
+ else()
16
+ set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
17
+ #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
18
+ endif()
19
+ endif()
20
+ message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
21
+
22
+ enable_language(CUDA)
23
+
24
+ file(GLOB GGML_HEADERS_CUDA "*.cuh")
25
+ list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
26
+
27
+ file(GLOB GGML_SOURCES_CUDA "*.cu")
28
+ file(GLOB SRCS "template-instances/fattn-wmma*.cu")
29
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
30
+ file(GLOB SRCS "template-instances/mmq*.cu")
31
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
32
+
33
+ if (GGML_CUDA_FA_ALL_QUANTS)
34
+ file(GLOB SRCS "template-instances/fattn-vec*.cu")
35
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
36
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
37
+ else()
38
+ file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
39
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
40
+ file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
41
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
42
+ file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
43
+ list(APPEND GGML_SOURCES_CUDA ${SRCS})
44
+ endif()
45
+
46
+ add_library(ggml-cuda
47
+ ${GGML_HEADERS_CUDA}
48
+ ${GGML_SOURCES_CUDA}
49
+ )
50
+
51
+ target_link_libraries(ggml-cuda PRIVATE ggml-base)
52
+ target_include_directories(ggml-cuda PRIVATE . ..)
53
+
54
+ # TODO: change the definitions to this target only
55
+
56
+ add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
57
+ add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
58
+ add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
59
+ add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
60
+
61
+ if (GGML_CUDA_GRAPHS)
62
+ add_compile_definitions(GGML_CUDA_USE_GRAPHS)
63
+ endif()
64
+
65
+ if (GGML_CUDA_FORCE_DMMV)
66
+ add_compile_definitions(GGML_CUDA_FORCE_DMMV)
67
+ endif()
68
+
69
+ if (GGML_CUDA_FORCE_MMQ)
70
+ add_compile_definitions(GGML_CUDA_FORCE_MMQ)
71
+ endif()
72
+
73
+ if (GGML_CUDA_FORCE_CUBLAS)
74
+ add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
75
+ endif()
76
+
77
+ if (GGML_CUDA_NO_VMM)
78
+ add_compile_definitions(GGML_CUDA_NO_VMM)
79
+ endif()
80
+
81
+ if (DEFINED GGML_CUDA_DMMV_Y)
82
+ add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility
83
+ endif()
84
+
85
+ if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
86
+ add_compile_definitions(GGML_CUDA_F16)
87
+ endif()
88
+
89
+ if (GGML_CUDA_NO_PEER_COPY)
90
+ add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
91
+ endif()
92
+
93
+ if (GGML_STATIC)
94
+ if (WIN32)
95
+ # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
96
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
97
+ else ()
98
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
99
+ endif()
100
+ else()
101
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
102
+ endif()
103
+
104
+ if (GGML_CUDA_NO_VMM)
105
+ # No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
106
+ else()
107
+ target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver)
108
+ endif()
109
+
110
+ set(CUDA_CXX_FLAGS "")
111
+
112
+ set(CUDA_FLAGS -use_fast_math)
113
+
114
+ if (GGML_FATAL_WARNINGS)
115
+ list(APPEND CUDA_FLAGS -Werror all-warnings)
116
+ endif()
117
+
118
+ if (GGML_ALL_WARNINGS AND NOT MSVC)
119
+ set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
120
+ if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
121
+ list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})
122
+ endif()
123
+
124
+ execute_process(
125
+ COMMAND ${NVCC_CMD} -Xcompiler --version
126
+ OUTPUT_VARIABLE CUDA_CCFULLVER
127
+ ERROR_QUIET
128
+ )
129
+
130
+ if (NOT CUDA_CCFULLVER MATCHES clang)
131
+ set(CUDA_CCID "GNU")
132
+ execute_process(
133
+ COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
134
+ OUTPUT_VARIABLE CUDA_CCVER
135
+ ERROR_QUIET
136
+ )
137
+ else()
138
+ if (CUDA_CCFULLVER MATCHES Apple)
139
+ set(CUDA_CCID "AppleClang")
140
+ else()
141
+ set(CUDA_CCID "Clang")
142
+ endif()
143
+ string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
144
+ endif()
145
+
146
+ message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
147
+
148
+ get_flags(${CUDA_CCID} ${CUDA_CCVER})
149
+ list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
150
+ endif()
151
+
152
+ if (NOT MSVC)
153
+ list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
154
+ endif()
155
+
156
+ list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
157
+
158
+ if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
159
+ list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
160
+ endif()
161
+
162
+ add_compile_options("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
163
+ else()
164
+ message(FATAL_ERROR "CUDA Toolkit not found")
165
+ endif()
ggml/src/ggml-cuda/mmq.cuh CHANGED
@@ -100,9 +100,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
100
  return 128;
101
  #else // INT8_MMA_AVAILABLE
102
 
103
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
104
  return 128;
105
- #else // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
106
 
107
  #if __CUDA_ARCH__ >= CC_VOLTA
108
  #ifdef GGML_CUDA_FORCE_MMQ
@@ -115,7 +115,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
115
  return 64;
116
  #endif // __CUDA_ARCH__ >= CC_VOLTA
117
 
118
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
119
  #endif // INT8_MMA_AVAILABLE
120
  }
121
 
@@ -124,7 +124,7 @@ static constexpr int get_mmq_y_host(const int cc) {
124
  }
125
 
126
  static constexpr __device__ int get_mmq_y_device() {
127
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
128
  #if defined(RDNA1)
129
  return 64;
130
  #else
@@ -136,7 +136,7 @@ static constexpr __device__ int get_mmq_y_device() {
136
  #else
137
  return 64;
138
  #endif // __CUDA_ARCH__ >= CC_VOLTA
139
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
140
  }
141
 
142
  #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
@@ -2569,7 +2569,7 @@ static __device__ void mul_mat_q_process_tile(
2569
  // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
2570
 
2571
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2572
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2573
  #if defined(RDNA3) || defined(RDNA2)
2574
  __launch_bounds__(WARP_SIZE*nwarps, 2)
2575
  #endif // defined(RDNA3) || defined(RDNA2)
@@ -2579,7 +2579,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2579
  #else
2580
  __launch_bounds__(WARP_SIZE*nwarps, 2)
2581
  #endif // __CUDA_ARCH__ >= CC_VOLTA
2582
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
2583
  static __global__ void mul_mat_q(
2584
  const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2585
  const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
@@ -2594,7 +2594,7 @@ static __global__ void mul_mat_q(
2594
  constexpr int mmq_y = get_mmq_y_device();
2595
 
2596
  // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
2597
- #if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
2598
  {
2599
  constexpr bool fixup = false;
2600
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
@@ -2602,7 +2602,7 @@ static __global__ void mul_mat_q(
2602
  blockIdx.x, blockIdx.y, 0, ne00/qk);
2603
  return;
2604
  }
2605
- #endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
2606
 
2607
  const int64_t blocks_per_ne00 = ne00 / qk;
2608
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
@@ -2765,14 +2765,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
2765
 
2766
  const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
2767
 
2768
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
2769
  static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
2770
  if (!shmem_limit_raised[id]) {
2771
  CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2772
  CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2773
  shmem_limit_raised[id] = true;
2774
  }
2775
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
2776
 
2777
  const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
2778
  const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
 
100
  return 128;
101
  #else // INT8_MMA_AVAILABLE
102
 
103
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
104
  return 128;
105
+ #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
106
 
107
  #if __CUDA_ARCH__ >= CC_VOLTA
108
  #ifdef GGML_CUDA_FORCE_MMQ
 
115
  return 64;
116
  #endif // __CUDA_ARCH__ >= CC_VOLTA
117
 
118
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
119
  #endif // INT8_MMA_AVAILABLE
120
  }
121
 
 
124
  }
125
 
126
  static constexpr __device__ int get_mmq_y_device() {
127
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
128
  #if defined(RDNA1)
129
  return 64;
130
  #else
 
136
  #else
137
  return 64;
138
  #endif // __CUDA_ARCH__ >= CC_VOLTA
139
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
140
  }
141
 
142
  #define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
 
2569
  // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
2570
 
2571
  template <ggml_type type, int mmq_x, int nwarps, bool need_check>
2572
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2573
  #if defined(RDNA3) || defined(RDNA2)
2574
  __launch_bounds__(WARP_SIZE*nwarps, 2)
2575
  #endif // defined(RDNA3) || defined(RDNA2)
 
2579
  #else
2580
  __launch_bounds__(WARP_SIZE*nwarps, 2)
2581
  #endif // __CUDA_ARCH__ >= CC_VOLTA
2582
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
2583
  static __global__ void mul_mat_q(
2584
  const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
2585
  const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
 
2594
  constexpr int mmq_y = get_mmq_y_device();
2595
 
2596
  // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
2597
+ #if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
2598
  {
2599
  constexpr bool fixup = false;
2600
  mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
 
2602
  blockIdx.x, blockIdx.y, 0, ne00/qk);
2603
  return;
2604
  }
2605
+ #endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
2606
 
2607
  const int64_t blocks_per_ne00 = ne00 / qk;
2608
  constexpr int blocks_per_iter = MMQ_ITER_K / qk;
 
2765
 
2766
  const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
2767
 
2768
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
2769
  static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
2770
  if (!shmem_limit_raised[id]) {
2771
  CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2772
  CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
2773
  shmem_limit_raised[id] = true;
2774
  }
2775
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
2776
 
2777
  const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
2778
  const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
ggml/src/ggml-cuda/mmvq.cu CHANGED
@@ -48,10 +48,10 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
48
  }
49
 
50
  template <ggml_type type, int ncols_y>
51
- #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
52
  // tell the compiler to use as many registers as it wants, see nwarps definition below
53
  __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
54
- #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
55
  static __global__ void mul_mat_vec_q(
56
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57
  const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
62
 
63
  constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
64
 
65
- #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66
  constexpr int nwarps = 1;
67
  constexpr int rows_per_cuda_block = 1;
68
  #else
69
  constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
70
  constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
71
- #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72
 
73
  const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
74
  const int row0 = rows_per_cuda_block*blockIdx.x;
 
48
  }
49
 
50
  template <ggml_type type, int ncols_y>
51
+ #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
52
  // tell the compiler to use as many registers as it wants, see nwarps definition below
53
  __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
54
+ #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
55
  static __global__ void mul_mat_vec_q(
56
  const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
57
  const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
 
62
 
63
  constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
64
 
65
+ #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
66
  constexpr int nwarps = 1;
67
  constexpr int rows_per_cuda_block = 1;
68
  #else
69
  constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
70
  constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
71
+ #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
72
 
73
  const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
74
  const int row0 = rows_per_cuda_block*blockIdx.x;
ggml/src/ggml-cuda/sum.cu CHANGED
@@ -1,6 +1,6 @@
1
- #if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
2
  #define USE_CUB
3
- #endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
4
 
5
  #ifdef USE_CUB
6
  // On Windows CUB uses libraries with variables called CC_PASCAL which conflict with the define in common.cuh.
 
1
+ #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
2
  #define USE_CUB
3
+ #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
4
 
5
  #ifdef USE_CUB
6
  // On Windows CUB uses libraries with variables called CC_PASCAL which conflict with the define in common.cuh.
ggml/src/ggml-hip/CMakeLists.txt ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ if (NOT EXISTS $ENV{ROCM_PATH})
2
+ if (NOT EXISTS /opt/rocm)
3
+ set(ROCM_PATH /usr)
4
+ else()
5
+ set(ROCM_PATH /opt/rocm)
6
+ endif()
7
+ else()
8
+ set(ROCM_PATH $ENV{ROCM_PATH})
9
+ endif()
10
+
11
+ list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
12
+ list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
13
+
14
+ # CMake on Windows doesn't support the HIP language yet
15
+ if (WIN32)
16
+ set(CXX_IS_HIPCC TRUE)
17
+ else()
18
+ string(REGEX MATCH "hipcc(\.bat)?$" CXX_IS_HIPCC "${CMAKE_CXX_COMPILER}")
19
+ endif()
20
+
21
+ if (CXX_IS_HIPCC)
22
+ if (LINUX)
23
+ if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
24
+ message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
25
+ endif()
26
+
27
+ message(WARNING "Setting hipcc as the C++ compiler is legacy behavior."
28
+ " Prefer setting the HIP compiler directly. See README for details.")
29
+ endif()
30
+ else()
31
+ # Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
32
+ if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
33
+ set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
34
+ endif()
35
+ cmake_minimum_required(VERSION 3.21)
36
+ enable_language(HIP)
37
+ endif()
38
+
39
+ find_package(hip REQUIRED)
40
+ find_package(hipblas REQUIRED)
41
+ find_package(rocblas REQUIRED)
42
+
43
+ message(STATUS "HIP and hipBLAS found")
44
+
45
+ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
46
+ list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
47
+
48
+ file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
49
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
50
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
51
+ file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
52
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
53
+
54
+ if (GGML_CUDA_FA_ALL_QUANTS)
55
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
56
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
57
+ add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
58
+ else()
59
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
60
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
61
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
62
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
63
+ file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
64
+ list(APPEND GGML_SOURCES_ROCM ${SRCS})
65
+ endif()
66
+
67
+ add_library(ggml-hip
68
+ ${GGML_HEADERS_ROCM}
69
+ ${GGML_SOURCES_ROCM})
70
+
71
+ target_link_libraries(ggml-hip PRIVATE ggml-base)
72
+ target_include_directories(ggml-hip PRIVATE . ..)
73
+
74
+ # TODO: do not use CUDA definitions for HIP
75
+ target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
76
+
77
+ add_compile_definitions(GGML_USE_HIP)
78
+ add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
79
+ add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
80
+ add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
81
+
82
+ if (GGML_HIP_UMA)
83
+ add_compile_definitions(GGML_HIP_UMA)
84
+ endif()
85
+
86
+ if (GGML_CUDA_FORCE_DMMV)
87
+ add_compile_definitions(GGML_CUDA_FORCE_DMMV)
88
+ endif()
89
+
90
+ if (GGML_CUDA_FORCE_MMQ)
91
+ add_compile_definitions(GGML_CUDA_FORCE_MMQ)
92
+ endif()
93
+
94
+ if (GGML_CUDA_FORCE_CUBLAS)
95
+ add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
96
+ endif()
97
+
98
+ if (GGML_CUDA_NO_PEER_COPY)
99
+ add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
100
+ endif()
101
+
102
+ if (CXX_IS_HIPCC)
103
+ set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
104
+ target_link_libraries(ggml-hip PRIVATE hip::device)
105
+ else()
106
+ set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)
107
+ endif()
108
+
109
+ if (GGML_STATIC)
110
+ message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
111
+ endif()
112
+
113
+ target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas)
ggml/src/ggml-impl.h CHANGED
@@ -3,13 +3,29 @@
3
  // GGML internal header
4
 
5
  #include "ggml.h"
6
-
7
  #include <assert.h>
 
8
  #include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
9
  #include <stdbool.h>
10
  #include <stdint.h>
11
  #include <string.h>
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  #ifdef __cplusplus
14
  extern "C" {
15
  #endif
@@ -28,13 +44,13 @@ extern "C" {
28
  // if C99 - static_assert is noop
29
  // ref: https://stackoverflow.com/a/53923785/4039976
30
  #ifndef __cplusplus
31
- #ifndef static_assert
32
- #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
33
- #define static_assert(cond, msg) _Static_assert(cond, msg)
34
- #else
35
- #define static_assert(cond, msg) struct global_scope_noop_trick
36
- #endif
37
- #endif
38
  #endif
39
 
40
  static inline int ggml_up32(int n) {
@@ -120,14 +136,12 @@ struct ggml_map_custom1_op_params {
120
  void * userdata;
121
  };
122
 
123
-
124
  struct ggml_map_custom2_op_params {
125
  ggml_custom2_op_t fun;
126
  int n_tasks;
127
  void * userdata;
128
  };
129
 
130
-
131
  struct ggml_map_custom3_op_params {
132
  ggml_custom3_op_t fun;
133
  int n_tasks;
@@ -287,9 +301,249 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
287
  void * ggml_aligned_malloc(size_t size);
288
  void ggml_aligned_free(void * ptr, size_t size);
289
 
290
- // TODO: move to threading file
291
- void ggml_critical_section_start(void);
292
- void ggml_critical_section_end(void);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
 
294
  #ifdef __cplusplus
295
  }
 
3
  // GGML internal header
4
 
5
  #include "ggml.h"
 
6
  #include <assert.h>
7
+ #include <math.h>
8
  #include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
9
  #include <stdbool.h>
10
  #include <stdint.h>
11
  #include <string.h>
12
 
13
+ #ifdef __ARM_FEATURE_SVE
14
+ #include <arm_sve.h>
15
+ #endif // __ARM_FEATURE_SVE
16
+
17
+ #if defined(__ARM_NEON)
18
+ // if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
19
+ //
20
+ // $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
21
+ //
22
+ #include <arm_neon.h>
23
+ #endif
24
+
25
+ #if defined(__F16C__)
26
+ #include <immintrin.h>
27
+ #endif
28
+
29
  #ifdef __cplusplus
30
  extern "C" {
31
  #endif
 
44
  // if C99 - static_assert is noop
45
  // ref: https://stackoverflow.com/a/53923785/4039976
46
  #ifndef __cplusplus
47
+ #ifndef static_assert
48
+ #if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
49
+ #define static_assert(cond, msg) _Static_assert(cond, msg)
50
+ #else
51
+ #define static_assert(cond, msg) struct global_scope_noop_trick
52
+ #endif
53
+ #endif
54
  #endif
55
 
56
  static inline int ggml_up32(int n) {
 
136
  void * userdata;
137
  };
138
 
 
139
  struct ggml_map_custom2_op_params {
140
  ggml_custom2_op_t fun;
141
  int n_tasks;
142
  void * userdata;
143
  };
144
 
 
145
  struct ggml_map_custom3_op_params {
146
  ggml_custom3_op_t fun;
147
  int n_tasks;
 
301
  void * ggml_aligned_malloc(size_t size);
302
  void ggml_aligned_free(void * ptr, size_t size);
303
 
304
+ // FP16 to FP32 conversion
305
+
306
+ #if defined(__ARM_NEON)
307
+ #ifdef _MSC_VER
308
+ typedef uint16_t ggml_fp16_internal_t;
309
+ #else
310
+ typedef __fp16 ggml_fp16_internal_t;
311
+ #endif
312
+ #endif
313
+
314
+ #if defined(__ARM_NEON) && !defined(_MSC_VER)
315
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
316
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
317
+
318
+ #define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
319
+
320
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
321
+ ggml_fp16_internal_t tmp;
322
+ memcpy(&tmp, &h, sizeof(ggml_fp16_t));
323
+ return (float)tmp;
324
+ }
325
+
326
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
327
+ ggml_fp16_t res;
328
+ ggml_fp16_internal_t tmp = f;
329
+ memcpy(&res, &tmp, sizeof(ggml_fp16_t));
330
+ return res;
331
+ }
332
+
333
+ #elif defined(__F16C__)
334
+
335
+ #ifdef _MSC_VER
336
+ #define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
337
+ #define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
338
+ #else
339
+ #define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
340
+ #define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
341
+ #endif
342
+
343
+ #elif defined(__POWER9_VECTOR__)
344
+
345
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
346
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
347
+ /* the inline asm below is about 12% faster than the lookup method */
348
+ #define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
349
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
350
+
351
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
352
+ register float f;
353
+ register double d;
354
+ __asm__(
355
+ "mtfprd %0,%2\n"
356
+ "xscvhpdp %0,%0\n"
357
+ "frsp %1,%0\n" :
358
+ /* temp */ "=d"(d),
359
+ /* out */ "=f"(f):
360
+ /* in */ "r"(h));
361
+ return f;
362
+ }
363
+
364
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
365
+ register double d;
366
+ register ggml_fp16_t r;
367
+ __asm__( /* xscvdphp can work on double or single precision */
368
+ "xscvdphp %0,%2\n"
369
+ "mffprd %1,%0\n" :
370
+ /* temp */ "=d"(d),
371
+ /* out */ "=r"(r):
372
+ /* in */ "f"(f));
373
+ return r;
374
+ }
375
+
376
+ #else
377
+
378
+ // FP16 <-> FP32
379
+ // ref: https://github.com/Maratyszcza/FP16
380
+
381
+ static inline float fp32_from_bits(uint32_t w) {
382
+ union {
383
+ uint32_t as_bits;
384
+ float as_value;
385
+ } fp32;
386
+ fp32.as_bits = w;
387
+ return fp32.as_value;
388
+ }
389
+
390
+ static inline uint32_t fp32_to_bits(float f) {
391
+ union {
392
+ float as_value;
393
+ uint32_t as_bits;
394
+ } fp32;
395
+ fp32.as_value = f;
396
+ return fp32.as_bits;
397
+ }
398
+
399
+ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
400
+ const uint32_t w = (uint32_t) h << 16;
401
+ const uint32_t sign = w & UINT32_C(0x80000000);
402
+ const uint32_t two_w = w + w;
403
+
404
+ const uint32_t exp_offset = UINT32_C(0xE0) << 23;
405
+ #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
406
+ const float exp_scale = 0x1.0p-112f;
407
+ #else
408
+ const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
409
+ #endif
410
+ const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
411
+
412
+ const uint32_t magic_mask = UINT32_C(126) << 23;
413
+ const float magic_bias = 0.5f;
414
+ const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
415
+
416
+ const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
417
+ const uint32_t result = sign |
418
+ (two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
419
+ return fp32_from_bits(result);
420
+ }
421
+
422
+ static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
423
+ #if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
424
+ const float scale_to_inf = 0x1.0p+112f;
425
+ const float scale_to_zero = 0x1.0p-110f;
426
+ #else
427
+ const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
428
+ const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
429
+ #endif
430
+ float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
431
+
432
+ const uint32_t w = fp32_to_bits(f);
433
+ const uint32_t shl1_w = w + w;
434
+ const uint32_t sign = w & UINT32_C(0x80000000);
435
+ uint32_t bias = shl1_w & UINT32_C(0xFF000000);
436
+ if (bias < UINT32_C(0x71000000)) {
437
+ bias = UINT32_C(0x71000000);
438
+ }
439
+
440
+ base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
441
+ const uint32_t bits = fp32_to_bits(base);
442
+ const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
443
+ const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
444
+ const uint32_t nonsign = exp_bits + mantissa_bits;
445
+ return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
446
+ }
447
+
448
+ #define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
449
+ #define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
450
+
451
+ #endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
452
+
453
+ // precomputed f32 table for f16 (256 KB)
454
+ // defined in ggml.c, initialized in ggml_init()
455
+ GGML_API float ggml_table_f32_f16[1 << 16];
456
+
457
+ // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
458
+ // so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
459
+ // This is also true for POWER9.
460
+ #if !defined(GGML_FP16_TO_FP32)
461
+ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
462
+ uint16_t s;
463
+ memcpy(&s, &f, sizeof(uint16_t));
464
+ return ggml_table_f32_f16[s];
465
+ }
466
+
467
+ #define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
468
+ #endif
469
+
470
+ #if !defined(GGML_FP32_TO_FP16)
471
+ #define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
472
+ #endif
473
+
474
+ /**
475
+ * Converts brain16 to float32.
476
+ *
477
+ * The bfloat16 floating point format has the following structure:
478
+ *
479
+ * ┌sign
480
+ * │
481
+ * │ ┌exponent
482
+ * │ │
483
+ * │ │ ┌mantissa
484
+ * │ │ │
485
+ * │┌──┴───┐┌─┴───┐
486
+ * 0b0000000000000000 brain16
487
+ *
488
+ * Since bf16 has the same number of exponent bits as a 32bit float,
489
+ * encoding and decoding numbers becomes relatively straightforward.
490
+ *
491
+ * ┌sign
492
+ * │
493
+ * │ ┌exponent
494
+ * │ │
495
+ * │ │ ┌mantissa
496
+ * │ │ │
497
+ * │┌──┴───┐┌─┴───────────────────┐
498
+ * 0b00000000000000000000000000000000 IEEE binary32
499
+ *
500
+ * For comparison, the standard fp16 format has fewer exponent bits.
501
+ *
502
+ * ┌sign
503
+ * │
504
+ * │ ┌exponent
505
+ * │ │
506
+ * │ │ ┌mantissa
507
+ * │ │ │
508
+ * │┌─┴─┐┌─┴──────┐
509
+ * 0b0000000000000000 IEEE binary16
510
+ *
511
+ * @see IEEE 754-2008
512
+ */
513
+ static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
514
+ union {
515
+ float f;
516
+ uint32_t i;
517
+ } u;
518
+ u.i = (uint32_t)h.bits << 16;
519
+ return u.f;
520
+ }
521
+
522
+ /**
523
+ * Converts float32 to brain16.
524
+ *
525
+ * This is binary identical with Google Brain float conversion.
526
+ * Floats shall round to nearest even, and NANs shall be quiet.
527
+ * Subnormals aren't flushed to zero, except perhaps when used.
528
+ * This code should vectorize nicely if using modern compilers.
529
+ */
530
+ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
531
+ ggml_bf16_t h;
532
+ union {
533
+ float f;
534
+ uint32_t i;
535
+ } u;
536
+ u.f = s;
537
+ if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
538
+ h.bits = (u.i >> 16) | 64; /* force to quiet */
539
+ return h;
540
+ }
541
+ h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
542
+ return h;
543
+ }
544
+
545
+ #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
546
+ #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
547
 
548
  #ifdef __cplusplus
549
  }
ggml/src/ggml-kompute/CMakeLists.txt ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ find_package(Vulkan COMPONENTS glslc REQUIRED)
3
+ find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
4
+
5
+ if (NOT glslc_executable)
6
+ message(FATAL_ERROR "glslc not found")
7
+ endif()
8
+
9
+ add_library(ggml-kompute
10
+ ggml-kompute.cpp
11
+ ../../include/ggml-kompute.h
12
+ )
13
+
14
+ target_link_libraries(ggml-kompute PRIVATE ggml-base kompute)
15
+ target_include_directories(ggml-kompute PRIVATE . .. ${CMAKE_CURRENT_BINARY_DIR})
16
+
17
+ add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
18
+
19
+ function(compile_shader)
20
+ set(options)
21
+ set(oneValueArgs)
22
+ set(multiValueArgs SOURCES)
23
+ cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
24
+ foreach(source ${compile_shader_SOURCES})
25
+ get_filename_component(filename ${source} NAME)
26
+ set(spv_file ${filename}.spv)
27
+ add_custom_command(
28
+ OUTPUT ${spv_file}
29
+ DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
30
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
31
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
32
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
33
+ ${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
34
+ COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
35
+ COMMENT "Compiling ${source} to ${spv_file}"
36
+ )
37
+
38
+ get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
39
+ set(FILE_NAME "shader${RAW_FILE_NAME}")
40
+ string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
41
+ string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
42
+ string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
43
+ set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
44
+ message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
45
+ if(CMAKE_GENERATOR MATCHES "Visual Studio")
46
+ add_custom_command(
47
+ OUTPUT ${OUTPUT_HEADER_FILE}
48
+ COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
49
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
50
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
51
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
52
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
53
+ COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
54
+ COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
55
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
56
+ DEPENDS ${spv_file} xxd
57
+ COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
58
+ )
59
+ else()
60
+ add_custom_command(
61
+ OUTPUT ${OUTPUT_HEADER_FILE}
62
+ COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
63
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
64
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
65
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
66
+ COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
67
+ COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
68
+ COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
69
+ COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
70
+ DEPENDS ${spv_file} xxd
71
+ COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
72
+ )
73
+ endif()
74
+ endforeach()
75
+ endfunction()
76
+
77
+ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
78
+ message(STATUS "Kompute found")
79
+ set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
80
+ add_subdirectory(kompute)
81
+
82
+ # Compile our shaders
83
+ compile_shader(SOURCES
84
+ kompute-shaders/op_scale.comp
85
+ kompute-shaders/op_scale_8.comp
86
+ kompute-shaders/op_add.comp
87
+ kompute-shaders/op_addrow.comp
88
+ kompute-shaders/op_mul.comp
89
+ kompute-shaders/op_silu.comp
90
+ kompute-shaders/op_relu.comp
91
+ kompute-shaders/op_gelu.comp
92
+ kompute-shaders/op_softmax.comp
93
+ kompute-shaders/op_norm.comp
94
+ kompute-shaders/op_rmsnorm.comp
95
+ kompute-shaders/op_diagmask.comp
96
+ kompute-shaders/op_mul_mat_mat_f32.comp
97
+ kompute-shaders/op_mul_mat_f16.comp
98
+ kompute-shaders/op_mul_mat_q8_0.comp
99
+ kompute-shaders/op_mul_mat_q4_0.comp
100
+ kompute-shaders/op_mul_mat_q4_1.comp
101
+ kompute-shaders/op_mul_mat_q4_k.comp
102
+ kompute-shaders/op_mul_mat_q6_k.comp
103
+ kompute-shaders/op_getrows_f32.comp
104
+ kompute-shaders/op_getrows_f16.comp
105
+ kompute-shaders/op_getrows_q4_0.comp
106
+ kompute-shaders/op_getrows_q4_1.comp
107
+ kompute-shaders/op_getrows_q6_k.comp
108
+ kompute-shaders/op_rope_f16.comp
109
+ kompute-shaders/op_rope_f32.comp
110
+ kompute-shaders/op_cpy_f16_f16.comp
111
+ kompute-shaders/op_cpy_f16_f32.comp
112
+ kompute-shaders/op_cpy_f32_f16.comp
113
+ kompute-shaders/op_cpy_f32_f32.comp
114
+ )
115
+
116
+ # Create a custom target for our generated shaders
117
+ add_custom_target(generated_shaders DEPENDS
118
+ shaderop_scale.h
119
+ shaderop_scale_8.h
120
+ shaderop_add.h
121
+ shaderop_addrow.h
122
+ shaderop_mul.h
123
+ shaderop_silu.h
124
+ shaderop_relu.h
125
+ shaderop_gelu.h
126
+ shaderop_softmax.h
127
+ shaderop_norm.h
128
+ shaderop_rmsnorm.h
129
+ shaderop_diagmask.h
130
+ shaderop_mul_mat_mat_f32.h
131
+ shaderop_mul_mat_f16.h
132
+ shaderop_mul_mat_q8_0.h
133
+ shaderop_mul_mat_q4_0.h
134
+ shaderop_mul_mat_q4_1.h
135
+ shaderop_mul_mat_q4_k.h
136
+ shaderop_mul_mat_q6_k.h
137
+ shaderop_getrows_f32.h
138
+ shaderop_getrows_f16.h
139
+ shaderop_getrows_q4_0.h
140
+ shaderop_getrows_q4_1.h
141
+ shaderop_getrows_q6_k.h
142
+ shaderop_rope_f16.h
143
+ shaderop_rope_f32.h
144
+ shaderop_cpy_f16_f16.h
145
+ shaderop_cpy_f16_f32.h
146
+ shaderop_cpy_f32_f16.h
147
+ shaderop_cpy_f32_f32.h
148
+ )
149
+
150
+ # Create a custom command that depends on the generated_shaders
151
+ add_custom_command(
152
+ OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
153
+ COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
154
+ DEPENDS generated_shaders
155
+ COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
156
+ )
157
+
158
+ # Add the stamp to the main sources to ensure dependency tracking
159
+ target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
160
+ else()
161
+ message(WARNING "Kompute not found")
162
+ endif()