Spaces:
Running
Running
ggml : build backends as libraries (llama/10256)
Browse files* ggml : build backends as libraries
---------
Signed-off-by: Xiaodong Ye <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: R0CKSTAR <[email protected]>
This view is limited to 50 files because it contains too many changes.
See raw diff
- ggml/CMakeLists.txt +7 -3
- ggml/include/ggml-amx.h +5 -5
- ggml/include/ggml-backend.h +14 -0
- ggml/include/ggml-blas.h +4 -4
- ggml/include/ggml-cann.h +8 -8
- ggml/include/ggml-cpu.h +64 -40
- ggml/include/ggml-cuda.h +12 -12
- ggml/include/ggml-kompute.h +4 -4
- ggml/include/ggml-metal.h +8 -8
- ggml/include/ggml-rpc.h +7 -7
- ggml/include/ggml-sycl.h +13 -13
- ggml/include/ggml-vulkan.h +9 -9
- ggml/include/ggml.h +5 -38
- ggml/src/ggml-aarch64.c +0 -0
- ggml/src/ggml-aarch64.h +0 -20
- ggml/src/ggml-amx/CMakeLists.txt +107 -0
- ggml/src/ggml-amx/common.h +2 -1
- ggml/src/ggml-amx/ggml-amx.cpp +449 -0
- ggml/src/ggml-amx/mmq.cpp +4 -3
- ggml/src/ggml-backend-reg.cpp +195 -0
- ggml/src/ggml-blas/CMakeLists.txt +91 -0
- ggml/src/ggml-blas/ggml-blas.cpp +514 -0
- ggml/src/ggml-cann/CMakeLists.txt +46 -0
- ggml/src/ggml-cann/ggml-cann.cpp +2128 -0
- ggml/src/ggml-cpu/CMakeLists.txt +244 -0
- ggml/src/ggml-cpu/cmake/FindSIMD.cmake +100 -0
- ggml/src/ggml-cpu/ggml-cpu-aarch64.c +0 -0
- ggml/src/ggml-cpu/ggml-cpu-aarch64.h +27 -0
- ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- ggml/src/ggml-cpu/ggml-cpu-quants.c +0 -0
- ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- ggml/src/ggml-cpu/ggml-cpu.c +0 -0
- ggml/src/ggml-cpu/ggml-cpu.cpp +575 -0
- ggml/src/ggml-cpu/llamafile/sgemm.cpp +1884 -0
- ggml/src/ggml-cpu/llamafile/sgemm.h +14 -0
- ggml/src/ggml-cuda/common.cuh +25 -25
- ggml/src/ggml-cuda/fattn-common.cuh +2 -2
- ggml/src/ggml-cuda/fattn-tile-f16.cu +2 -2
- ggml/src/ggml-cuda/fattn-tile-f32.cu +2 -2
- ggml/src/ggml-cuda/fattn-vec-f16.cuh +2 -2
- ggml/src/ggml-cuda/fattn-vec-f32.cuh +2 -2
- ggml/src/ggml-cuda/fattn-wmma-f16.cuh +2 -2
- ggml/src/ggml-cuda/ggml-cuda.cu +0 -0
- ggml/src/ggml-cuda/ggml/CMakeLists.txt +165 -0
- ggml/src/ggml-cuda/mmq.cuh +11 -11
- ggml/src/ggml-cuda/mmvq.cu +4 -4
- ggml/src/ggml-cuda/sum.cu +2 -2
- ggml/src/ggml-hip/CMakeLists.txt +113 -0
- ggml/src/ggml-impl.h +267 -13
- ggml/src/ggml-kompute/CMakeLists.txt +162 -0
ggml/CMakeLists.txt
CHANGED
|
@@ -116,6 +116,7 @@ endif()
|
|
| 116 |
|
| 117 |
# ggml core
|
| 118 |
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
|
|
|
| 119 |
|
| 120 |
# 3rd party libs / backends
|
| 121 |
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
|
|
@@ -141,7 +142,7 @@ option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM"
|
|
| 141 |
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
| 142 |
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
| 143 |
|
| 144 |
-
option(
|
| 145 |
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
| 146 |
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
| 147 |
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
|
@@ -238,12 +239,15 @@ set_target_properties(ggml PROPERTIES PUBLIC_HEADER "${GGML_PUBLIC_HEADERS}")
|
|
| 238 |
install(TARGETS ggml PUBLIC_HEADER)
|
| 239 |
|
| 240 |
if (BUILD_SHARED_LIBS)
|
| 241 |
-
install(TARGETS ggml
|
|
|
|
| 242 |
endif()
|
| 243 |
|
|
|
|
| 244 |
if (GGML_METAL)
|
|
|
|
| 245 |
install(
|
| 246 |
-
FILES src/ggml-metal.metal
|
| 247 |
PERMISSIONS
|
| 248 |
OWNER_READ
|
| 249 |
OWNER_WRITE
|
|
|
|
| 116 |
|
| 117 |
# ggml core
|
| 118 |
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
|
| 119 |
+
option(GGML_CPU "ggml: enable CPU backend" ON)
|
| 120 |
|
| 121 |
# 3rd party libs / backends
|
| 122 |
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)
|
|
|
|
| 142 |
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
|
| 143 |
option(GGML_CUDA_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ${GGML_CUDA_GRAPHS_DEFAULT})
|
| 144 |
|
| 145 |
+
option(GGML_HIP "ggml: use HIP" OFF)
|
| 146 |
option(GGML_HIP_UMA "ggml: use HIP unified memory architecture" OFF)
|
| 147 |
option(GGML_VULKAN "ggml: use Vulkan" OFF)
|
| 148 |
option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF)
|
|
|
|
| 239 |
install(TARGETS ggml PUBLIC_HEADER)
|
| 240 |
|
| 241 |
if (BUILD_SHARED_LIBS)
|
| 242 |
+
install(TARGETS ggml LIBRARY)
|
| 243 |
+
install(TARGETS ggml-base LIBRARY)
|
| 244 |
endif()
|
| 245 |
|
| 246 |
+
# FIXME: this should be done in the backend cmake files
|
| 247 |
if (GGML_METAL)
|
| 248 |
+
# FIXME: does this need to be installed with GGML_METAL_EMBED_LIBRARY?
|
| 249 |
install(
|
| 250 |
+
FILES ggml/src/ggml-metal/ggml-metal.metal
|
| 251 |
PERMISSIONS
|
| 252 |
OWNER_READ
|
| 253 |
OWNER_WRITE
|
ggml/include/ggml-amx.h
CHANGED
|
@@ -9,16 +9,16 @@ extern "C" {
|
|
| 9 |
#endif
|
| 10 |
|
| 11 |
// buffer_type API
|
| 12 |
-
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
// backend API
|
| 17 |
-
|
| 18 |
|
| 19 |
-
|
| 20 |
|
| 21 |
-
|
| 22 |
|
| 23 |
#ifdef __cplusplus
|
| 24 |
}
|
|
|
|
| 9 |
#endif
|
| 10 |
|
| 11 |
// buffer_type API
|
| 12 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
|
| 13 |
|
| 14 |
+
GGML_BACKEND_API bool ggml_backend_is_amx(ggml_backend_t backend);
|
| 15 |
|
| 16 |
// backend API
|
| 17 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_amx_init(void);
|
| 18 |
|
| 19 |
+
GGML_BACKEND_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);
|
| 20 |
|
| 21 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_amx_reg(void);
|
| 22 |
|
| 23 |
#ifdef __cplusplus
|
| 24 |
}
|
ggml/include/ggml-backend.h
CHANGED
|
@@ -3,6 +3,20 @@
|
|
| 3 |
#include "ggml.h"
|
| 4 |
#include "ggml-alloc.h"
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
#ifdef __cplusplus
|
| 7 |
extern "C" {
|
| 8 |
#endif
|
|
|
|
| 3 |
#include "ggml.h"
|
| 4 |
#include "ggml-alloc.h"
|
| 5 |
|
| 6 |
+
#ifdef GGML_BACKEND_SHARED
|
| 7 |
+
# if defined(_WIN32) && !defined(__MINGW32__)
|
| 8 |
+
# ifdef GGML_BACKEND_BUILD
|
| 9 |
+
# define GGML_BACKEND_API __declspec(dllexport) extern
|
| 10 |
+
# else
|
| 11 |
+
# define GGML_BACKEND_API __declspec(dllimport) extern
|
| 12 |
+
# endif
|
| 13 |
+
# else
|
| 14 |
+
# define GGML_BACKEND_API __attribute__ ((visibility ("default"))) extern
|
| 15 |
+
# endif
|
| 16 |
+
#else
|
| 17 |
+
# define GGML_BACKEND_API extern
|
| 18 |
+
#endif
|
| 19 |
+
|
| 20 |
#ifdef __cplusplus
|
| 21 |
extern "C" {
|
| 22 |
#endif
|
ggml/include/ggml-blas.h
CHANGED
|
@@ -9,15 +9,15 @@ extern "C" {
|
|
| 9 |
#endif
|
| 10 |
|
| 11 |
// backend API
|
| 12 |
-
|
| 13 |
|
| 14 |
-
|
| 15 |
|
| 16 |
// number of threads used for conversion to float
|
| 17 |
// for openblas and blis, this will also set the number of threads used for blas operations
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
|
| 23 |
#ifdef __cplusplus
|
|
|
|
| 9 |
#endif
|
| 10 |
|
| 11 |
// backend API
|
| 12 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_blas_init(void);
|
| 13 |
|
| 14 |
+
GGML_BACKEND_API bool ggml_backend_is_blas(ggml_backend_t backend);
|
| 15 |
|
| 16 |
// number of threads used for conversion to float
|
| 17 |
// for openblas and blis, this will also set the number of threads used for blas operations
|
| 18 |
+
GGML_BACKEND_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
|
| 19 |
|
| 20 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_blas_reg(void);
|
| 21 |
|
| 22 |
|
| 23 |
#ifdef __cplusplus
|
ggml/include/ggml-cann.h
CHANGED
|
@@ -34,7 +34,7 @@ extern "C" {
|
|
| 34 |
*/
|
| 35 |
#define GGML_CANN_MAX_DEVICES 16
|
| 36 |
|
| 37 |
-
|
| 38 |
|
| 39 |
/**
|
| 40 |
* @brief Initializes the CANN backend for a specified device.
|
|
@@ -46,7 +46,7 @@ GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
|
|
| 46 |
* @param device The index of the device to initialize.
|
| 47 |
* @return A pointer to the initialized backend instance, or nullptr on failure.
|
| 48 |
*/
|
| 49 |
-
|
| 50 |
|
| 51 |
/**
|
| 52 |
* @brief Checks if a given backend is a CANN backend.
|
|
@@ -57,7 +57,7 @@ GGML_API ggml_backend_t ggml_backend_cann_init(int32_t device);
|
|
| 57 |
* @param backend The backend instance to check.
|
| 58 |
* @return True if the backend is a CANN backend, false otherwise.
|
| 59 |
*/
|
| 60 |
-
|
| 61 |
|
| 62 |
/**
|
| 63 |
* @brief Retrieves the CANN buffer type for a specified device.
|
|
@@ -69,7 +69,7 @@ GGML_API bool ggml_backend_is_cann(ggml_backend_t backend);
|
|
| 69 |
* @return A pointer to the buffer type interface for the specified device, or
|
| 70 |
* nullptr if the device index is out of range.
|
| 71 |
*/
|
| 72 |
-
|
| 73 |
ggml_backend_cann_buffer_type(int32_t device);
|
| 74 |
|
| 75 |
/**
|
|
@@ -80,14 +80,14 @@ ggml_backend_cann_buffer_type(int32_t device);
|
|
| 80 |
*
|
| 81 |
* @return The number of CANN devices available.
|
| 82 |
*/
|
| 83 |
-
|
| 84 |
|
| 85 |
/**
|
| 86 |
* @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.
|
| 87 |
*
|
| 88 |
* @return A pointer to the host buffer type interface.
|
| 89 |
*/
|
| 90 |
-
|
| 91 |
|
| 92 |
/**
|
| 93 |
* @brief Retrieves the description of a specific CANN device.
|
|
@@ -99,7 +99,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
|
|
| 99 |
* @param description Pointer to a buffer where the description will be written.
|
| 100 |
* @param description_size Size of the description buffer.
|
| 101 |
*/
|
| 102 |
-
|
| 103 |
int32_t device, char* description, size_t description_size);
|
| 104 |
|
| 105 |
/**
|
|
@@ -114,7 +114,7 @@ GGML_API void ggml_backend_cann_get_device_description(
|
|
| 114 |
* @param total Pointer to a variable where the total memory size will be
|
| 115 |
* stored.
|
| 116 |
*/
|
| 117 |
-
|
| 118 |
size_t* free,
|
| 119 |
size_t* total);
|
| 120 |
|
|
|
|
| 34 |
*/
|
| 35 |
#define GGML_CANN_MAX_DEVICES 16
|
| 36 |
|
| 37 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cann_reg(void);
|
| 38 |
|
| 39 |
/**
|
| 40 |
* @brief Initializes the CANN backend for a specified device.
|
|
|
|
| 46 |
* @param device The index of the device to initialize.
|
| 47 |
* @return A pointer to the initialized backend instance, or nullptr on failure.
|
| 48 |
*/
|
| 49 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_cann_init(int32_t device);
|
| 50 |
|
| 51 |
/**
|
| 52 |
* @brief Checks if a given backend is a CANN backend.
|
|
|
|
| 57 |
* @param backend The backend instance to check.
|
| 58 |
* @return True if the backend is a CANN backend, false otherwise.
|
| 59 |
*/
|
| 60 |
+
GGML_BACKEND_API bool ggml_backend_is_cann(ggml_backend_t backend);
|
| 61 |
|
| 62 |
/**
|
| 63 |
* @brief Retrieves the CANN buffer type for a specified device.
|
|
|
|
| 69 |
* @return A pointer to the buffer type interface for the specified device, or
|
| 70 |
* nullptr if the device index is out of range.
|
| 71 |
*/
|
| 72 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t
|
| 73 |
ggml_backend_cann_buffer_type(int32_t device);
|
| 74 |
|
| 75 |
/**
|
|
|
|
| 80 |
*
|
| 81 |
* @return The number of CANN devices available.
|
| 82 |
*/
|
| 83 |
+
GGML_BACKEND_API int32_t ggml_backend_cann_get_device_count(void);
|
| 84 |
|
| 85 |
/**
|
| 86 |
* @brief pinned host buffer for use with the CPU backend for faster copies between CPU and NPU.
|
| 87 |
*
|
| 88 |
* @return A pointer to the host buffer type interface.
|
| 89 |
*/
|
| 90 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type(void);
|
| 91 |
|
| 92 |
/**
|
| 93 |
* @brief Retrieves the description of a specific CANN device.
|
|
|
|
| 99 |
* @param description Pointer to a buffer where the description will be written.
|
| 100 |
* @param description_size Size of the description buffer.
|
| 101 |
*/
|
| 102 |
+
GGML_BACKEND_API void ggml_backend_cann_get_device_description(
|
| 103 |
int32_t device, char* description, size_t description_size);
|
| 104 |
|
| 105 |
/**
|
|
|
|
| 114 |
* @param total Pointer to a variable where the total memory size will be
|
| 115 |
* stored.
|
| 116 |
*/
|
| 117 |
+
GGML_BACKEND_API void ggml_backend_cann_get_device_memory(int32_t device,
|
| 118 |
size_t* free,
|
| 119 |
size_t* total);
|
| 120 |
|
ggml/include/ggml-cpu.h
CHANGED
|
@@ -54,54 +54,77 @@ extern "C" {
|
|
| 54 |
GGML_NUMA_STRATEGY_COUNT
|
| 55 |
};
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
-
|
| 61 |
-
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
|
| 75 |
-
|
| 76 |
-
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
| 88 |
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
| 89 |
-
|
| 90 |
const struct ggml_cgraph * cgraph,
|
| 91 |
int n_threads, /* = GGML_DEFAULT_N_THREADS */
|
| 92 |
struct ggml_threadpool * threadpool /* = NULL */ );
|
| 93 |
-
|
| 94 |
|
| 95 |
// same as ggml_graph_compute() but the work data is allocated as a part of the context
|
| 96 |
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
|
| 97 |
-
|
| 98 |
|
| 99 |
-
//
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
//
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
// Internal types and functions exposed for tests and benchmarks
|
| 107 |
|
|
@@ -115,6 +138,7 @@ extern "C" {
|
|
| 115 |
const void * GGML_RESTRICT y, int nr, int nc);
|
| 116 |
|
| 117 |
struct ggml_type_traits_cpu {
|
|
|
|
| 118 |
ggml_from_float_to_mat_t from_float_to_mat;
|
| 119 |
ggml_vec_dot_t vec_dot;
|
| 120 |
enum ggml_type vec_dot_type;
|
|
@@ -124,25 +148,25 @@ extern "C" {
|
|
| 124 |
ggml_gemm_t gemm;
|
| 125 |
};
|
| 126 |
|
| 127 |
-
|
| 128 |
|
| 129 |
-
|
| 130 |
|
| 131 |
//
|
| 132 |
// CPU backend
|
| 133 |
//
|
| 134 |
|
| 135 |
-
|
| 136 |
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
|
| 142 |
-
|
| 143 |
|
| 144 |
#ifdef GGML_USE_CPU_HBM
|
| 145 |
-
|
| 146 |
#endif
|
| 147 |
|
| 148 |
#ifdef __cplusplus
|
|
|
|
| 54 |
GGML_NUMA_STRATEGY_COUNT
|
| 55 |
};
|
| 56 |
|
| 57 |
+
GGML_BACKEND_API void ggml_numa_init(enum ggml_numa_strategy numa); // call once for better performance on NUMA systems
|
| 58 |
+
GGML_BACKEND_API bool ggml_is_numa(void); // true if init detected that system has >1 NUMA node
|
| 59 |
|
| 60 |
+
GGML_BACKEND_API struct ggml_tensor * ggml_new_i32(struct ggml_context * ctx, int32_t value);
|
| 61 |
+
GGML_BACKEND_API struct ggml_tensor * ggml_new_f32(struct ggml_context * ctx, float value);
|
| 62 |
|
| 63 |
+
GGML_BACKEND_API struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value);
|
| 64 |
+
GGML_BACKEND_API struct ggml_tensor * ggml_set_f32 (struct ggml_tensor * tensor, float value);
|
| 65 |
|
| 66 |
+
GGML_BACKEND_API int32_t ggml_get_i32_1d(const struct ggml_tensor * tensor, int i);
|
| 67 |
+
GGML_BACKEND_API void ggml_set_i32_1d(const struct ggml_tensor * tensor, int i, int32_t value);
|
| 68 |
|
| 69 |
+
GGML_BACKEND_API int32_t ggml_get_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
|
| 70 |
+
GGML_BACKEND_API void ggml_set_i32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, int32_t value);
|
| 71 |
|
| 72 |
+
GGML_BACKEND_API float ggml_get_f32_1d(const struct ggml_tensor * tensor, int i);
|
| 73 |
+
GGML_BACKEND_API void ggml_set_f32_1d(const struct ggml_tensor * tensor, int i, float value);
|
| 74 |
|
| 75 |
+
GGML_BACKEND_API float ggml_get_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3);
|
| 76 |
+
GGML_BACKEND_API void ggml_set_f32_nd(const struct ggml_tensor * tensor, int i0, int i1, int i2, int i3, float value);
|
| 77 |
|
| 78 |
+
GGML_BACKEND_API struct ggml_threadpool_params ggml_threadpool_params_default(int n_threads);
|
| 79 |
+
GGML_BACKEND_API void ggml_threadpool_params_init (struct ggml_threadpool_params * p, int n_threads);
|
| 80 |
+
GGML_BACKEND_API bool ggml_threadpool_params_match (const struct ggml_threadpool_params * p0, const struct ggml_threadpool_params * p1);
|
| 81 |
+
GGML_BACKEND_API struct ggml_threadpool * ggml_threadpool_new (struct ggml_threadpool_params * params);
|
| 82 |
+
GGML_BACKEND_API void ggml_threadpool_free (struct ggml_threadpool * threadpool);
|
| 83 |
+
GGML_BACKEND_API int ggml_threadpool_get_n_threads(struct ggml_threadpool * threadpool);
|
| 84 |
+
GGML_BACKEND_API void ggml_threadpool_pause (struct ggml_threadpool * threadpool);
|
| 85 |
+
GGML_BACKEND_API void ggml_threadpool_resume (struct ggml_threadpool * threadpool);
|
| 86 |
|
| 87 |
// ggml_graph_plan() has to be called before ggml_graph_compute()
|
| 88 |
// when plan.work_size > 0, caller must allocate memory for plan.work_data
|
| 89 |
+
GGML_BACKEND_API struct ggml_cplan ggml_graph_plan(
|
| 90 |
const struct ggml_cgraph * cgraph,
|
| 91 |
int n_threads, /* = GGML_DEFAULT_N_THREADS */
|
| 92 |
struct ggml_threadpool * threadpool /* = NULL */ );
|
| 93 |
+
GGML_BACKEND_API enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
|
| 94 |
|
| 95 |
// same as ggml_graph_compute() but the work data is allocated as a part of the context
|
| 96 |
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data
|
| 97 |
+
GGML_BACKEND_API enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct ggml_cgraph * cgraph, int n_threads);
|
| 98 |
|
| 99 |
+
//
|
| 100 |
+
// system info
|
| 101 |
+
//
|
| 102 |
+
|
| 103 |
+
// x86
|
| 104 |
+
GGML_BACKEND_API int ggml_cpu_has_sse3 (void);
|
| 105 |
+
GGML_BACKEND_API int ggml_cpu_has_ssse3 (void);
|
| 106 |
+
GGML_BACKEND_API int ggml_cpu_has_avx (void);
|
| 107 |
+
GGML_BACKEND_API int ggml_cpu_has_avx2 (void);
|
| 108 |
+
GGML_BACKEND_API int ggml_cpu_has_f16c (void);
|
| 109 |
+
GGML_BACKEND_API int ggml_cpu_has_fma (void);
|
| 110 |
+
GGML_BACKEND_API int ggml_cpu_has_avx_vnni (void);
|
| 111 |
+
GGML_BACKEND_API int ggml_cpu_has_avx512 (void);
|
| 112 |
+
GGML_BACKEND_API int ggml_cpu_has_avx512_vbmi(void);
|
| 113 |
+
GGML_BACKEND_API int ggml_cpu_has_avx512_vnni(void);
|
| 114 |
+
GGML_BACKEND_API int ggml_cpu_has_avx512_bf16(void);
|
| 115 |
+
GGML_BACKEND_API int ggml_cpu_has_amx_int8 (void);
|
| 116 |
+
// ARM
|
| 117 |
+
GGML_BACKEND_API int ggml_cpu_has_neon (void);
|
| 118 |
+
GGML_BACKEND_API int ggml_cpu_has_arm_fma (void);
|
| 119 |
+
GGML_BACKEND_API int ggml_cpu_has_fp16_va (void);
|
| 120 |
+
GGML_BACKEND_API int ggml_cpu_has_matmul_int8(void);
|
| 121 |
+
GGML_BACKEND_API int ggml_cpu_has_sve (void);
|
| 122 |
+
GGML_BACKEND_API int ggml_cpu_get_sve_cnt (void); // sve vector length in bytes
|
| 123 |
+
// other
|
| 124 |
+
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
|
| 125 |
+
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
|
| 126 |
+
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
|
| 127 |
+
GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
|
| 128 |
|
| 129 |
// Internal types and functions exposed for tests and benchmarks
|
| 130 |
|
|
|
|
| 138 |
const void * GGML_RESTRICT y, int nr, int nc);
|
| 139 |
|
| 140 |
struct ggml_type_traits_cpu {
|
| 141 |
+
ggml_from_float_t from_float;
|
| 142 |
ggml_from_float_to_mat_t from_float_to_mat;
|
| 143 |
ggml_vec_dot_t vec_dot;
|
| 144 |
enum ggml_type vec_dot_type;
|
|
|
|
| 148 |
ggml_gemm_t gemm;
|
| 149 |
};
|
| 150 |
|
| 151 |
+
GGML_BACKEND_API const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type);
|
| 152 |
|
| 153 |
+
GGML_BACKEND_API void ggml_cpu_init(void);
|
| 154 |
|
| 155 |
//
|
| 156 |
// CPU backend
|
| 157 |
//
|
| 158 |
|
| 159 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_cpu_init(void);
|
| 160 |
|
| 161 |
+
GGML_BACKEND_API bool ggml_backend_is_cpu (ggml_backend_t backend);
|
| 162 |
+
GGML_BACKEND_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
|
| 163 |
+
GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
|
| 164 |
+
GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
|
| 165 |
|
| 166 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
|
| 167 |
|
| 168 |
#ifdef GGML_USE_CPU_HBM
|
| 169 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
|
| 170 |
#endif
|
| 171 |
|
| 172 |
#ifdef __cplusplus
|
ggml/include/ggml-cuda.h
CHANGED
|
@@ -7,7 +7,7 @@
|
|
| 7 |
extern "C" {
|
| 8 |
#endif
|
| 9 |
|
| 10 |
-
#ifdef
|
| 11 |
#define GGML_CUDA_NAME "ROCm"
|
| 12 |
#define GGML_CUBLAS_NAME "hipBLAS"
|
| 13 |
#elif defined(GGML_USE_MUSA)
|
|
@@ -20,27 +20,27 @@ extern "C" {
|
|
| 20 |
#define GGML_CUDA_MAX_DEVICES 16
|
| 21 |
|
| 22 |
// backend API
|
| 23 |
-
|
| 24 |
|
| 25 |
-
|
| 26 |
|
| 27 |
// device buffer
|
| 28 |
-
|
| 29 |
|
| 30 |
// split tensor buffer that splits matrices by rows across multiple devices
|
| 31 |
-
|
| 32 |
|
| 33 |
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
| 34 |
-
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
-
|
| 41 |
-
|
| 42 |
|
| 43 |
-
|
| 44 |
|
| 45 |
#ifdef __cplusplus
|
| 46 |
}
|
|
|
|
| 7 |
extern "C" {
|
| 8 |
#endif
|
| 9 |
|
| 10 |
+
#ifdef GGML_USE_HIP
|
| 11 |
#define GGML_CUDA_NAME "ROCm"
|
| 12 |
#define GGML_CUBLAS_NAME "hipBLAS"
|
| 13 |
#elif defined(GGML_USE_MUSA)
|
|
|
|
| 20 |
#define GGML_CUDA_MAX_DEVICES 16
|
| 21 |
|
| 22 |
// backend API
|
| 23 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_cuda_init(int device);
|
| 24 |
|
| 25 |
+
GGML_BACKEND_API bool ggml_backend_is_cuda(ggml_backend_t backend);
|
| 26 |
|
| 27 |
// device buffer
|
| 28 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
|
| 29 |
|
| 30 |
// split tensor buffer that splits matrices by rows across multiple devices
|
| 31 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(int main_device, const float * tensor_split);
|
| 32 |
|
| 33 |
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
| 34 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
|
| 35 |
|
| 36 |
+
GGML_BACKEND_API int ggml_backend_cuda_get_device_count(void);
|
| 37 |
+
GGML_BACKEND_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
|
| 38 |
+
GGML_BACKEND_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
|
| 39 |
|
| 40 |
+
GGML_BACKEND_API bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size);
|
| 41 |
+
GGML_BACKEND_API void ggml_backend_cuda_unregister_host_buffer(void * buffer);
|
| 42 |
|
| 43 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cuda_reg(void);
|
| 44 |
|
| 45 |
#ifdef __cplusplus
|
| 46 |
}
|
ggml/include/ggml-kompute.h
CHANGED
|
@@ -37,13 +37,13 @@ struct ggml_vk_device ggml_vk_current_device(void);
|
|
| 37 |
// forward declaration
|
| 38 |
typedef struct ggml_backend * ggml_backend_t;
|
| 39 |
|
| 40 |
-
|
| 41 |
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
-
|
| 47 |
|
| 48 |
#ifdef __cplusplus
|
| 49 |
}
|
|
|
|
| 37 |
// forward declaration
|
| 38 |
typedef struct ggml_backend * ggml_backend_t;
|
| 39 |
|
| 40 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_kompute_init(int device);
|
| 41 |
|
| 42 |
+
GGML_BACKEND_API bool ggml_backend_is_kompute(ggml_backend_t backend);
|
| 43 |
|
| 44 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device);
|
| 45 |
|
| 46 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_kompute_reg(void);
|
| 47 |
|
| 48 |
#ifdef __cplusplus
|
| 49 |
}
|
ggml/include/ggml-metal.h
CHANGED
|
@@ -39,27 +39,27 @@ extern "C" {
|
|
| 39 |
// user-code should use only these functions
|
| 40 |
//
|
| 41 |
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
| 45 |
|
| 46 |
GGML_DEPRECATED(
|
| 47 |
-
|
| 48 |
"obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
| 53 |
|
| 54 |
// helper to check if the device supports a specific family
|
| 55 |
// ideally, the user code should be doing these checks
|
| 56 |
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
| 57 |
-
|
| 58 |
|
| 59 |
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
| 60 |
-
|
| 61 |
|
| 62 |
-
|
| 63 |
|
| 64 |
#ifdef __cplusplus
|
| 65 |
}
|
|
|
|
| 39 |
// user-code should use only these functions
|
| 40 |
//
|
| 41 |
|
| 42 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
|
| 43 |
|
| 44 |
+
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
| 45 |
|
| 46 |
GGML_DEPRECATED(
|
| 47 |
+
GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
|
| 48 |
"obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
|
| 49 |
|
| 50 |
+
GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
| 51 |
|
| 52 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
| 53 |
|
| 54 |
// helper to check if the device supports a specific family
|
| 55 |
// ideally, the user code should be doing these checks
|
| 56 |
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
|
| 57 |
+
GGML_BACKEND_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family);
|
| 58 |
|
| 59 |
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called
|
| 60 |
+
GGML_BACKEND_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
|
| 61 |
|
| 62 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_metal_reg(void);
|
| 63 |
|
| 64 |
#ifdef __cplusplus
|
| 65 |
}
|
ggml/include/ggml-rpc.h
CHANGED
|
@@ -10,18 +10,18 @@ extern "C" {
|
|
| 10 |
#define GGML_RPC_MAX_SERVERS 16
|
| 11 |
|
| 12 |
// backend API
|
| 13 |
-
|
| 14 |
-
|
| 15 |
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
|
| 24 |
-
|
| 25 |
|
| 26 |
#ifdef __cplusplus
|
| 27 |
}
|
|
|
|
| 10 |
#define GGML_RPC_MAX_SERVERS 16
|
| 11 |
|
| 12 |
// backend API
|
| 13 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
|
| 14 |
+
GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend);
|
| 15 |
|
| 16 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
|
| 17 |
|
| 18 |
+
GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
|
| 19 |
|
| 20 |
+
GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
|
| 21 |
|
| 22 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
|
| 23 |
|
| 24 |
+
GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
|
| 25 |
|
| 26 |
#ifdef __cplusplus
|
| 27 |
}
|
ggml/include/ggml-sycl.h
CHANGED
|
@@ -17,32 +17,32 @@ extern "C" {
|
|
| 17 |
#endif
|
| 18 |
|
| 19 |
// backend API
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
|
| 24 |
// devide buffer
|
| 25 |
-
|
| 26 |
|
| 27 |
// split tensor buffer that splits matrices by rows across multiple devices
|
| 28 |
-
|
| 29 |
|
| 30 |
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
| 31 |
-
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
char *description,
|
| 37 |
size_t description_size);
|
| 38 |
-
|
| 39 |
-
|
| 40 |
|
| 41 |
// SYCL doesn't support registering host memory, keep here for reference
|
| 42 |
-
//
|
| 43 |
-
//
|
| 44 |
|
| 45 |
-
|
| 46 |
|
| 47 |
#ifdef __cplusplus
|
| 48 |
}
|
|
|
|
| 17 |
#endif
|
| 18 |
|
| 19 |
// backend API
|
| 20 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_sycl_init(int device);
|
| 21 |
|
| 22 |
+
GGML_BACKEND_API bool ggml_backend_is_sycl(ggml_backend_t backend);
|
| 23 |
|
| 24 |
// devide buffer
|
| 25 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
|
| 26 |
|
| 27 |
// split tensor buffer that splits matrices by rows across multiple devices
|
| 28 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split);
|
| 29 |
|
| 30 |
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
| 31 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
|
| 32 |
|
| 33 |
+
GGML_BACKEND_API void ggml_backend_sycl_print_sycl_devices(void);
|
| 34 |
+
GGML_BACKEND_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
|
| 35 |
+
GGML_BACKEND_API void ggml_backend_sycl_get_device_description(int device,
|
| 36 |
char *description,
|
| 37 |
size_t description_size);
|
| 38 |
+
GGML_BACKEND_API int ggml_backend_sycl_get_device_count();
|
| 39 |
+
GGML_BACKEND_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
|
| 40 |
|
| 41 |
// SYCL doesn't support registering host memory, keep here for reference
|
| 42 |
+
// GGML_BACKEND_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
|
| 43 |
+
// GGML_BACKEND_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
|
| 44 |
|
| 45 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
|
| 46 |
|
| 47 |
#ifdef __cplusplus
|
| 48 |
}
|
ggml/include/ggml-vulkan.h
CHANGED
|
@@ -10,21 +10,21 @@ extern "C" {
|
|
| 10 |
#define GGML_VK_NAME "Vulkan"
|
| 11 |
#define GGML_VK_MAX_DEVICES 16
|
| 12 |
|
| 13 |
-
|
| 14 |
|
| 15 |
// backend API
|
| 16 |
-
|
| 17 |
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
|
| 23 |
-
|
| 24 |
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
| 25 |
-
|
| 26 |
|
| 27 |
-
|
| 28 |
|
| 29 |
#ifdef __cplusplus
|
| 30 |
}
|
|
|
|
| 10 |
#define GGML_VK_NAME "Vulkan"
|
| 11 |
#define GGML_VK_MAX_DEVICES 16
|
| 12 |
|
| 13 |
+
GGML_BACKEND_API void ggml_vk_instance_init(void);
|
| 14 |
|
| 15 |
// backend API
|
| 16 |
+
GGML_BACKEND_API ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
| 17 |
|
| 18 |
+
GGML_BACKEND_API bool ggml_backend_is_vk(ggml_backend_t backend);
|
| 19 |
+
GGML_BACKEND_API int ggml_backend_vk_get_device_count(void);
|
| 20 |
+
GGML_BACKEND_API void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
|
| 21 |
+
GGML_BACKEND_API void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
|
| 22 |
|
| 23 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
|
| 24 |
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
| 25 |
+
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
|
| 26 |
|
| 27 |
+
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_vk_reg(void);
|
| 28 |
|
| 29 |
#ifdef __cplusplus
|
| 30 |
}
|
ggml/include/ggml.h
CHANGED
|
@@ -176,15 +176,15 @@
|
|
| 176 |
#ifdef GGML_SHARED
|
| 177 |
# if defined(_WIN32) && !defined(__MINGW32__)
|
| 178 |
# ifdef GGML_BUILD
|
| 179 |
-
# define GGML_API __declspec(dllexport)
|
| 180 |
# else
|
| 181 |
-
# define GGML_API __declspec(dllimport)
|
| 182 |
# endif
|
| 183 |
# else
|
| 184 |
-
# define GGML_API __attribute__ ((visibility ("default")))
|
| 185 |
# endif
|
| 186 |
#else
|
| 187 |
-
# define GGML_API
|
| 188 |
#endif
|
| 189 |
|
| 190 |
// TODO: support for clang
|
|
@@ -1490,7 +1490,7 @@ extern "C" {
|
|
| 1490 |
"use ggml_rope_ext_inplace instead");
|
| 1491 |
|
| 1492 |
// compute correction dims for YaRN RoPE scaling
|
| 1493 |
-
void ggml_rope_yarn_corr_dims(
|
| 1494 |
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
| 1495 |
|
| 1496 |
// rotary position embedding backward, i.e compute dx from dy
|
|
@@ -2384,38 +2384,6 @@ extern "C" {
|
|
| 2384 |
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
|
| 2385 |
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
|
| 2386 |
|
| 2387 |
-
//
|
| 2388 |
-
// system info
|
| 2389 |
-
//
|
| 2390 |
-
|
| 2391 |
-
GGML_API int ggml_cpu_has_avx (void);
|
| 2392 |
-
GGML_API int ggml_cpu_has_avx_vnni (void);
|
| 2393 |
-
GGML_API int ggml_cpu_has_avx2 (void);
|
| 2394 |
-
GGML_API int ggml_cpu_has_avx512 (void);
|
| 2395 |
-
GGML_API int ggml_cpu_has_avx512_vbmi(void);
|
| 2396 |
-
GGML_API int ggml_cpu_has_avx512_vnni(void);
|
| 2397 |
-
GGML_API int ggml_cpu_has_avx512_bf16(void);
|
| 2398 |
-
GGML_API int ggml_cpu_has_amx_int8 (void);
|
| 2399 |
-
GGML_API int ggml_cpu_has_fma (void);
|
| 2400 |
-
GGML_API int ggml_cpu_has_arm_fma (void);
|
| 2401 |
-
GGML_API int ggml_cpu_has_metal (void);
|
| 2402 |
-
GGML_API int ggml_cpu_has_f16c (void);
|
| 2403 |
-
GGML_API int ggml_cpu_has_fp16_va (void);
|
| 2404 |
-
GGML_API int ggml_cpu_has_wasm_simd (void);
|
| 2405 |
-
GGML_API int ggml_cpu_has_blas (void);
|
| 2406 |
-
GGML_API int ggml_cpu_has_cuda (void);
|
| 2407 |
-
GGML_API int ggml_cpu_has_vulkan (void);
|
| 2408 |
-
GGML_API int ggml_cpu_has_kompute (void);
|
| 2409 |
-
GGML_API int ggml_cpu_has_gpublas (void);
|
| 2410 |
-
GGML_API int ggml_cpu_has_sse3 (void);
|
| 2411 |
-
GGML_API int ggml_cpu_has_ssse3 (void);
|
| 2412 |
-
GGML_API int ggml_cpu_has_riscv_v (void);
|
| 2413 |
-
GGML_API int ggml_cpu_has_sycl (void);
|
| 2414 |
-
GGML_API int ggml_cpu_has_rpc (void);
|
| 2415 |
-
GGML_API int ggml_cpu_has_vsx (void);
|
| 2416 |
-
GGML_API int ggml_cpu_has_cann (void);
|
| 2417 |
-
GGML_API int ggml_cpu_has_llamafile (void);
|
| 2418 |
-
|
| 2419 |
#ifdef __cplusplus
|
| 2420 |
// restrict not standard in C++
|
| 2421 |
#define GGML_RESTRICT
|
|
@@ -2432,7 +2400,6 @@ extern "C" {
|
|
| 2432 |
size_t type_size;
|
| 2433 |
bool is_quantized;
|
| 2434 |
ggml_to_float_t to_float;
|
| 2435 |
-
ggml_from_float_t from_float;
|
| 2436 |
ggml_from_float_t from_float_ref;
|
| 2437 |
};
|
| 2438 |
|
|
|
|
| 176 |
#ifdef GGML_SHARED
|
| 177 |
# if defined(_WIN32) && !defined(__MINGW32__)
|
| 178 |
# ifdef GGML_BUILD
|
| 179 |
+
# define GGML_API __declspec(dllexport) extern
|
| 180 |
# else
|
| 181 |
+
# define GGML_API __declspec(dllimport) extern
|
| 182 |
# endif
|
| 183 |
# else
|
| 184 |
+
# define GGML_API __attribute__ ((visibility ("default"))) extern
|
| 185 |
# endif
|
| 186 |
#else
|
| 187 |
+
# define GGML_API extern
|
| 188 |
#endif
|
| 189 |
|
| 190 |
// TODO: support for clang
|
|
|
|
| 1490 |
"use ggml_rope_ext_inplace instead");
|
| 1491 |
|
| 1492 |
// compute correction dims for YaRN RoPE scaling
|
| 1493 |
+
GGML_API void ggml_rope_yarn_corr_dims(
|
| 1494 |
int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]);
|
| 1495 |
|
| 1496 |
// rotary position embedding backward, i.e compute dx from dy
|
|
|
|
| 2384 |
GGML_API size_t gguf_get_meta_size(const struct gguf_context * ctx);
|
| 2385 |
GGML_API void gguf_get_meta_data(const struct gguf_context * ctx, void * data);
|
| 2386 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2387 |
#ifdef __cplusplus
|
| 2388 |
// restrict not standard in C++
|
| 2389 |
#define GGML_RESTRICT
|
|
|
|
| 2400 |
size_t type_size;
|
| 2401 |
bool is_quantized;
|
| 2402 |
ggml_to_float_t to_float;
|
|
|
|
| 2403 |
ggml_from_float_t from_float_ref;
|
| 2404 |
};
|
| 2405 |
|
ggml/src/ggml-aarch64.c
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml/src/ggml-aarch64.h
CHANGED
|
@@ -1,9 +1,5 @@
|
|
| 1 |
-
// SPDX-FileCopyrightText: Copyright 2024 Arm Ltd.
|
| 2 |
#pragma once
|
| 3 |
|
| 4 |
-
#define GGML_COMMON_DECL_C
|
| 5 |
-
#include "ggml-common.h"
|
| 6 |
-
|
| 7 |
#include "ggml.h"
|
| 8 |
|
| 9 |
// GGML internal header
|
|
@@ -12,27 +8,11 @@
|
|
| 12 |
extern "C" {
|
| 13 |
#endif
|
| 14 |
|
| 15 |
-
// Quantization
|
| 16 |
-
void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 17 |
-
void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 18 |
-
|
| 19 |
-
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
|
| 20 |
-
|
| 21 |
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
| 22 |
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 23 |
size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 24 |
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 25 |
|
| 26 |
-
// GEMV
|
| 27 |
-
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 28 |
-
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 29 |
-
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 30 |
-
|
| 31 |
-
// GEMM
|
| 32 |
-
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 33 |
-
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 34 |
-
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 35 |
-
|
| 36 |
#ifdef __cplusplus
|
| 37 |
}
|
| 38 |
#endif
|
|
|
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
#include "ggml.h"
|
| 4 |
|
| 5 |
// GGML internal header
|
|
|
|
| 8 |
extern "C" {
|
| 9 |
#endif
|
| 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
| 12 |
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 13 |
size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 14 |
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
#ifdef __cplusplus
|
| 17 |
}
|
| 18 |
#endif
|
ggml/src/ggml-amx/CMakeLists.txt
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
| 2 |
+
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
| 3 |
+
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$") AND
|
| 4 |
+
CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 11.0)
|
| 5 |
+
message(STATUS "Using AMX")
|
| 6 |
+
|
| 7 |
+
file(GLOB GGML_HEADERS_AMX "*.h")
|
| 8 |
+
list(APPEND GGML_HEADERS_AMX "../../include/ggml-amx.h")
|
| 9 |
+
|
| 10 |
+
file(GLOB GGML_SOURCES_AMX "*.cpp")
|
| 11 |
+
|
| 12 |
+
add_library(ggml-amx
|
| 13 |
+
${GGML_HEADERS_AMX}
|
| 14 |
+
${GGML_SOURCES_AMX})
|
| 15 |
+
|
| 16 |
+
target_link_libraries(ggml-amx PRIVATE ggml-base)
|
| 17 |
+
target_include_directories(ggml-amx PRIVATE . ..)
|
| 18 |
+
|
| 19 |
+
# this is duplicated from the CPU backend, since the AMX backend also depends on the architecture flags
|
| 20 |
+
# TODO: integrate AMX backend into the CPU backend
|
| 21 |
+
if (MSVC)
|
| 22 |
+
# instruction set detection for MSVC only
|
| 23 |
+
if (GGML_NATIVE)
|
| 24 |
+
# TODO: improve, should not reference files from the parent folder
|
| 25 |
+
include(../ggml-cpu/cmake/FindSIMD.cmake)
|
| 26 |
+
endif ()
|
| 27 |
+
if (GGML_AVX512)
|
| 28 |
+
list(APPEND ARCH_FLAGS /arch:AVX512)
|
| 29 |
+
# MSVC has no compile-time flags enabling specific
|
| 30 |
+
# AVX512 extensions, neither it defines the
|
| 31 |
+
# macros corresponding to the extensions.
|
| 32 |
+
# Do it manually.
|
| 33 |
+
if (GGML_AVX512_VBMI)
|
| 34 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
|
| 35 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
|
| 36 |
+
endif()
|
| 37 |
+
if (GGML_AVX512_VNNI)
|
| 38 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
| 39 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
| 40 |
+
endif()
|
| 41 |
+
if (GGML_AVX512_BF16)
|
| 42 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
|
| 43 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
|
| 44 |
+
endif()
|
| 45 |
+
if (GGML_AMX_TILE)
|
| 46 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_TILE__>)
|
| 47 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_TILE__>)
|
| 48 |
+
endif()
|
| 49 |
+
if (GGML_AMX_INT8)
|
| 50 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_INT8__>)
|
| 51 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_INT8__>)
|
| 52 |
+
endif()
|
| 53 |
+
if (GGML_AMX_BF16)
|
| 54 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_BF16__>)
|
| 55 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_BF16__>)
|
| 56 |
+
endif()
|
| 57 |
+
elseif (GGML_AVX2)
|
| 58 |
+
list(APPEND ARCH_FLAGS /arch:AVX2)
|
| 59 |
+
elseif (GGML_AVX)
|
| 60 |
+
list(APPEND ARCH_FLAGS /arch:AVX)
|
| 61 |
+
endif()
|
| 62 |
+
else()
|
| 63 |
+
if (GGML_NATIVE)
|
| 64 |
+
list(APPEND ARCH_FLAGS -march=native)
|
| 65 |
+
endif()
|
| 66 |
+
if (GGML_F16C)
|
| 67 |
+
list(APPEND ARCH_FLAGS -mf16c)
|
| 68 |
+
endif()
|
| 69 |
+
if (GGML_FMA)
|
| 70 |
+
list(APPEND ARCH_FLAGS -mfma)
|
| 71 |
+
endif()
|
| 72 |
+
if (GGML_AVX)
|
| 73 |
+
list(APPEND ARCH_FLAGS -mavx)
|
| 74 |
+
endif()
|
| 75 |
+
if (GGML_AVX2)
|
| 76 |
+
list(APPEND ARCH_FLAGS -mavx2)
|
| 77 |
+
endif()
|
| 78 |
+
if (GGML_AVX512)
|
| 79 |
+
list(APPEND ARCH_FLAGS -mavx512f)
|
| 80 |
+
list(APPEND ARCH_FLAGS -mavx512dq)
|
| 81 |
+
list(APPEND ARCH_FLAGS -mavx512bw)
|
| 82 |
+
endif()
|
| 83 |
+
if (GGML_AVX512_VBMI)
|
| 84 |
+
list(APPEND ARCH_FLAGS -mavx512vbmi)
|
| 85 |
+
endif()
|
| 86 |
+
if (GGML_AVX512_VNNI)
|
| 87 |
+
list(APPEND ARCH_FLAGS -mavx512vnni)
|
| 88 |
+
endif()
|
| 89 |
+
if (GGML_AVX512_BF16)
|
| 90 |
+
list(APPEND ARCH_FLAGS -mavx512bf16)
|
| 91 |
+
endif()
|
| 92 |
+
if (GGML_AMX_TILE)
|
| 93 |
+
list(APPEND ARCH_FLAGS -mamx-tile)
|
| 94 |
+
endif()
|
| 95 |
+
if (GGML_AMX_INT8)
|
| 96 |
+
list(APPEND ARCH_FLAGS -mamx-int8)
|
| 97 |
+
endif()
|
| 98 |
+
if (GGML_AMX_BF16)
|
| 99 |
+
list(APPEND ARCH_FLAGS -mamx-bf16)
|
| 100 |
+
endif()
|
| 101 |
+
endif()
|
| 102 |
+
|
| 103 |
+
target_compile_options(ggml-amx PRIVATE ${ARCH_FLAGS})
|
| 104 |
+
else()
|
| 105 |
+
set(GGML_AMX OFF PARENT_SCOPE)
|
| 106 |
+
message(WARNING "AMX requires x86 and gcc version > 11.0. Turning off GGML_AMX.")
|
| 107 |
+
endif()
|
ggml/src/ggml-amx/common.h
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "ggml.h"
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
#include <algorithm>
|
| 7 |
#include <memory>
|
|
|
|
| 1 |
#pragma once
|
| 2 |
|
| 3 |
#include "ggml.h"
|
| 4 |
+
// hack until AMX is moved into the CPU backend
|
| 5 |
+
#include "../ggml-cpu/ggml-cpu-impl.h" // <immintrin.h>
|
| 6 |
|
| 7 |
#include <algorithm>
|
| 8 |
#include <memory>
|
ggml/src/ggml-amx/ggml-amx.cpp
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ggml-amx.h"
|
| 2 |
+
#include "ggml-amx/common.h"
|
| 3 |
+
#include "ggml-amx/mmq.h"
|
| 4 |
+
#include "ggml-backend-impl.h"
|
| 5 |
+
#include "ggml-impl.h"
|
| 6 |
+
|
| 7 |
+
#if defined(__gnu_linux__)
|
| 8 |
+
#include <sys/syscall.h>
|
| 9 |
+
#include <unistd.h>
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
#include <cstdlib>
|
| 13 |
+
#include <cstring>
|
| 14 |
+
#include <memory>
|
| 15 |
+
|
| 16 |
+
#if defined(__AMX_INT8__)
|
| 17 |
+
|
| 18 |
+
// AMX buffer interface
|
| 19 |
+
static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
| 20 |
+
free(buffer->context);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
|
| 24 |
+
return (void *)(buffer->context);
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
|
| 28 |
+
memset((char *)tensor->data + offset, value, size);
|
| 29 |
+
|
| 30 |
+
GGML_UNUSED(buffer);
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
| 34 |
+
if (qtype_has_amx_kernels(tensor->type)) {
|
| 35 |
+
ggml_backend_amx_convert_weight(tensor, data, offset, size);
|
| 36 |
+
} else {
|
| 37 |
+
memcpy((char *)tensor->data + offset, data, size);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
GGML_UNUSED(buffer);
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
| 44 |
+
GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
|
| 45 |
+
memcpy(data, (const char *)tensor->data + offset, size);
|
| 46 |
+
|
| 47 |
+
GGML_UNUSED(buffer);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
|
| 51 |
+
if (ggml_backend_buffer_is_host(src->buffer)) {
|
| 52 |
+
if (qtype_has_amx_kernels(src->type)) {
|
| 53 |
+
ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
|
| 54 |
+
} else {
|
| 55 |
+
memcpy(dst->data, src->data, ggml_nbytes(src));
|
| 56 |
+
}
|
| 57 |
+
return true;
|
| 58 |
+
}
|
| 59 |
+
return false;
|
| 60 |
+
|
| 61 |
+
GGML_UNUSED(buffer);
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
|
| 65 |
+
memset(buffer->context, value, buffer->size);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
|
| 69 |
+
/* .free_buffer = */ ggml_backend_amx_buffer_free_buffer,
|
| 70 |
+
/* .get_base = */ ggml_backend_amx_buffer_get_base,
|
| 71 |
+
/* .init_tensor = */ NULL, // no initialization required
|
| 72 |
+
/* .memset_tensor = */ ggml_backend_amx_buffer_memset_tensor,
|
| 73 |
+
/* .set_tensor = */ ggml_backend_amx_buffer_set_tensor,
|
| 74 |
+
/* .get_tensor = */ ggml_backend_amx_buffer_get_tensor,
|
| 75 |
+
/* .cpy_tensor = */ ggml_backend_amx_buffer_cpy_tensor,
|
| 76 |
+
/* .clear = */ ggml_backend_amx_buffer_clear,
|
| 77 |
+
/* .reset = */ NULL,
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
| 81 |
+
return "AMX";
|
| 82 |
+
|
| 83 |
+
GGML_UNUSED(buft);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
| 87 |
+
void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
|
| 88 |
+
if (data == NULL) {
|
| 89 |
+
fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
|
| 90 |
+
return NULL;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
|
| 97 |
+
return TENSOR_ALIGNMENT;
|
| 98 |
+
|
| 99 |
+
GGML_UNUSED(buft);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
|
| 103 |
+
return ggml_backend_amx_get_alloc_size(tensor);
|
| 104 |
+
|
| 105 |
+
GGML_UNUSED(buft);
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
| 109 |
+
return false;
|
| 110 |
+
|
| 111 |
+
GGML_UNUSED(buft);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
|
| 115 |
+
static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
|
| 116 |
+
/* .iface = */ {
|
| 117 |
+
/* .get_name = */ ggml_backend_amx_buffer_type_get_name,
|
| 118 |
+
/* .alloc_buffer = */ ggml_backend_amx_buffer_type_alloc_buffer,
|
| 119 |
+
/* .get_alignment = */ ggml_backend_amx_buffer_type_get_alignment,
|
| 120 |
+
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
| 121 |
+
/* .get_alloc_size = */ ggml_backend_amx_buffer_type_get_alloc_size,
|
| 122 |
+
/* .is_host = */ ggml_backend_amx_buffer_type_is_host,
|
| 123 |
+
},
|
| 124 |
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
|
| 125 |
+
/* .context = */ NULL,
|
| 126 |
+
};
|
| 127 |
+
|
| 128 |
+
return &ggml_backend_buffer_type_amx;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// backend interface
|
| 132 |
+
|
| 133 |
+
static const char * ggml_backend_amx_name(ggml_backend_t backend) {
|
| 134 |
+
return "AMX";
|
| 135 |
+
|
| 136 |
+
GGML_UNUSED(backend);
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
static void ggml_backend_amx_free(ggml_backend_t backend) {
|
| 140 |
+
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
|
| 141 |
+
delete ctx;
|
| 142 |
+
delete backend;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
| 146 |
+
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
|
| 147 |
+
|
| 148 |
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 149 |
+
struct ggml_tensor * node = cgraph->nodes[i];
|
| 150 |
+
|
| 151 |
+
switch (node->op) {
|
| 152 |
+
case GGML_OP_MUL_MAT:
|
| 153 |
+
ggml_backend_amx_mul_mat(ctx, node);
|
| 154 |
+
break;
|
| 155 |
+
|
| 156 |
+
case GGML_OP_NONE:
|
| 157 |
+
case GGML_OP_RESHAPE:
|
| 158 |
+
case GGML_OP_VIEW:
|
| 159 |
+
case GGML_OP_PERMUTE:
|
| 160 |
+
case GGML_OP_TRANSPOSE:
|
| 161 |
+
break;
|
| 162 |
+
|
| 163 |
+
default:
|
| 164 |
+
fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
|
| 165 |
+
GGML_ASSERT(false);
|
| 166 |
+
}
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
return GGML_STATUS_SUCCESS;
|
| 170 |
+
|
| 171 |
+
GGML_UNUSED(backend);
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
static struct ggml_backend_i ggml_backend_amx_i = {
|
| 175 |
+
/* .get_name = */ ggml_backend_amx_name,
|
| 176 |
+
/* .free = */ ggml_backend_amx_free,
|
| 177 |
+
/* .set_tensor_async = */ NULL,
|
| 178 |
+
/* .get_tensor_async = */ NULL,
|
| 179 |
+
/* .cpy_tensor_async = */ NULL,
|
| 180 |
+
/* .synchronize = */ NULL,
|
| 181 |
+
/* .graph_plan_create = */ NULL,
|
| 182 |
+
/* .graph_plan_free = */ NULL,
|
| 183 |
+
/* .graph_plan_update = */ NULL,
|
| 184 |
+
/* .graph_plan_compute = */ NULL,
|
| 185 |
+
/* .graph_compute = */ ggml_backend_amx_graph_compute,
|
| 186 |
+
/* .event_record = */ NULL,
|
| 187 |
+
/* .event_wait = */ NULL,
|
| 188 |
+
};
|
| 189 |
+
|
| 190 |
+
static ggml_guid_t ggml_backend_amx_guid() {
|
| 191 |
+
static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
|
| 192 |
+
return &guid;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
#define ARCH_GET_XCOMP_PERM 0x1022
|
| 196 |
+
#define ARCH_REQ_XCOMP_PERM 0x1023
|
| 197 |
+
#define XFEATURE_XTILECFG 17
|
| 198 |
+
#define XFEATURE_XTILEDATA 18
|
| 199 |
+
|
| 200 |
+
static bool ggml_amx_init() {
|
| 201 |
+
#if defined(__gnu_linux__)
|
| 202 |
+
if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
|
| 203 |
+
fprintf(stderr, "AMX is not ready to be used!\n");
|
| 204 |
+
return false;
|
| 205 |
+
}
|
| 206 |
+
return true;
|
| 207 |
+
#elif defined(_WIN32)
|
| 208 |
+
return true;
|
| 209 |
+
#endif
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
ggml_backend_t ggml_backend_amx_init() {
|
| 213 |
+
|
| 214 |
+
// invoke a Linux system call to request access to AMX features
|
| 215 |
+
ggml_amx_init();
|
| 216 |
+
|
| 217 |
+
// backend context
|
| 218 |
+
ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
|
| 219 |
+
|
| 220 |
+
// ggml amx backend
|
| 221 |
+
ggml_backend_t backend = new ggml_backend {
|
| 222 |
+
/* .guid = */ ggml_backend_amx_guid(),
|
| 223 |
+
/* .interface = */ ggml_backend_amx_i,
|
| 224 |
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
|
| 225 |
+
/* .context = */ ctx,
|
| 226 |
+
};
|
| 227 |
+
|
| 228 |
+
return backend;
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
bool ggml_backend_is_amx(ggml_backend_t backend) {
|
| 232 |
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
|
| 236 |
+
GGML_ASSERT(ggml_backend_is_amx(backend_amx));
|
| 237 |
+
|
| 238 |
+
ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
|
| 239 |
+
ctx->n_threads = n_threads;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// device interface
|
| 243 |
+
|
| 244 |
+
static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
|
| 245 |
+
return "AMX";
|
| 246 |
+
|
| 247 |
+
GGML_UNUSED(dev);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
|
| 251 |
+
return "Intel Advanced Matrix Extensions";
|
| 252 |
+
|
| 253 |
+
GGML_UNUSED(dev);
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
| 257 |
+
// TODO
|
| 258 |
+
*free = 0;
|
| 259 |
+
*total = 0;
|
| 260 |
+
|
| 261 |
+
GGML_UNUSED(dev);
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
|
| 265 |
+
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
|
| 266 |
+
|
| 267 |
+
GGML_UNUSED(dev);
|
| 268 |
+
}
|
| 269 |
+
|
| 270 |
+
static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
| 271 |
+
props->name = ggml_backend_amx_device_get_name(dev);
|
| 272 |
+
props->description = ggml_backend_amx_device_get_description(dev);
|
| 273 |
+
props->type = ggml_backend_amx_device_get_type(dev);
|
| 274 |
+
ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
| 275 |
+
|
| 276 |
+
// `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
|
| 277 |
+
props->caps = {
|
| 278 |
+
/* .async = */ false,
|
| 279 |
+
/* .host_buffer = */ false,
|
| 280 |
+
/* .buffer_from_host_ptr = */ false,
|
| 281 |
+
/* .events = */ false,
|
| 282 |
+
};
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
|
| 286 |
+
return ggml_backend_amx_init();
|
| 287 |
+
|
| 288 |
+
GGML_UNUSED(dev);
|
| 289 |
+
GGML_UNUSED(params);
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
|
| 293 |
+
return ggml_backend_amx_buffer_type();
|
| 294 |
+
|
| 295 |
+
GGML_UNUSED(dev);
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
| 299 |
+
|
| 300 |
+
// handle only 2d gemm for now
|
| 301 |
+
auto is_contiguous_2d = [](const struct ggml_tensor * t) {
|
| 302 |
+
return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
|
| 303 |
+
};
|
| 304 |
+
|
| 305 |
+
switch (op->op) {
|
| 306 |
+
case GGML_OP_NONE:
|
| 307 |
+
case GGML_OP_RESHAPE:
|
| 308 |
+
case GGML_OP_VIEW:
|
| 309 |
+
case GGML_OP_PERMUTE:
|
| 310 |
+
case GGML_OP_TRANSPOSE:
|
| 311 |
+
return true;
|
| 312 |
+
|
| 313 |
+
case GGML_OP_MUL_MAT: {
|
| 314 |
+
const struct ggml_tensor * src0 = op->src[0];
|
| 315 |
+
const struct ggml_tensor * src1 = op->src[1];
|
| 316 |
+
|
| 317 |
+
const enum ggml_type type = src0->type;
|
| 318 |
+
const int64_t ne0 = op->ne[0];
|
| 319 |
+
|
| 320 |
+
bool is_training = src0->grad || src1->grad;
|
| 321 |
+
|
| 322 |
+
// amx kernels enables for Q4_0, Q4_1, Q8_0, F16
|
| 323 |
+
// Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
|
| 324 |
+
bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
|
| 325 |
+
|
| 326 |
+
bool can_use_amx =
|
| 327 |
+
is_contiguous_2d(src0) && // src0 must be contiguous
|
| 328 |
+
is_contiguous_2d(src1) && // src1 must be contiguous
|
| 329 |
+
!is_training && // inference only
|
| 330 |
+
src1->type == GGML_TYPE_F32 && // src1 must be float32
|
| 331 |
+
has_amx_kernels && // with amx kernel impls
|
| 332 |
+
ne0 % (TILE_N * 2) == 0; // out_features is 32x
|
| 333 |
+
|
| 334 |
+
return can_use_amx;
|
| 335 |
+
}
|
| 336 |
+
default:
|
| 337 |
+
return false;
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
GGML_UNUSED(dev);
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
| 344 |
+
return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
|
| 345 |
+
|
| 346 |
+
GGML_UNUSED(dev);
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
|
| 350 |
+
/* .get_name = */ ggml_backend_amx_device_get_name,
|
| 351 |
+
/* .get_description = */ ggml_backend_amx_device_get_description,
|
| 352 |
+
/* .get_memory = */ ggml_backend_amx_device_get_memory,
|
| 353 |
+
/* .get_type = */ ggml_backend_amx_device_get_type,
|
| 354 |
+
/* .get_props = */ ggml_backend_amx_device_get_props,
|
| 355 |
+
/* .init_backend = */ ggml_backend_amx_device_init,
|
| 356 |
+
/* .get_buffer_type = */ ggml_backend_amx_device_get_buffer_type,
|
| 357 |
+
/* .get_host_buffer_type = */ NULL,
|
| 358 |
+
/* .buffer_from_host_ptr = */ NULL,
|
| 359 |
+
/* .supports_op = */ ggml_backend_amx_device_supports_op,
|
| 360 |
+
/* .supports_buft = */ ggml_backend_amx_device_supports_buft,
|
| 361 |
+
/* .offload_op = */ NULL,
|
| 362 |
+
/* .event_new = */ NULL,
|
| 363 |
+
/* .event_free = */ NULL,
|
| 364 |
+
/* .event_synchronize = */ NULL,
|
| 365 |
+
};
|
| 366 |
+
|
| 367 |
+
// backend reg interface
|
| 368 |
+
|
| 369 |
+
static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
|
| 370 |
+
return "AMX";
|
| 371 |
+
|
| 372 |
+
GGML_UNUSED(reg);
|
| 373 |
+
}
|
| 374 |
+
|
| 375 |
+
static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
|
| 376 |
+
return 1;
|
| 377 |
+
|
| 378 |
+
GGML_UNUSED(reg);
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
| 382 |
+
GGML_ASSERT(index == 0);
|
| 383 |
+
|
| 384 |
+
static ggml_backend_device ggml_backend_amx_device = {
|
| 385 |
+
/* .iface = */ ggml_backend_amx_device_i,
|
| 386 |
+
/* .reg = */ reg,
|
| 387 |
+
/* .context = */ nullptr,
|
| 388 |
+
};
|
| 389 |
+
|
| 390 |
+
return &ggml_backend_amx_device;
|
| 391 |
+
|
| 392 |
+
GGML_UNUSED(reg);
|
| 393 |
+
GGML_UNUSED(index);
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
| 397 |
+
if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
| 398 |
+
return (void *)ggml_backend_amx_set_n_threads;
|
| 399 |
+
}
|
| 400 |
+
return NULL;
|
| 401 |
+
|
| 402 |
+
GGML_UNUSED(reg);
|
| 403 |
+
GGML_UNUSED(name);
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
|
| 407 |
+
/* .get_name = */ ggml_backend_amx_reg_get_name,
|
| 408 |
+
/* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
|
| 409 |
+
/* .get_device = */ ggml_backend_amx_reg_get_device,
|
| 410 |
+
/* .get_proc_address = */ ggml_backend_amx_get_proc_address,
|
| 411 |
+
};
|
| 412 |
+
|
| 413 |
+
ggml_backend_reg_t ggml_backend_amx_reg(void) {
|
| 414 |
+
static struct ggml_backend_reg ggml_backend_amx_reg = {
|
| 415 |
+
/* .iface = */ ggml_backend_amx_reg_i,
|
| 416 |
+
/* .context = */ NULL,
|
| 417 |
+
};
|
| 418 |
+
|
| 419 |
+
return &ggml_backend_amx_reg;
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
#else // if defined(__AMX_INT8__)
|
| 423 |
+
|
| 424 |
+
ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void) {
|
| 425 |
+
return nullptr;
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
bool ggml_backend_is_amx(ggml_backend_t backend) {
|
| 429 |
+
GGML_UNUSED(backend);
|
| 430 |
+
return false;
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
ggml_backend_t ggml_backend_amx_init(void) {
|
| 434 |
+
fprintf(stderr, "GGML is not compiled with AMX support!\n");
|
| 435 |
+
return nullptr;
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
|
| 439 |
+
fprintf(stderr, "GGML is not compiled with AMX support!\n");
|
| 440 |
+
|
| 441 |
+
GGML_UNUSED(backend_amx);
|
| 442 |
+
GGML_UNUSED(n_threads);
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
ggml_backend_reg_t ggml_backend_amx_reg(void) {
|
| 446 |
+
return nullptr;
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
#endif
|
ggml/src/ggml-amx/mmq.cpp
CHANGED
|
@@ -496,19 +496,20 @@ inline void from_float(const float * x, char * vy, int64_t k);
|
|
| 496 |
|
| 497 |
template <>
|
| 498 |
inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
|
| 499 |
-
|
|
|
|
| 500 |
}
|
| 501 |
|
| 502 |
template <>
|
| 503 |
inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
|
| 504 |
-
|
| 505 |
}
|
| 506 |
|
| 507 |
template <>
|
| 508 |
inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {
|
| 509 |
#if 1
|
| 510 |
// TODO: this is reference impl!
|
| 511 |
-
|
| 512 |
#else
|
| 513 |
quantize_row_q8_K_vnni(x, vy, k);
|
| 514 |
#endif
|
|
|
|
| 496 |
|
| 497 |
template <>
|
| 498 |
inline void from_float<block_q8_0>(const float * x, char * vy, int64_t k) {
|
| 499 |
+
// FIXME: using unoptimized reference impl until moved to CPU backend
|
| 500 |
+
quantize_row_q8_0_ref(x, (block_q8_0 *)vy, k);
|
| 501 |
}
|
| 502 |
|
| 503 |
template <>
|
| 504 |
inline void from_float<block_q8_1>(const float * x, char * vy, int64_t k) {
|
| 505 |
+
quantize_row_q8_1_ref(x, (block_q8_1 *)vy, k);
|
| 506 |
}
|
| 507 |
|
| 508 |
template <>
|
| 509 |
inline void from_float<block_q8_K>(const float * x, char * vy, int64_t k) {
|
| 510 |
#if 1
|
| 511 |
// TODO: this is reference impl!
|
| 512 |
+
quantize_row_q8_K_ref(x, (block_q8_K *)vy, k);
|
| 513 |
#else
|
| 514 |
quantize_row_q8_K_vnni(x, vy, k);
|
| 515 |
#endif
|
ggml/src/ggml-backend-reg.cpp
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ggml-backend-impl.h"
|
| 2 |
+
#include "ggml-backend.h"
|
| 3 |
+
#include "ggml-cpu.h"
|
| 4 |
+
#include "ggml-impl.h"
|
| 5 |
+
#include <cstring>
|
| 6 |
+
#include <vector>
|
| 7 |
+
|
| 8 |
+
// Backend registry
|
| 9 |
+
|
| 10 |
+
#ifdef GGML_USE_CUDA
|
| 11 |
+
#include "ggml-cuda.h"
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#ifdef GGML_USE_METAL
|
| 15 |
+
#include "ggml-metal.h"
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#ifdef GGML_USE_SYCL
|
| 19 |
+
#include "ggml-sycl.h"
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
#ifdef GGML_USE_VULKAN
|
| 23 |
+
#include "ggml-vulkan.h"
|
| 24 |
+
#endif
|
| 25 |
+
|
| 26 |
+
#ifdef GGML_USE_BLAS
|
| 27 |
+
#include "ggml-blas.h"
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
#ifdef GGML_USE_RPC
|
| 31 |
+
#include "ggml-rpc.h"
|
| 32 |
+
#endif
|
| 33 |
+
|
| 34 |
+
#ifdef GGML_USE_AMX
|
| 35 |
+
# include "ggml-amx.h"
|
| 36 |
+
#endif
|
| 37 |
+
|
| 38 |
+
#ifdef GGML_USE_CANN
|
| 39 |
+
#include "ggml-cann.h"
|
| 40 |
+
#endif
|
| 41 |
+
|
| 42 |
+
#ifdef GGML_USE_KOMPUTE
|
| 43 |
+
#include "ggml-kompute.h"
|
| 44 |
+
#endif
|
| 45 |
+
|
| 46 |
+
struct ggml_backend_registry {
|
| 47 |
+
std::vector<ggml_backend_reg_t> backends;
|
| 48 |
+
std::vector<ggml_backend_dev_t> devices;
|
| 49 |
+
|
| 50 |
+
ggml_backend_registry() {
|
| 51 |
+
#ifdef GGML_USE_CUDA
|
| 52 |
+
register_backend(ggml_backend_cuda_reg());
|
| 53 |
+
#endif
|
| 54 |
+
#ifdef GGML_USE_METAL
|
| 55 |
+
register_backend(ggml_backend_metal_reg());
|
| 56 |
+
#endif
|
| 57 |
+
#ifdef GGML_USE_SYCL
|
| 58 |
+
register_backend(ggml_backend_sycl_reg());
|
| 59 |
+
#endif
|
| 60 |
+
#ifdef GGML_USE_VULKAN
|
| 61 |
+
register_backend(ggml_backend_vk_reg());
|
| 62 |
+
#endif
|
| 63 |
+
#ifdef GGML_USE_CANN
|
| 64 |
+
register_backend(ggml_backend_cann_reg());
|
| 65 |
+
#endif
|
| 66 |
+
#ifdef GGML_USE_BLAS
|
| 67 |
+
register_backend(ggml_backend_blas_reg());
|
| 68 |
+
#endif
|
| 69 |
+
#ifdef GGML_USE_RPC
|
| 70 |
+
register_backend(ggml_backend_rpc_reg());
|
| 71 |
+
#endif
|
| 72 |
+
#ifdef GGML_USE_AMX
|
| 73 |
+
register_backend(ggml_backend_amx_reg());
|
| 74 |
+
#endif
|
| 75 |
+
#ifdef GGML_USE_KOMPUTE
|
| 76 |
+
register_backend(ggml_backend_kompute_reg());
|
| 77 |
+
#endif
|
| 78 |
+
|
| 79 |
+
register_backend(ggml_backend_cpu_reg());
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
void register_backend(ggml_backend_reg_t reg) {
|
| 83 |
+
if (!reg) {
|
| 84 |
+
return;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
#ifndef NDEBUG
|
| 88 |
+
GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
|
| 89 |
+
__func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
|
| 90 |
+
#endif
|
| 91 |
+
backends.push_back(reg);
|
| 92 |
+
for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) {
|
| 93 |
+
register_device(ggml_backend_reg_dev_get(reg, i));
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
void register_device(ggml_backend_dev_t device) {
|
| 98 |
+
#ifndef NDEBUG
|
| 99 |
+
GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
|
| 100 |
+
#endif
|
| 101 |
+
devices.push_back(device);
|
| 102 |
+
}
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
static ggml_backend_registry & get_reg() {
|
| 106 |
+
static ggml_backend_registry reg;
|
| 107 |
+
return reg;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// Internal API
|
| 111 |
+
void ggml_backend_register(ggml_backend_reg_t reg) {
|
| 112 |
+
get_reg().register_backend(reg);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
void ggml_backend_device_register(ggml_backend_dev_t device) {
|
| 116 |
+
get_reg().register_device(device);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// Backend (reg) enumeration
|
| 120 |
+
size_t ggml_backend_reg_count() {
|
| 121 |
+
return get_reg().backends.size();
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
|
| 125 |
+
GGML_ASSERT(index < ggml_backend_reg_count());
|
| 126 |
+
return get_reg().backends[index];
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
|
| 130 |
+
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
|
| 131 |
+
ggml_backend_reg_t reg = ggml_backend_reg_get(i);
|
| 132 |
+
if (std::strcmp(ggml_backend_reg_name(reg), name) == 0) {
|
| 133 |
+
return reg;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
return NULL;
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
// Device enumeration
|
| 140 |
+
size_t ggml_backend_dev_count() {
|
| 141 |
+
return get_reg().devices.size();
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
|
| 145 |
+
GGML_ASSERT(index < ggml_backend_dev_count());
|
| 146 |
+
return get_reg().devices[index];
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
|
| 150 |
+
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
| 151 |
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
| 152 |
+
if (strcmp(ggml_backend_dev_name(dev), name) == 0) {
|
| 153 |
+
return dev;
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
return NULL;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
ggml_backend_dev_t ggml_backend_dev_by_type(enum ggml_backend_dev_type type) {
|
| 160 |
+
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
|
| 161 |
+
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
|
| 162 |
+
if (ggml_backend_dev_type(dev) == type) {
|
| 163 |
+
return dev;
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
return NULL;
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// Convenience functions
|
| 170 |
+
ggml_backend_t ggml_backend_init_by_name(const char * name, const char * params) {
|
| 171 |
+
ggml_backend_dev_t dev = ggml_backend_dev_by_name(name);
|
| 172 |
+
if (!dev) {
|
| 173 |
+
return NULL;
|
| 174 |
+
}
|
| 175 |
+
return ggml_backend_dev_init(dev, params);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const char * params) {
|
| 179 |
+
ggml_backend_dev_t dev = ggml_backend_dev_by_type(type);
|
| 180 |
+
if (!dev) {
|
| 181 |
+
return NULL;
|
| 182 |
+
}
|
| 183 |
+
return ggml_backend_dev_init(dev, params);
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
ggml_backend_t ggml_backend_init_best(void) {
|
| 187 |
+
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
|
| 188 |
+
if (!dev) {
|
| 189 |
+
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
|
| 190 |
+
}
|
| 191 |
+
if (!dev) {
|
| 192 |
+
return NULL;
|
| 193 |
+
}
|
| 194 |
+
return ggml_backend_dev_init(dev, NULL);
|
| 195 |
+
}
|
ggml/src/ggml-blas/CMakeLists.txt
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if (GGML_STATIC)
|
| 2 |
+
set(BLA_STATIC ON)
|
| 3 |
+
endif()
|
| 4 |
+
#if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.22)
|
| 5 |
+
# set(BLA_SIZEOF_INTEGER 8)
|
| 6 |
+
#endif()
|
| 7 |
+
|
| 8 |
+
set(BLA_VENDOR ${GGML_BLAS_VENDOR})
|
| 9 |
+
find_package(BLAS)
|
| 10 |
+
|
| 11 |
+
if (BLAS_FOUND)
|
| 12 |
+
message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
|
| 13 |
+
|
| 14 |
+
add_library(ggml-blas
|
| 15 |
+
ggml-blas.cpp
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
target_link_libraries(ggml-blas PRIVATE ggml-base)
|
| 19 |
+
target_include_directories(ggml-blas PRIVATE . ..)
|
| 20 |
+
|
| 21 |
+
if (${GGML_BLAS_VENDOR} MATCHES "Apple")
|
| 22 |
+
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
| 23 |
+
add_compile_definitions(ACCELERATE_LAPACK_ILP64)
|
| 24 |
+
add_compile_definitions(GGML_BLAS_USE_ACCELERATE)
|
| 25 |
+
elseif ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
|
| 26 |
+
# BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
|
| 27 |
+
# see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
|
| 28 |
+
find_package(PkgConfig REQUIRED)
|
| 29 |
+
if (${GGML_BLAS_VENDOR} MATCHES "Generic")
|
| 30 |
+
pkg_check_modules(DepBLAS blas)
|
| 31 |
+
elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
|
| 32 |
+
# As of openblas v0.3.22, the 64-bit is named openblas64.pc
|
| 33 |
+
pkg_check_modules(DepBLAS openblas64)
|
| 34 |
+
if (NOT DepBLAS_FOUND)
|
| 35 |
+
pkg_check_modules(DepBLAS openblas)
|
| 36 |
+
endif()
|
| 37 |
+
elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
|
| 38 |
+
add_compile_definitions(GGML_BLAS_USE_BLIS)
|
| 39 |
+
pkg_check_modules(DepBLAS blis)
|
| 40 |
+
elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
|
| 41 |
+
pkg_check_modules(DepBLAS blas-atlas)
|
| 42 |
+
elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
|
| 43 |
+
pkg_check_modules(DepBLAS flexiblas_api)
|
| 44 |
+
elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
|
| 45 |
+
add_compile_definitions(GGML_BLAS_USE_MKL)
|
| 46 |
+
# all Intel* libraries share the same include path
|
| 47 |
+
pkg_check_modules(DepBLAS mkl-sdl)
|
| 48 |
+
elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
|
| 49 |
+
# this doesn't provide pkg-config
|
| 50 |
+
# suggest to assign BLAS_INCLUDE_DIRS on your own
|
| 51 |
+
if ("${NVHPC_VERSION}" STREQUAL "")
|
| 52 |
+
message(WARNING "Better to set NVHPC_VERSION")
|
| 53 |
+
else()
|
| 54 |
+
set(DepBLAS_FOUND ON)
|
| 55 |
+
set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
|
| 56 |
+
endif()
|
| 57 |
+
endif()
|
| 58 |
+
if (DepBLAS_FOUND)
|
| 59 |
+
set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
|
| 60 |
+
else()
|
| 61 |
+
message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
|
| 62 |
+
" detected by pkgconfig, trying to find cblas.h from possible paths...")
|
| 63 |
+
find_path(BLAS_INCLUDE_DIRS
|
| 64 |
+
NAMES cblas.h
|
| 65 |
+
HINTS
|
| 66 |
+
/usr/include
|
| 67 |
+
/usr/local/include
|
| 68 |
+
/usr/include/openblas
|
| 69 |
+
/opt/homebrew/opt/openblas/include
|
| 70 |
+
/usr/local/opt/openblas/include
|
| 71 |
+
/usr/include/x86_64-linux-gnu/openblas/include
|
| 72 |
+
)
|
| 73 |
+
endif()
|
| 74 |
+
endif()
|
| 75 |
+
|
| 76 |
+
message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
|
| 77 |
+
|
| 78 |
+
#add_compile_options(${BLAS_LINKER_FLAGS})
|
| 79 |
+
target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS})
|
| 80 |
+
|
| 81 |
+
if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel"))
|
| 82 |
+
add_compile_definitions(GGML_BLAS_USE_MKL)
|
| 83 |
+
endif()
|
| 84 |
+
|
| 85 |
+
target_link_libraries (ggml-blas PRIVATE ${BLAS_LIBRARIES})
|
| 86 |
+
target_include_directories(ggml-blas PRIVATE ${BLAS_INCLUDE_DIRS})
|
| 87 |
+
else()
|
| 88 |
+
message(ERROR "BLAS not found, please refer to "
|
| 89 |
+
"https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
|
| 90 |
+
" to set correct GGML_BLAS_VENDOR")
|
| 91 |
+
endif()
|
ggml/src/ggml-blas/ggml-blas.cpp
ADDED
|
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ggml-impl.h"
|
| 2 |
+
#include "ggml-blas.h"
|
| 3 |
+
#include "ggml-backend-impl.h"
|
| 4 |
+
|
| 5 |
+
#include <future>
|
| 6 |
+
#include <vector>
|
| 7 |
+
#include <cstring>
|
| 8 |
+
|
| 9 |
+
#if defined(GGML_BLAS_USE_ACCELERATE)
|
| 10 |
+
# include <Accelerate/Accelerate.h>
|
| 11 |
+
#elif defined(GGML_BLAS_USE_MKL)
|
| 12 |
+
# include <mkl.h>
|
| 13 |
+
#elif defined(GGML_BLAS_USE_BLIS)
|
| 14 |
+
# include <blis.h>
|
| 15 |
+
#elif defined(GGML_BLAS_USE_NVPL)
|
| 16 |
+
# include <nvpl_blas.h>
|
| 17 |
+
#else
|
| 18 |
+
# include <cblas.h>
|
| 19 |
+
#endif
|
| 20 |
+
|
| 21 |
+
struct ggml_backend_blas_context {
|
| 22 |
+
int n_threads = GGML_DEFAULT_N_THREADS;
|
| 23 |
+
std::unique_ptr<char[]> work_data;
|
| 24 |
+
size_t work_size = 0;
|
| 25 |
+
#ifndef GGML_USE_OPENMP
|
| 26 |
+
std::vector<std::future<void>> tasks;
|
| 27 |
+
#endif
|
| 28 |
+
};
|
| 29 |
+
|
| 30 |
+
static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
|
| 31 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 32 |
+
const struct ggml_tensor * src1 = dst->src[1];
|
| 33 |
+
|
| 34 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 35 |
+
|
| 36 |
+
const enum ggml_type type = src0->type;
|
| 37 |
+
|
| 38 |
+
GGML_ASSERT(ne0 == ne01);
|
| 39 |
+
GGML_ASSERT(ne1 == ne11);
|
| 40 |
+
GGML_ASSERT(ne2 == ne12);
|
| 41 |
+
GGML_ASSERT(ne3 == ne13);
|
| 42 |
+
|
| 43 |
+
// we don't support permuted src0 or src1
|
| 44 |
+
GGML_ASSERT(nb00 == ggml_type_size(type));
|
| 45 |
+
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
|
| 46 |
+
|
| 47 |
+
// dst cannot be transposed or permuted
|
| 48 |
+
GGML_ASSERT(nb0 == sizeof(float));
|
| 49 |
+
GGML_ASSERT(nb0 <= nb1);
|
| 50 |
+
GGML_ASSERT(nb1 <= nb2);
|
| 51 |
+
GGML_ASSERT(nb2 <= nb3);
|
| 52 |
+
|
| 53 |
+
// broadcast factors
|
| 54 |
+
const int64_t r2 = ne12/ne02;
|
| 55 |
+
const int64_t r3 = ne13/ne03;
|
| 56 |
+
|
| 57 |
+
const int64_t ne_plane = ne01*ne00;
|
| 58 |
+
const size_t desired_wsize = type == GGML_TYPE_F32 ? 0 : ne03*ne02*ne_plane*sizeof(float);
|
| 59 |
+
|
| 60 |
+
if (ctx->work_size < desired_wsize) {
|
| 61 |
+
ctx->work_data.reset(new char[desired_wsize]);
|
| 62 |
+
ctx->work_size = desired_wsize;
|
| 63 |
+
}
|
| 64 |
+
void * wdata = ctx->work_data.get();
|
| 65 |
+
|
| 66 |
+
// convert src0 to float
|
| 67 |
+
if (type != GGML_TYPE_F32) {
|
| 68 |
+
const auto * type_traits = ggml_get_type_traits(type);
|
| 69 |
+
ggml_to_float_t const to_float = type_traits->to_float;
|
| 70 |
+
|
| 71 |
+
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
| 72 |
+
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
| 73 |
+
const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
|
| 74 |
+
float * const wplane = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
|
| 75 |
+
|
| 76 |
+
const int min_cols_per_thread = 4096;
|
| 77 |
+
const int min_rows_per_thread = std::max((int)(min_cols_per_thread/ne00), 1);
|
| 78 |
+
const int n_threads = std::max(std::min(ctx->n_threads, (int)(ne01/min_rows_per_thread)), 1);
|
| 79 |
+
|
| 80 |
+
#ifdef GGML_USE_OPENMP
|
| 81 |
+
#pragma omp parallel for num_threads(n_threads)
|
| 82 |
+
for (int64_t i01 = 0; i01 < ne01; i01++) {
|
| 83 |
+
to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
|
| 84 |
+
}
|
| 85 |
+
#else
|
| 86 |
+
for (int i = 1; i < n_threads; i++) {
|
| 87 |
+
const int64_t start = i*ne01/n_threads;
|
| 88 |
+
const int64_t end = (i + 1)*ne01/n_threads;
|
| 89 |
+
if (start < end) {
|
| 90 |
+
ctx->tasks.push_back(std::async(std::launch::async, [=]() {
|
| 91 |
+
for (int64_t i01 = start; i01 < end; i01++) {
|
| 92 |
+
to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
|
| 93 |
+
}
|
| 94 |
+
}));
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
{
|
| 98 |
+
// reuse the current thread for the first task
|
| 99 |
+
const int64_t start = 0;
|
| 100 |
+
const int64_t end = ne01/n_threads;
|
| 101 |
+
for (int64_t i01 = start; i01 < end; i01++) {
|
| 102 |
+
to_float((const char *) x + i01*nb01, wplane + i01*ne00, ne00);
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
#endif
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
#ifndef GGML_USE_OPENMP
|
| 110 |
+
// wait for all tasks to finish
|
| 111 |
+
for (auto & task : ctx->tasks) {
|
| 112 |
+
task.get();
|
| 113 |
+
}
|
| 114 |
+
ctx->tasks.clear();
|
| 115 |
+
#endif
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
#if defined(OPENBLAS_VERSION)
|
| 119 |
+
openblas_set_num_threads(ctx->n_threads);
|
| 120 |
+
#endif
|
| 121 |
+
|
| 122 |
+
#if defined(GGML_BLAS_USE_BLIS)
|
| 123 |
+
bli_thread_set_num_threads(ctx->n_threads);
|
| 124 |
+
#endif
|
| 125 |
+
|
| 126 |
+
#if defined(GGML_BLAS_USE_NVPL)
|
| 127 |
+
nvpl_blas_set_num_threads(ctx->n_threads);
|
| 128 |
+
#endif
|
| 129 |
+
|
| 130 |
+
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
| 131 |
+
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
| 132 |
+
const int64_t i03 = i13/r3;
|
| 133 |
+
const int64_t i02 = i12/r2;
|
| 134 |
+
|
| 135 |
+
const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
|
| 136 |
+
const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
|
| 137 |
+
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
|
| 138 |
+
|
| 139 |
+
if (type != GGML_TYPE_F32) {
|
| 140 |
+
x = (float *) wdata + i02*ne_plane + i03*ne02*ne_plane;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
|
| 144 |
+
ne1, ne01, ne10,
|
| 145 |
+
1.0f, y, ne10,
|
| 146 |
+
x, ne00,
|
| 147 |
+
0.0f, d, ne01);
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
|
| 153 |
+
const struct ggml_tensor * src0 = dst->src[0];
|
| 154 |
+
const struct ggml_tensor * src1 = dst->src[1];
|
| 155 |
+
|
| 156 |
+
GGML_TENSOR_BINARY_OP_LOCALS
|
| 157 |
+
|
| 158 |
+
GGML_ASSERT(ne0 == ne00);
|
| 159 |
+
GGML_ASSERT(ne1 == ne10);
|
| 160 |
+
GGML_ASSERT(ne2 == ne02);
|
| 161 |
+
GGML_ASSERT(ne02 == ne12);
|
| 162 |
+
GGML_ASSERT(ne3 == ne13);
|
| 163 |
+
GGML_ASSERT(ne03 == ne13);
|
| 164 |
+
|
| 165 |
+
// we don't support permuted src0 or src1
|
| 166 |
+
GGML_ASSERT(nb00 == sizeof(float));
|
| 167 |
+
|
| 168 |
+
// dst cannot be transposed or permuted
|
| 169 |
+
GGML_ASSERT(nb0 == sizeof(float));
|
| 170 |
+
// GGML_ASSERT(nb0 <= nb1);
|
| 171 |
+
// GGML_ASSERT(nb1 <= nb2);
|
| 172 |
+
// GGML_ASSERT(nb2 <= nb3);
|
| 173 |
+
|
| 174 |
+
// Arguments to ggml_compute_forward_out_prod (expressed as major,minor)
|
| 175 |
+
// src0: (k,n)
|
| 176 |
+
// src1: (k,m)
|
| 177 |
+
// dst: (m,n)
|
| 178 |
+
//
|
| 179 |
+
// Arguments to sgemm (see https://github.com/Reference-LAPACK/lapack/blob/master/BLAS/SRC/sgemm.f)
|
| 180 |
+
// Also expressed as (major,minor)
|
| 181 |
+
// a: (m,k): so src1 transposed
|
| 182 |
+
// b: (k,n): so src0
|
| 183 |
+
// c: (m,n)
|
| 184 |
+
//
|
| 185 |
+
// However, if ggml_is_transposed(src1) is true, then
|
| 186 |
+
// src1->data already contains a transposed version, so sgemm mustn't
|
| 187 |
+
// transpose it further.
|
| 188 |
+
|
| 189 |
+
int n = src0->ne[0];
|
| 190 |
+
int k = src0->ne[1];
|
| 191 |
+
int m = src1->ne[0];
|
| 192 |
+
|
| 193 |
+
CBLAS_TRANSPOSE transposeA;
|
| 194 |
+
int lda;
|
| 195 |
+
|
| 196 |
+
if (!ggml_is_transposed(src1)) {
|
| 197 |
+
transposeA = CblasTrans;
|
| 198 |
+
lda = m;
|
| 199 |
+
} else {
|
| 200 |
+
transposeA = CblasNoTrans;
|
| 201 |
+
lda = k;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
float * a = (float *) ((char *) src1->data);
|
| 205 |
+
float * b = (float *) ((char *) src0->data);
|
| 206 |
+
float * c = (float *) ((char *) dst->data);
|
| 207 |
+
|
| 208 |
+
cblas_sgemm(CblasRowMajor, transposeA, CblasNoTrans, m, n, k, 1.0, a, lda, b, n, 0.0, c, n);
|
| 209 |
+
|
| 210 |
+
GGML_UNUSED(ctx);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
// backend interface
|
| 214 |
+
|
| 215 |
+
static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
|
| 216 |
+
return "BLAS";
|
| 217 |
+
|
| 218 |
+
GGML_UNUSED(backend);
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
static void ggml_backend_blas_free(ggml_backend_t backend) {
|
| 222 |
+
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
|
| 223 |
+
delete ctx;
|
| 224 |
+
delete backend;
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
| 228 |
+
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend->context;
|
| 229 |
+
|
| 230 |
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 231 |
+
struct ggml_tensor * node = cgraph->nodes[i];
|
| 232 |
+
|
| 233 |
+
switch (node->op) {
|
| 234 |
+
case GGML_OP_MUL_MAT:
|
| 235 |
+
ggml_backend_blas_mul_mat(ctx, node);
|
| 236 |
+
break;
|
| 237 |
+
|
| 238 |
+
case GGML_OP_OUT_PROD:
|
| 239 |
+
ggml_backend_blas_out_prod(ctx, node);
|
| 240 |
+
break;
|
| 241 |
+
|
| 242 |
+
case GGML_OP_NONE:
|
| 243 |
+
case GGML_OP_RESHAPE:
|
| 244 |
+
case GGML_OP_VIEW:
|
| 245 |
+
case GGML_OP_PERMUTE:
|
| 246 |
+
case GGML_OP_TRANSPOSE:
|
| 247 |
+
break;
|
| 248 |
+
|
| 249 |
+
default:
|
| 250 |
+
GGML_ABORT("%s: unsupported op %s\n", __func__, ggml_op_desc(node));
|
| 251 |
+
}
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
return GGML_STATUS_SUCCESS;
|
| 255 |
+
|
| 256 |
+
GGML_UNUSED(backend);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
static struct ggml_backend_i blas_backend_i = {
|
| 260 |
+
/* .get_name = */ ggml_backend_blas_get_name,
|
| 261 |
+
/* .free = */ ggml_backend_blas_free,
|
| 262 |
+
/* .set_tensor_async = */ NULL,
|
| 263 |
+
/* .get_tensor_async = */ NULL,
|
| 264 |
+
/* .cpy_tensor_async = */ NULL,
|
| 265 |
+
/* .synchronize = */ NULL,
|
| 266 |
+
/* .graph_plan_create = */ NULL,
|
| 267 |
+
/* .graph_plan_free = */ NULL,
|
| 268 |
+
/* .graph_plan_update = */ NULL,
|
| 269 |
+
/* .graph_plan_compute = */ NULL,
|
| 270 |
+
/* .graph_compute = */ ggml_backend_blas_graph_compute,
|
| 271 |
+
/* .event_record = */ NULL,
|
| 272 |
+
/* .event_wait = */ NULL,
|
| 273 |
+
};
|
| 274 |
+
|
| 275 |
+
static ggml_guid_t ggml_backend_blas_guid(void) {
|
| 276 |
+
static ggml_guid guid = { 0x12, 0xa8, 0xae, 0xf4, 0xc0, 0x1e, 0x61, 0x97, 0x8f, 0xeb, 0x33, 0x04, 0xa1, 0x33, 0x51, 0x2d };
|
| 277 |
+
return &guid;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
ggml_backend_t ggml_backend_blas_init(void) {
|
| 281 |
+
ggml_backend_blas_context * ctx = new ggml_backend_blas_context;
|
| 282 |
+
|
| 283 |
+
ggml_backend_t backend = new ggml_backend {
|
| 284 |
+
/* .guid = */ ggml_backend_blas_guid(),
|
| 285 |
+
/* .interface = */ blas_backend_i,
|
| 286 |
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
|
| 287 |
+
/* .context = */ ctx,
|
| 288 |
+
};
|
| 289 |
+
|
| 290 |
+
#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
|
| 291 |
+
if (openblas_get_parallel() != OPENBLAS_OPENMP) {
|
| 292 |
+
GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
|
| 293 |
+
}
|
| 294 |
+
#endif
|
| 295 |
+
|
| 296 |
+
#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
|
| 297 |
+
GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
|
| 298 |
+
#endif
|
| 299 |
+
|
| 300 |
+
return backend;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
bool ggml_backend_is_blas(ggml_backend_t backend) {
|
| 304 |
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_blas_guid());
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads) {
|
| 308 |
+
GGML_ASSERT(ggml_backend_is_blas(backend_blas));
|
| 309 |
+
|
| 310 |
+
ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
|
| 311 |
+
ctx->n_threads = n_threads;
|
| 312 |
+
}
|
| 313 |
+
|
| 314 |
+
// device interface
|
| 315 |
+
|
| 316 |
+
static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
|
| 317 |
+
return "BLAS";
|
| 318 |
+
|
| 319 |
+
GGML_UNUSED(dev);
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
|
| 323 |
+
#if defined(GGML_BLAS_USE_ACCELERATE)
|
| 324 |
+
return "Accelerate";
|
| 325 |
+
#elif defined(GGML_BLAS_USE_MKL)
|
| 326 |
+
return "MKL";
|
| 327 |
+
#elif defined(GGML_BLAS_USE_BLIS)
|
| 328 |
+
return "BLIS";
|
| 329 |
+
#elif defined(GGML_BLAS_USE_NVPL)
|
| 330 |
+
return "NVPL";
|
| 331 |
+
#elif defined(OPENBLAS_VERSION)
|
| 332 |
+
return "OpenBLAS";
|
| 333 |
+
#else
|
| 334 |
+
return "BLAS";
|
| 335 |
+
#endif
|
| 336 |
+
|
| 337 |
+
GGML_UNUSED(dev);
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
| 341 |
+
// TODO
|
| 342 |
+
*free = 0;
|
| 343 |
+
*total = 0;
|
| 344 |
+
|
| 345 |
+
GGML_UNUSED(dev);
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
|
| 349 |
+
return GGML_BACKEND_DEVICE_TYPE_ACCEL;
|
| 350 |
+
|
| 351 |
+
GGML_UNUSED(dev);
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
| 355 |
+
props->name = ggml_backend_blas_device_get_name(dev);
|
| 356 |
+
props->description = ggml_backend_blas_device_get_description(dev);
|
| 357 |
+
props->type = ggml_backend_blas_device_get_type(dev);
|
| 358 |
+
ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
| 359 |
+
props->caps = {
|
| 360 |
+
/* .async = */ false,
|
| 361 |
+
/* .host_buffer = */ false,
|
| 362 |
+
/* .buffer_from_host_ptr = */ true,
|
| 363 |
+
/* .events = */ false,
|
| 364 |
+
};
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
static ggml_backend_t ggml_backend_blas_device_init_backend(ggml_backend_dev_t dev, const char * params) {
|
| 368 |
+
return ggml_backend_blas_init();
|
| 369 |
+
|
| 370 |
+
GGML_UNUSED(dev);
|
| 371 |
+
GGML_UNUSED(params);
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
|
| 375 |
+
return ggml_backend_cpu_buffer_type();
|
| 376 |
+
|
| 377 |
+
GGML_UNUSED(dev);
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
| 381 |
+
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
| 382 |
+
|
| 383 |
+
GGML_UNUSED(dev);
|
| 384 |
+
GGML_UNUSED(max_tensor_size);
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
| 388 |
+
const struct ggml_tensor * src0 = op->src[0];
|
| 389 |
+
const struct ggml_tensor * src1 = op->src[1];
|
| 390 |
+
|
| 391 |
+
switch (op->op) {
|
| 392 |
+
case GGML_OP_NONE:
|
| 393 |
+
case GGML_OP_RESHAPE:
|
| 394 |
+
case GGML_OP_VIEW:
|
| 395 |
+
case GGML_OP_PERMUTE:
|
| 396 |
+
case GGML_OP_TRANSPOSE:
|
| 397 |
+
return true;
|
| 398 |
+
|
| 399 |
+
case GGML_OP_MUL_MAT:
|
| 400 |
+
{
|
| 401 |
+
// BLAS usually is only faster for large matrices
|
| 402 |
+
const struct ggml_tensor * src0 = op->src[0];
|
| 403 |
+
const struct ggml_tensor * src1 = op->src[1];
|
| 404 |
+
|
| 405 |
+
const int64_t ne10 = src1->ne[0];
|
| 406 |
+
|
| 407 |
+
const int64_t ne0 = op->ne[0];
|
| 408 |
+
const int64_t ne1 = op->ne[1];
|
| 409 |
+
|
| 410 |
+
// TODO: find the optimal value
|
| 411 |
+
const int64_t min_batch = 32;
|
| 412 |
+
|
| 413 |
+
return ggml_is_contiguous(src0) &&
|
| 414 |
+
ggml_is_contiguous(src1) &&
|
| 415 |
+
src1->type == GGML_TYPE_F32 &&
|
| 416 |
+
(ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
|
| 417 |
+
(src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
case GGML_OP_OUT_PROD:
|
| 421 |
+
return op->src[0]->type == GGML_TYPE_F32 &&
|
| 422 |
+
op->src[1]->type == GGML_TYPE_F32 &&
|
| 423 |
+
ggml_is_matrix(src0) &&
|
| 424 |
+
ggml_is_matrix(src1) &&
|
| 425 |
+
ggml_is_contiguous(src0) &&
|
| 426 |
+
(ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
|
| 427 |
+
(src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
|
| 428 |
+
|
| 429 |
+
default:
|
| 430 |
+
return false;
|
| 431 |
+
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
GGML_UNUSED(dev);
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
| 438 |
+
return ggml_backend_buft_is_host(buft);
|
| 439 |
+
|
| 440 |
+
GGML_UNUSED(dev);
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
|
| 444 |
+
/* .get_name = */ ggml_backend_blas_device_get_name,
|
| 445 |
+
/* .get_description = */ ggml_backend_blas_device_get_description,
|
| 446 |
+
/* .get_memory = */ ggml_backend_blas_device_get_memory,
|
| 447 |
+
/* .get_type = */ ggml_backend_blas_device_get_type,
|
| 448 |
+
/* .get_props = */ ggml_backend_blas_device_get_props,
|
| 449 |
+
/* .init_backend = */ ggml_backend_blas_device_init_backend,
|
| 450 |
+
/* .get_buffer_type = */ ggml_backend_blas_device_get_buffer_type,
|
| 451 |
+
/* .get_host_buffer_type = */ NULL,
|
| 452 |
+
/* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_host_ptr,
|
| 453 |
+
/* .supports_op = */ ggml_backend_blas_device_supports_op,
|
| 454 |
+
/* .supports_buft = */ ggml_backend_blas_device_supports_buft,
|
| 455 |
+
/* .offload_op = */ NULL,
|
| 456 |
+
/* .event_new = */ NULL,
|
| 457 |
+
/* .event_free = */ NULL,
|
| 458 |
+
/* .event_synchronize = */ NULL,
|
| 459 |
+
};
|
| 460 |
+
|
| 461 |
+
// backend reg interface
|
| 462 |
+
|
| 463 |
+
static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
|
| 464 |
+
return "BLAS";
|
| 465 |
+
|
| 466 |
+
GGML_UNUSED(reg);
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
|
| 470 |
+
return 1;
|
| 471 |
+
|
| 472 |
+
GGML_UNUSED(reg);
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
| 476 |
+
GGML_ASSERT(index == 0);
|
| 477 |
+
|
| 478 |
+
static ggml_backend_device ggml_backend_blas_device = {
|
| 479 |
+
/* .iface = */ ggml_backend_blas_device_i,
|
| 480 |
+
/* .reg = */ reg,
|
| 481 |
+
/* .context = */ nullptr,
|
| 482 |
+
};
|
| 483 |
+
|
| 484 |
+
return &ggml_backend_blas_device;
|
| 485 |
+
|
| 486 |
+
GGML_UNUSED(reg);
|
| 487 |
+
GGML_UNUSED(index);
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
| 491 |
+
if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
| 492 |
+
return (void *)ggml_backend_blas_set_n_threads;
|
| 493 |
+
}
|
| 494 |
+
return NULL;
|
| 495 |
+
|
| 496 |
+
GGML_UNUSED(reg);
|
| 497 |
+
GGML_UNUSED(name);
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
|
| 501 |
+
/* .get_name = */ ggml_backend_blas_reg_get_name,
|
| 502 |
+
/* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
|
| 503 |
+
/* .get_device = */ ggml_backend_blas_reg_get_device,
|
| 504 |
+
/* .get_proc_address = */ ggml_backend_blas_get_proc_address,
|
| 505 |
+
};
|
| 506 |
+
|
| 507 |
+
ggml_backend_reg_t ggml_backend_blas_reg(void) {
|
| 508 |
+
static struct ggml_backend_reg ggml_backend_blas_reg = {
|
| 509 |
+
/* .iface = */ ggml_backend_blas_reg_i,
|
| 510 |
+
/* .context = */ NULL,
|
| 511 |
+
};
|
| 512 |
+
|
| 513 |
+
return &ggml_backend_blas_reg;
|
| 514 |
+
}
|
ggml/src/ggml-cann/CMakeLists.txt
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if ("cann${CANN_INSTALL_DIR}" STREQUAL "cann" AND DEFINED ENV{ASCEND_TOOLKIT_HOME})
|
| 2 |
+
set(CANN_INSTALL_DIR $ENV{ASCEND_TOOLKIT_HOME})
|
| 3 |
+
message(STATUS "CANN: updated CANN_INSTALL_DIR from ASCEND_TOOLKIT_HOME=$ENV{ASCEND_TOOLKIT_HOME}")
|
| 4 |
+
endif()
|
| 5 |
+
|
| 6 |
+
if (CANN_INSTALL_DIR)
|
| 7 |
+
# Only Support Linux.
|
| 8 |
+
if (NOT UNIX)
|
| 9 |
+
message(FATAL_ERROR "CANN: CANN toolkit supports unix but not ${CMAKE_SYSTEM_NAME}")
|
| 10 |
+
endif()
|
| 11 |
+
|
| 12 |
+
# Supported platforms: x86-64, arm64
|
| 13 |
+
if (CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64")
|
| 14 |
+
elseif (CMAKE_SYSTEM_PROCESSOR STREQUAL "x86_64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "amd64")
|
| 15 |
+
else()
|
| 16 |
+
message(FATAL_ERROR "CANN: CANN toolkit supports x86-64 and arm64 but not ${CMAKE_SYSTEM_PROCESSOR}")
|
| 17 |
+
endif()
|
| 18 |
+
|
| 19 |
+
# Set header and libs
|
| 20 |
+
set(CANN_INCLUDE_DIRS
|
| 21 |
+
${CANN_INSTALL_DIR}/include
|
| 22 |
+
${CANN_INSTALL_DIR}/include/aclnn
|
| 23 |
+
${CANN_INSTALL_DIR}/acllib/include
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
add_subdirectory(kernels)
|
| 27 |
+
list(APPEND CANN_LIBRARIES
|
| 28 |
+
ascendcl
|
| 29 |
+
nnopbase
|
| 30 |
+
opapi
|
| 31 |
+
acl_op_compiler
|
| 32 |
+
ascendc_kernels
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
file(GLOB GGML_SOURCES_CANN "*.cpp")
|
| 36 |
+
|
| 37 |
+
add_library(ggml-cann ${GGML_SOURCES_CANN})
|
| 38 |
+
target_link_libraries(ggml-cann PRIVATE ggml-base ${CANN_LIBRARIES})
|
| 39 |
+
target_include_directories(ggml-cann PRIVATE . .. ${CANN_INCLUDE_DIRS})
|
| 40 |
+
target_link_directories(ggml-cann PRIVATE ${CANN_INSTALL_DIR}/lib64)
|
| 41 |
+
|
| 42 |
+
message(STATUS "CANN: CANN_INCLUDE_DIRS = ${CANN_INCLUDE_DIRS}")
|
| 43 |
+
message(STATUS "CANN: CANN_LIBRARIES = ${CANN_LIBRARIES}")
|
| 44 |
+
else()
|
| 45 |
+
message(FATAL_ERROR "CANN: Can't find CANN_INSTALL_DIR, did you forget to source set_var.sh?")
|
| 46 |
+
endif()
|
ggml/src/ggml-cann/ggml-cann.cpp
ADDED
|
@@ -0,0 +1,2128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/*
|
| 2 |
+
* Copyright (c) 2023-2024 The ggml authors
|
| 3 |
+
*
|
| 4 |
+
* Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 5 |
+
* of this software and associated documentation files (the "Software"), to
|
| 6 |
+
* deal in the Software without restriction, including without limitation the
|
| 7 |
+
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
|
| 8 |
+
* sell copies of the Software, and to permit persons to whom the Software is
|
| 9 |
+
* furnished to do so, subject to the following conditions:
|
| 10 |
+
*
|
| 11 |
+
* The above copyright notice and this permission notice shall be included in
|
| 12 |
+
* all copies or substantial portions of the Software.
|
| 13 |
+
*
|
| 14 |
+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 15 |
+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 16 |
+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 17 |
+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 18 |
+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
| 19 |
+
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
|
| 20 |
+
* IN THE SOFTWARE.
|
| 21 |
+
*/
|
| 22 |
+
|
| 23 |
+
#include "ggml-cann.h"
|
| 24 |
+
|
| 25 |
+
#include <acl/acl.h>
|
| 26 |
+
#include <stdarg.h>
|
| 27 |
+
|
| 28 |
+
#include <cmath>
|
| 29 |
+
#include <cstdio>
|
| 30 |
+
#include <cstring>
|
| 31 |
+
#include <mutex>
|
| 32 |
+
|
| 33 |
+
#include "ggml-impl.h"
|
| 34 |
+
#include "ggml-backend-impl.h"
|
| 35 |
+
#include "ggml-cann/aclnn_ops.h"
|
| 36 |
+
#include "ggml-cann/common.h"
|
| 37 |
+
|
| 38 |
+
#define GGML_COMMON_DECL_C
|
| 39 |
+
|
| 40 |
+
#include "ggml-common.h"
|
| 41 |
+
|
| 42 |
+
#define GGML_CANN_NAME "CANN"
|
| 43 |
+
|
| 44 |
+
/**
|
| 45 |
+
* @brief Handles CANN errors by printing an error message and aborting.
|
| 46 |
+
*
|
| 47 |
+
* @param stmt The statement that caused the error.
|
| 48 |
+
* @param func The function in which the error occurred.
|
| 49 |
+
* @param file The file in which the error occurred.
|
| 50 |
+
* @param line The line number where the error occurred.
|
| 51 |
+
* @param msg The error message.
|
| 52 |
+
*/
|
| 53 |
+
[[noreturn]] void ggml_cann_error(const char* stmt, const char* func,
|
| 54 |
+
const char* file, int line, const char* msg) {
|
| 55 |
+
int32_t id = -1;
|
| 56 |
+
aclrtGetDevice(&id);
|
| 57 |
+
|
| 58 |
+
GGML_LOG_ERROR("CANN error: %s\n", msg);
|
| 59 |
+
GGML_LOG_ERROR(" current device: %d, in function %s at %s:%d\n", id, func,
|
| 60 |
+
file, line);
|
| 61 |
+
GGML_LOG_ERROR(" %s\n", stmt);
|
| 62 |
+
// abort with GGML_ASSERT to get a stack trace
|
| 63 |
+
GGML_ABORT("CANN error");
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
/**
|
| 67 |
+
* @brief Sets the device to be used by CANN.
|
| 68 |
+
*
|
| 69 |
+
* @param device The device ID to set.
|
| 70 |
+
*/
|
| 71 |
+
void ggml_cann_set_device(const int32_t device) {
|
| 72 |
+
// TODO: uncomment these lines after empty context has fixed.
|
| 73 |
+
// int current_device;
|
| 74 |
+
// ACL_CHECK(aclrtGetDevice(¤t_device));
|
| 75 |
+
|
| 76 |
+
// if (device == current_device) {
|
| 77 |
+
// return;
|
| 78 |
+
// }
|
| 79 |
+
ACL_CHECK(aclrtSetDevice(device));
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
/**
|
| 83 |
+
* @brief Retrieves the current device ID.
|
| 84 |
+
*
|
| 85 |
+
* @return The current device ID.
|
| 86 |
+
*/
|
| 87 |
+
int32_t ggml_cann_get_device() {
|
| 88 |
+
int32_t id;
|
| 89 |
+
ACL_CHECK(aclrtGetDevice(&id));
|
| 90 |
+
return id;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
/**
|
| 94 |
+
* @brief Initialize the CANN device information.
|
| 95 |
+
*
|
| 96 |
+
* This function initializes the CANN device information by obtaining the
|
| 97 |
+
* device count and setting the memory allocation granularity for each device.
|
| 98 |
+
*
|
| 99 |
+
* @return A structure containing the device information.
|
| 100 |
+
*/
|
| 101 |
+
static ggml_cann_device_info ggml_cann_init() {
|
| 102 |
+
ggml_cann_device_info info = {};
|
| 103 |
+
|
| 104 |
+
aclError err = aclrtGetDeviceCount((uint32_t*)&info.device_count);
|
| 105 |
+
|
| 106 |
+
if (err != ACL_SUCCESS) {
|
| 107 |
+
GGML_LOG_ERROR("%s: failed to initialize CANN: %s\n",
|
| 108 |
+
__func__, aclGetRecentErrMsg());
|
| 109 |
+
return info;
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
GGML_ASSERT(info.device_count <= GGML_CANN_MAX_DEVICES);
|
| 113 |
+
|
| 114 |
+
for (int id = 0; id < info.device_count; ++id) {
|
| 115 |
+
aclrtPhysicalMemProp prop = {};
|
| 116 |
+
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
|
| 117 |
+
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
|
| 118 |
+
prop.memAttr = ACL_HBM_MEM_HUGE;
|
| 119 |
+
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
|
| 120 |
+
prop.location.id = id;
|
| 121 |
+
prop.reserve = 0;
|
| 122 |
+
ACL_CHECK(aclrtMemGetAllocationGranularity(
|
| 123 |
+
&prop, ACL_RT_MEM_ALLOC_GRANULARITY_RECOMMENDED,
|
| 124 |
+
&info.devices[id].vmm_granularity));
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
// TODO: add more device info later.
|
| 128 |
+
return info;
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
/**
|
| 132 |
+
* @brief Retrieve the CANN device information.
|
| 133 |
+
*
|
| 134 |
+
* This function returns a reference to a structure containing the CANN device
|
| 135 |
+
* information. The device information is initialized once and reused on
|
| 136 |
+
* subsequent calls.
|
| 137 |
+
*
|
| 138 |
+
* @return A reference to the structure containing the device information.
|
| 139 |
+
*/
|
| 140 |
+
const ggml_cann_device_info& ggml_cann_info() {
|
| 141 |
+
static ggml_cann_device_info info = ggml_cann_init();
|
| 142 |
+
return info;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
//#define DEBUG_CANN_MALLOC
|
| 146 |
+
/**
|
| 147 |
+
* @brief A pool of CANN buffers(legacy).
|
| 148 |
+
*
|
| 149 |
+
* This class manages a pool of CANN buffers for a specific device.
|
| 150 |
+
*/
|
| 151 |
+
struct ggml_cann_pool_leg : public ggml_cann_pool {
|
| 152 |
+
/**
|
| 153 |
+
* @brief The maximum number of buffers in the pool.
|
| 154 |
+
*/
|
| 155 |
+
static const int MAX_BUFFERS = 256;
|
| 156 |
+
|
| 157 |
+
/**
|
| 158 |
+
* @brief The device ID associated with this buffer pool.
|
| 159 |
+
*/
|
| 160 |
+
int device;
|
| 161 |
+
|
| 162 |
+
/**
|
| 163 |
+
* @brief Structure representing a CANN buffer.
|
| 164 |
+
*/
|
| 165 |
+
struct ggml_cann_buffer {
|
| 166 |
+
void* ptr = nullptr; ///< Pointer to the buffer memory.
|
| 167 |
+
size_t size = 0; ///< Size of the buffer.
|
| 168 |
+
};
|
| 169 |
+
|
| 170 |
+
/**
|
| 171 |
+
* @brief Array of CANN buffers in the pool.
|
| 172 |
+
*/
|
| 173 |
+
ggml_cann_buffer buffer_pool[MAX_BUFFERS] = {};
|
| 174 |
+
|
| 175 |
+
/**
|
| 176 |
+
* @brief Total size of all buffers in the pool.
|
| 177 |
+
*/
|
| 178 |
+
size_t pool_size = 0;
|
| 179 |
+
|
| 180 |
+
/**
|
| 181 |
+
* @brief Constructor to initialize the buffer pool for a specific device.
|
| 182 |
+
*
|
| 183 |
+
* @param device The device ID to associate with this buffer pool.
|
| 184 |
+
*/
|
| 185 |
+
explicit ggml_cann_pool_leg(int device) : device(device) {}
|
| 186 |
+
|
| 187 |
+
/**
|
| 188 |
+
* @brief Destructor to free all buffers in the pool.
|
| 189 |
+
*/
|
| 190 |
+
~ggml_cann_pool_leg() {
|
| 191 |
+
ggml_cann_set_device(device);
|
| 192 |
+
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
| 193 |
+
ggml_cann_buffer& b = buffer_pool[i];
|
| 194 |
+
if (b.ptr != nullptr) {
|
| 195 |
+
ACL_CHECK(aclrtFree(b.ptr));
|
| 196 |
+
pool_size -= b.size;
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
GGML_ASSERT(pool_size == 0);
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
/**
|
| 203 |
+
* @brief Allocate a buffer of the given size.
|
| 204 |
+
*
|
| 205 |
+
* @param size The size of the buffer to allocate.
|
| 206 |
+
* @param actual_size A pointer to a variable to receive the actual size of
|
| 207 |
+
* the allocated buffer.
|
| 208 |
+
* @return A pointer to the allocated buffer.
|
| 209 |
+
*/
|
| 210 |
+
void* alloc(size_t size, size_t* actual_size) override {
|
| 211 |
+
#ifdef DEBUG_CANN_MALLOC
|
| 212 |
+
int nnz = 0;
|
| 213 |
+
size_t max_size = 0;
|
| 214 |
+
#endif
|
| 215 |
+
size_t best_diff = 1ull << 36;
|
| 216 |
+
int ibest = -1;
|
| 217 |
+
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
| 218 |
+
ggml_cann_buffer& b = buffer_pool[i];
|
| 219 |
+
if (b.ptr != nullptr) {
|
| 220 |
+
#ifdef DEBUG_CANN_MALLOC
|
| 221 |
+
++nnz;
|
| 222 |
+
if (b.size > max_size) max_size = b.size;
|
| 223 |
+
#endif
|
| 224 |
+
if (b.size >= size) {
|
| 225 |
+
size_t diff = b.size - size;
|
| 226 |
+
if (diff < best_diff) {
|
| 227 |
+
best_diff = diff;
|
| 228 |
+
ibest = i;
|
| 229 |
+
if (!best_diff) {
|
| 230 |
+
void* ptr = b.ptr;
|
| 231 |
+
*actual_size = b.size;
|
| 232 |
+
b.ptr = nullptr;
|
| 233 |
+
b.size = 0;
|
| 234 |
+
return ptr;
|
| 235 |
+
}
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
if (ibest >= 0) {
|
| 241 |
+
ggml_cann_buffer& b = buffer_pool[ibest];
|
| 242 |
+
void* ptr = b.ptr;
|
| 243 |
+
*actual_size = b.size;
|
| 244 |
+
b.ptr = nullptr;
|
| 245 |
+
b.size = 0;
|
| 246 |
+
return ptr;
|
| 247 |
+
}
|
| 248 |
+
void* ptr;
|
| 249 |
+
size_t look_ahead_size = (size_t)(1.05 * size);
|
| 250 |
+
look_ahead_size = 256 * ((look_ahead_size + 255) / 256);
|
| 251 |
+
ggml_cann_set_device(device);
|
| 252 |
+
ACL_CHECK(
|
| 253 |
+
aclrtMalloc(&ptr, look_ahead_size, ACL_MEM_MALLOC_HUGE_FIRST));
|
| 254 |
+
*actual_size = look_ahead_size;
|
| 255 |
+
pool_size += look_ahead_size;
|
| 256 |
+
#ifdef DEBUG_CANN_MALLOC
|
| 257 |
+
GGML_LOG_INFO(
|
| 258 |
+
"%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, "
|
| 259 |
+
"requested %u MB\n",
|
| 260 |
+
__func__, device, nnz, (uint32_t)(max_size / 1024 / 1024),
|
| 261 |
+
(uint32_t)(pool_size / 1024 / 1024),
|
| 262 |
+
(uint32_t)(size / 1024 / 1024));
|
| 263 |
+
#endif
|
| 264 |
+
return ptr;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
/**
|
| 268 |
+
* @brief Free a buffer and return it to the pool.
|
| 269 |
+
*
|
| 270 |
+
* @param ptr Pointer to the buffer to free.
|
| 271 |
+
* @param size Size of the buffer to free.
|
| 272 |
+
*/
|
| 273 |
+
void free(void* ptr, size_t size) override {
|
| 274 |
+
for (int i = 0; i < MAX_BUFFERS; ++i) {
|
| 275 |
+
ggml_cann_buffer& b = buffer_pool[i];
|
| 276 |
+
if (b.ptr == nullptr) {
|
| 277 |
+
b.ptr = ptr;
|
| 278 |
+
b.size = size;
|
| 279 |
+
return;
|
| 280 |
+
}
|
| 281 |
+
}
|
| 282 |
+
// memory should always buffered. these memory may still needed by
|
| 283 |
+
// tasks in stream.
|
| 284 |
+
// TODO, fix me.
|
| 285 |
+
GGML_ABORT("Cann buffer pool full, increase MAX_CANN_BUFFERS\n");
|
| 286 |
+
}
|
| 287 |
+
};
|
| 288 |
+
|
| 289 |
+
/**
|
| 290 |
+
* @brief A pool of CANN buffers with virtual memory.
|
| 291 |
+
*
|
| 292 |
+
* This class manages a pool of CANN buffers with virtual memory for a specific
|
| 293 |
+
* device.
|
| 294 |
+
*/
|
| 295 |
+
struct ggml_cann_pool_vmm : public ggml_cann_pool {
|
| 296 |
+
/**
|
| 297 |
+
* @brief The maximum size of the virtual memory pool (32 GB).
|
| 298 |
+
*/
|
| 299 |
+
static const size_t CANN_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
| 300 |
+
|
| 301 |
+
/**
|
| 302 |
+
* @brief The device ID associated with this buffer pool.
|
| 303 |
+
*/
|
| 304 |
+
int device;
|
| 305 |
+
|
| 306 |
+
/**
|
| 307 |
+
* @brief Pointer to the start of the virtual memory pool.
|
| 308 |
+
*/
|
| 309 |
+
void* pool_addr = 0;
|
| 310 |
+
|
| 311 |
+
/**
|
| 312 |
+
* @brief Amount of virtual memory used in the pool.
|
| 313 |
+
*/
|
| 314 |
+
size_t pool_used = 0;
|
| 315 |
+
|
| 316 |
+
/**
|
| 317 |
+
* @brief Total size of the virtual memory pool.
|
| 318 |
+
*/
|
| 319 |
+
size_t pool_size = 0;
|
| 320 |
+
|
| 321 |
+
/**
|
| 322 |
+
* @brief Allocation granularity for the virtual memory pool.
|
| 323 |
+
*/
|
| 324 |
+
size_t granularity;
|
| 325 |
+
|
| 326 |
+
/**
|
| 327 |
+
* @brief Handles for the physical memory allocated.
|
| 328 |
+
*/
|
| 329 |
+
std::vector<aclrtDrvMemHandle> handles;
|
| 330 |
+
|
| 331 |
+
/**
|
| 332 |
+
* @brief Offsets for the mapped memory regions.
|
| 333 |
+
*/
|
| 334 |
+
std::vector<void*> map_offsets;
|
| 335 |
+
|
| 336 |
+
/**
|
| 337 |
+
* @brief Constructor to initialize the buffer pool with virtual memory for
|
| 338 |
+
* a specific device.
|
| 339 |
+
*
|
| 340 |
+
* @param device The device ID to associate with this buffer pool.
|
| 341 |
+
*/
|
| 342 |
+
explicit ggml_cann_pool_vmm(int device)
|
| 343 |
+
: device(device),
|
| 344 |
+
granularity(ggml_cann_info().devices[device].vmm_granularity) {}
|
| 345 |
+
|
| 346 |
+
/**
|
| 347 |
+
* @brief Destructor to free all buffers in the virtual memory pool.
|
| 348 |
+
*/
|
| 349 |
+
~ggml_cann_pool_vmm() {
|
| 350 |
+
if (pool_addr != 0) {
|
| 351 |
+
for (auto& offset : map_offsets) {
|
| 352 |
+
ACL_CHECK(aclrtUnmapMem(offset));
|
| 353 |
+
}
|
| 354 |
+
for (auto& handle : handles) {
|
| 355 |
+
ACL_CHECK(aclrtFreePhysical(handle));
|
| 356 |
+
}
|
| 357 |
+
ACL_CHECK(aclrtReleaseMemAddress(pool_addr));
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
/**
|
| 362 |
+
* @brief Allocate a buffer of the given size in the virtual memory pool.
|
| 363 |
+
*
|
| 364 |
+
* @param size The size of the buffer to allocate.
|
| 365 |
+
* @param actual_size A pointer to a variable to receive the actual size of
|
| 366 |
+
* the allocated buffer.
|
| 367 |
+
* @return A pointer to the allocated buffer.
|
| 368 |
+
*/
|
| 369 |
+
void* alloc(size_t size, size_t* actual_size) override {
|
| 370 |
+
// round up the allocation size to the alignment to ensure that all
|
| 371 |
+
// allocations are aligned for all data types
|
| 372 |
+
const size_t alignment = 128;
|
| 373 |
+
size = alignment * ((size + alignment - 1) / alignment);
|
| 374 |
+
|
| 375 |
+
size_t avail = pool_size - pool_used;
|
| 376 |
+
|
| 377 |
+
if (size > avail) {
|
| 378 |
+
// round up to the next multiple of the granularity
|
| 379 |
+
size_t reserve_size = size - avail;
|
| 380 |
+
reserve_size =
|
| 381 |
+
granularity * ((reserve_size + granularity - 1) / granularity);
|
| 382 |
+
|
| 383 |
+
GGML_ASSERT(pool_size + reserve_size <= CANN_POOL_VMM_MAX_SIZE);
|
| 384 |
+
|
| 385 |
+
// allocate more physical memory
|
| 386 |
+
aclrtPhysicalMemProp prop = {};
|
| 387 |
+
prop.handleType = ACL_MEM_HANDLE_TYPE_NONE;
|
| 388 |
+
prop.allocationType = ACL_MEM_ALLOCATION_TYPE_PINNED;
|
| 389 |
+
prop.memAttr = ACL_HBM_MEM_HUGE;
|
| 390 |
+
prop.location.type = ACL_MEM_LOCATION_TYPE_DEVICE;
|
| 391 |
+
prop.location.id = device;
|
| 392 |
+
prop.reserve = 0;
|
| 393 |
+
aclrtDrvMemHandle handle;
|
| 394 |
+
ACL_CHECK(aclrtMallocPhysical(&handle, reserve_size, &prop, 0));
|
| 395 |
+
|
| 396 |
+
// reserve virtual address space (if not already reserved)
|
| 397 |
+
if (pool_addr == 0) {
|
| 398 |
+
ACL_CHECK(aclrtReserveMemAddress(
|
| 399 |
+
&pool_addr, CANN_POOL_VMM_MAX_SIZE, 0, NULL, 1));
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
// map at the end of the pool
|
| 403 |
+
ACL_CHECK(aclrtMapMem((char*)pool_addr + pool_size, reserve_size, 0,
|
| 404 |
+
handle, 0));
|
| 405 |
+
|
| 406 |
+
handles.push_back(handle);
|
| 407 |
+
map_offsets.push_back((char*)pool_addr + pool_size);
|
| 408 |
+
|
| 409 |
+
// add to the pool
|
| 410 |
+
pool_size += reserve_size;
|
| 411 |
+
|
| 412 |
+
// GGML_LOG_INFO("cann pool[%d]: size increased to %llu MB (
|
| 413 |
+
// reserved %llu MB)\n",
|
| 414 |
+
// device, (unsigned long long) (pool_size/1024/1024),
|
| 415 |
+
// (unsigned long long) (reserve_size/1024/1024));
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
GGML_ASSERT(pool_addr != 0);
|
| 419 |
+
|
| 420 |
+
void* ptr = (void*)((char*)pool_addr + pool_used);
|
| 421 |
+
*actual_size = size;
|
| 422 |
+
pool_used += size;
|
| 423 |
+
|
| 424 |
+
#ifdef DEBUG_CANN_MALLOC
|
| 425 |
+
GGML_LOG_INFO("cann pool[%d]: allocated %llu bytes at %llx\n", device,
|
| 426 |
+
(unsigned long long)size, (unsigned long long)ptr);
|
| 427 |
+
#endif
|
| 428 |
+
return ptr;
|
| 429 |
+
}
|
| 430 |
+
|
| 431 |
+
/**
|
| 432 |
+
* @brief Free a buffer and return it to the virtual memory pool.
|
| 433 |
+
*
|
| 434 |
+
* @param ptr Pointer to the buffer to free.
|
| 435 |
+
* @param size Size of the buffer to free.
|
| 436 |
+
*/
|
| 437 |
+
void free(void* ptr, size_t size) override {
|
| 438 |
+
#ifdef DEBUG_CANN_MALLOC
|
| 439 |
+
GGML_LOG_INFO("cann pool[%d]: freed %llu bytes at %llx\n", device,
|
| 440 |
+
(unsigned long long)size, (unsigned long long)ptr);
|
| 441 |
+
#endif
|
| 442 |
+
|
| 443 |
+
pool_used -= size;
|
| 444 |
+
|
| 445 |
+
// all deallocations must be in reverse order of the allocations
|
| 446 |
+
GGML_ASSERT(ptr == (void*)((char*)pool_addr + pool_used));
|
| 447 |
+
}
|
| 448 |
+
};
|
| 449 |
+
|
| 450 |
+
/**
|
| 451 |
+
* @brief Create a new CANN pool for a specific device.
|
| 452 |
+
*
|
| 453 |
+
* Factory method to create a new CANN pool object based on the device type.
|
| 454 |
+
*
|
| 455 |
+
* @param device The device ID for which to create the pool.
|
| 456 |
+
* @return A unique pointer to the created CANN pool.
|
| 457 |
+
*/
|
| 458 |
+
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(
|
| 459 |
+
int device) {
|
| 460 |
+
// return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_leg(device));
|
| 461 |
+
return std::unique_ptr<ggml_cann_pool>(new ggml_cann_pool_vmm(device));
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
// cann buffer
|
| 465 |
+
/**
|
| 466 |
+
* @brief Context for managing a CANN buffer associated with a specific device.
|
| 467 |
+
*
|
| 468 |
+
* This structure holds information about a CANN buffer, including the device
|
| 469 |
+
* ID, device pointer, and a name derived from GGML_CANN_NAME and the device ID.
|
| 470 |
+
*/
|
| 471 |
+
struct ggml_backend_cann_buffer_context {
|
| 472 |
+
int32_t device; ///< The device ID associated with this buffer context.
|
| 473 |
+
void* dev_ptr =
|
| 474 |
+
nullptr; ///< Pointer to the device memory allocated for the buffer.
|
| 475 |
+
|
| 476 |
+
/**
|
| 477 |
+
* @brief Constructor to initialize the CANN buffer context.
|
| 478 |
+
*
|
| 479 |
+
* @param device The device ID associated with this buffer context.
|
| 480 |
+
* @param dev_ptr Pointer to the device memory allocated for the buffer.
|
| 481 |
+
*/
|
| 482 |
+
ggml_backend_cann_buffer_context(int32_t device, void* dev_ptr)
|
| 483 |
+
: device(device),
|
| 484 |
+
dev_ptr(dev_ptr) {}
|
| 485 |
+
|
| 486 |
+
/**
|
| 487 |
+
* @brief Destructor to free the device memory allocated for the buffer.
|
| 488 |
+
*/
|
| 489 |
+
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
|
| 490 |
+
};
|
| 491 |
+
|
| 492 |
+
/**
|
| 493 |
+
* @brief Check if a buffer is a CANN buffer.
|
| 494 |
+
*
|
| 495 |
+
* This function checks if a given buffer is a CANN buffer by comparing its
|
| 496 |
+
* `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
|
| 497 |
+
*
|
| 498 |
+
* @param buffer The buffer to check.
|
| 499 |
+
* @return true if the buffer is a CANN buffer, false otherwise.
|
| 500 |
+
*/
|
| 501 |
+
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
|
| 502 |
+
static bool ggml_backend_buffer_is_cann(
|
| 503 |
+
ggml_backend_buffer_t buffer) {
|
| 504 |
+
return ggml_backend_buft_is_cann(buffer->buft);
|
| 505 |
+
}
|
| 506 |
+
|
| 507 |
+
/**
|
| 508 |
+
* @brief Free resources associated with a CANN buffer.
|
| 509 |
+
*
|
| 510 |
+
* This function frees the resources associated with a CANN buffer, including
|
| 511 |
+
* its context.
|
| 512 |
+
*
|
| 513 |
+
* @param buffer The CANN buffer to free.
|
| 514 |
+
*/
|
| 515 |
+
static void ggml_backend_cann_buffer_free_buffer(
|
| 516 |
+
ggml_backend_buffer_t buffer) {
|
| 517 |
+
ggml_backend_cann_buffer_context* ctx =
|
| 518 |
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
| 519 |
+
delete ctx;
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
/**
|
| 523 |
+
* @brief Retrieve the base pointer of a CANN buffer.
|
| 524 |
+
*
|
| 525 |
+
* This function returns the base pointer of a CANN buffer, which points to the
|
| 526 |
+
* device memory allocated for the buffer.
|
| 527 |
+
*
|
| 528 |
+
* @param buffer The CANN buffer whose base pointer is to be retrieved.
|
| 529 |
+
* @return A pointer to the base of the device memory allocated for the buffer.
|
| 530 |
+
*/
|
| 531 |
+
static void* ggml_backend_cann_buffer_get_base(
|
| 532 |
+
ggml_backend_buffer_t buffer) {
|
| 533 |
+
ggml_backend_cann_buffer_context* ctx =
|
| 534 |
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
| 535 |
+
return ctx->dev_ptr;
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
/**
|
| 539 |
+
* @brief Transform quantized Q4.0 tensor data into a format suitable for CANN
|
| 540 |
+
* processing.
|
| 541 |
+
*
|
| 542 |
+
* This function transforms quantized Q4.0 tensor data into a format suitable
|
| 543 |
+
* for CANN processing. It extracts quantization values and scales from the
|
| 544 |
+
* source data and prepares them in a format expected by CANN operations.
|
| 545 |
+
*
|
| 546 |
+
* @param tensor Pointer to the tensor information.
|
| 547 |
+
* @param src Pointer to the source data in Q4.0 format.
|
| 548 |
+
* @param dst Pointer to the destination buffer where transformed data will be
|
| 549 |
+
* stored.
|
| 550 |
+
*/
|
| 551 |
+
static void ggml_backend_cann_transform_q4_0(ggml_tensor* tensor,
|
| 552 |
+
const void* src,
|
| 553 |
+
void* dst) {
|
| 554 |
+
|
| 555 |
+
int64_t n_elems = ggml_nelements(tensor);
|
| 556 |
+
int64_t groups = n_elems / QK4_0;
|
| 557 |
+
size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
|
| 558 |
+
|
| 559 |
+
uint8_t* quant_offset = (uint8_t*)dst;
|
| 560 |
+
uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
|
| 561 |
+
|
| 562 |
+
for (int i = 0; i < groups; i++) {
|
| 563 |
+
const block_q4_0* group =
|
| 564 |
+
(const block_q4_0*)((const char*)src + i * sizeof(block_q4_0));
|
| 565 |
+
*scale_offset = group->d;
|
| 566 |
+
scale_offset++;
|
| 567 |
+
|
| 568 |
+
// 0-15
|
| 569 |
+
for (int j = 0; j < QK4_0 / 2; j += 2) {
|
| 570 |
+
(*quant_offset) = (group->qs[j] & 0x0F);
|
| 571 |
+
(*quant_offset) |= ((group->qs[j + 1] << 4));
|
| 572 |
+
quant_offset++;
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
// 16-31
|
| 576 |
+
for (int j = 0; j < QK4_0 / 2; j += 2) {
|
| 577 |
+
(*quant_offset) = (group->qs[j] >> 4);
|
| 578 |
+
(*quant_offset) |= (group->qs[j + 1] & 0xF0);
|
| 579 |
+
quant_offset++;
|
| 580 |
+
}
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
// put (uint4b_t -8) into int4b_t
|
| 584 |
+
for (quant_offset = (uint8_t*)dst;
|
| 585 |
+
quant_offset < (uint8_t*)dst + quant_bytes; quant_offset++) {
|
| 586 |
+
(*quant_offset) ^= 0x88;
|
| 587 |
+
}
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
/**
|
| 591 |
+
* @brief Transform CANN processed data back into quantized Q4.0 format.
|
| 592 |
+
*
|
| 593 |
+
* This function transforms CANN processed data back into quantized Q4.0 format.
|
| 594 |
+
* It reverses the transformation performed by
|
| 595 |
+
* ggml_backend_cann_transform_q4_0(), converting the data back into its
|
| 596 |
+
* original quantized form.
|
| 597 |
+
*
|
| 598 |
+
* @param tensor Pointer to the tensor information.
|
| 599 |
+
* @param src Pointer to the source buffer containing transformed data.
|
| 600 |
+
* @param dst Pointer to the destination buffer where the Q4.0 formatted data
|
| 601 |
+
* will be stored.
|
| 602 |
+
*/
|
| 603 |
+
static void ggml_backend_cann_transform_back_q4_0(
|
| 604 |
+
const ggml_tensor* tensor, void* src, void* dst) {
|
| 605 |
+
|
| 606 |
+
int64_t n_elems = ggml_nelements(tensor);
|
| 607 |
+
int64_t groups = n_elems / QK4_0;
|
| 608 |
+
size_t quant_bytes = n_elems * sizeof(uint8_t) / 2;
|
| 609 |
+
|
| 610 |
+
uint8_t* quant_offset = (uint8_t*)src;
|
| 611 |
+
uint16_t* scale_offset = (uint16_t*)((char*)src + quant_bytes);
|
| 612 |
+
|
| 613 |
+
for (; quant_offset < (uint8_t*)src + quant_bytes; quant_offset++) {
|
| 614 |
+
(*quant_offset) ^= 0x88;
|
| 615 |
+
}
|
| 616 |
+
quant_offset = (uint8_t*)src;
|
| 617 |
+
|
| 618 |
+
for (int i = 0; i < groups; i++) {
|
| 619 |
+
block_q4_0* group = (block_q4_0*)((char*)dst + i * sizeof(block_q4_0));
|
| 620 |
+
group->d = *scale_offset;
|
| 621 |
+
scale_offset++;
|
| 622 |
+
|
| 623 |
+
// 0-15
|
| 624 |
+
for (int j = 0; j < QK4_0 / 2; j += 2) {
|
| 625 |
+
group->qs[j] = ((*quant_offset) & 0x0F);
|
| 626 |
+
group->qs[j + 1] = ((*quant_offset) >> 4);
|
| 627 |
+
quant_offset++;
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
// 16-31
|
| 631 |
+
for (int j = 0; j < QK4_0 / 2; j += 2) {
|
| 632 |
+
group->qs[j] |= ((*quant_offset) << 4);
|
| 633 |
+
group->qs[j + 1] |= ((*quant_offset) & 0xF0);
|
| 634 |
+
quant_offset++;
|
| 635 |
+
}
|
| 636 |
+
}
|
| 637 |
+
}
|
| 638 |
+
|
| 639 |
+
/**
|
| 640 |
+
* @brief Transform quantized Q8.0 tensor data into a format suitable for CANN
|
| 641 |
+
* processing.
|
| 642 |
+
*
|
| 643 |
+
* This function transforms quantized Q8.0 tensor data into a format suitable
|
| 644 |
+
* for CANN processing. It extracts quantization values and scales from the
|
| 645 |
+
* source data and prepares them in a format expected by CANN operations.
|
| 646 |
+
*
|
| 647 |
+
* @param tensor Pointer to the tensor information.
|
| 648 |
+
* @param src Pointer to the source data in Q8.0 format.
|
| 649 |
+
* @param dst Pointer to the destination buffer where transformed data will be
|
| 650 |
+
* stored.
|
| 651 |
+
*/
|
| 652 |
+
static void ggml_backend_cann_transform_q8_0(ggml_tensor* tensor,
|
| 653 |
+
const void* src,
|
| 654 |
+
void* dst) {
|
| 655 |
+
int64_t n_elems = ggml_nelements(tensor);
|
| 656 |
+
int64_t groups = n_elems / QK8_0;
|
| 657 |
+
size_t quant_bytes = n_elems * sizeof(uint8_t);
|
| 658 |
+
|
| 659 |
+
uint8_t* quant_offset = (uint8_t*)dst;
|
| 660 |
+
uint16_t* scale_offset = (uint16_t*)((char*)dst + quant_bytes);
|
| 661 |
+
|
| 662 |
+
for (int i = 0; i < groups; i++) {
|
| 663 |
+
const block_q8_0* group =
|
| 664 |
+
(const block_q8_0*)((const char*)src + i * sizeof(block_q8_0));
|
| 665 |
+
*scale_offset = group->d;
|
| 666 |
+
scale_offset++;
|
| 667 |
+
size_t group_quant_size = QK8_0 * sizeof(uint8_t);
|
| 668 |
+
memcpy(quant_offset, group->qs, group_quant_size);
|
| 669 |
+
quant_offset += group_quant_size;
|
| 670 |
+
}
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
/**
|
| 674 |
+
* @brief Transform CANN processed data back into quantized Q8.0 format.
|
| 675 |
+
*
|
| 676 |
+
* This function transforms CANN processed data back into quantized Q8.0 format.
|
| 677 |
+
* It reverses the transformation performed by
|
| 678 |
+
* ggml_backend_cann_transform_q8_0(), converting the data back into its
|
| 679 |
+
* original quantized form.
|
| 680 |
+
*
|
| 681 |
+
* @param tensor Pointer to the tensor information.
|
| 682 |
+
* @param src Pointer to the source buffer containing transformed data.
|
| 683 |
+
* @param dst Pointer to the destination buffer where the Q8.0 formatted data
|
| 684 |
+
* will be stored.
|
| 685 |
+
*/
|
| 686 |
+
static void ggml_backend_cann_transform_back_q8_0(
|
| 687 |
+
const ggml_tensor* tensor, const void* src, void* dst) {
|
| 688 |
+
int64_t n_elems = ggml_nelements(tensor);
|
| 689 |
+
int64_t groups = n_elems / QK8_0;
|
| 690 |
+
size_t quant_bytes = n_elems * sizeof(uint8_t);
|
| 691 |
+
|
| 692 |
+
const uint8_t* quant_offset = (const uint8_t*)src;
|
| 693 |
+
const uint16_t* scale_offset =
|
| 694 |
+
(const uint16_t*)((const char*)src + quant_bytes);
|
| 695 |
+
|
| 696 |
+
for (int i = 0; i < groups; i++) {
|
| 697 |
+
block_q8_0* group = (block_q8_0*)((char*)dst + i * sizeof(block_q8_0));
|
| 698 |
+
group->d = *scale_offset;
|
| 699 |
+
scale_offset++;
|
| 700 |
+
size_t group_quant_size = QK8_0 * sizeof(uint8_t);
|
| 701 |
+
memcpy(group->qs, quant_offset, group_quant_size);
|
| 702 |
+
quant_offset += group_quant_size;
|
| 703 |
+
}
|
| 704 |
+
}
|
| 705 |
+
|
| 706 |
+
/**
|
| 707 |
+
* @brief Transform tensor data based on its type for CANN processing.
|
| 708 |
+
*
|
| 709 |
+
* This function transforms tensor data based on its quantization type for CANN
|
| 710 |
+
* processing. It dispatches the transformation based on the tensor's type to
|
| 711 |
+
* specialized functions handling Q4.0 and Q8.0 formats.
|
| 712 |
+
*
|
| 713 |
+
* @param tensor Pointer to the tensor information.
|
| 714 |
+
* @param src Pointer to the source data to be transformed.
|
| 715 |
+
* @param dst Pointer to the destination buffer where transformed data will be
|
| 716 |
+
* stored.
|
| 717 |
+
*/
|
| 718 |
+
static void ggml_backend_cann_transform(ggml_tensor* tensor,
|
| 719 |
+
const void* src, void* dst) {
|
| 720 |
+
switch (tensor->type) {
|
| 721 |
+
case GGML_TYPE_Q4_0:
|
| 722 |
+
ggml_backend_cann_transform_q4_0(tensor, src, dst);
|
| 723 |
+
break;
|
| 724 |
+
case GGML_TYPE_Q8_0:
|
| 725 |
+
ggml_backend_cann_transform_q8_0(tensor, src, dst);
|
| 726 |
+
break;
|
| 727 |
+
default:
|
| 728 |
+
break;
|
| 729 |
+
}
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
/**
|
| 733 |
+
* @brief Transform CANN processed data back into tensor data based on its type.
|
| 734 |
+
*
|
| 735 |
+
* This function transforms CANN processed data back into tensor data based on
|
| 736 |
+
* its quantization type for Q4.0 and Q8.0 formats. It dispatches the
|
| 737 |
+
* transformation based on the tensor's type to specialized functions.
|
| 738 |
+
*
|
| 739 |
+
* @param tensor Pointer to the tensor information.
|
| 740 |
+
* @param src Pointer to the source data containing CANN processed data.
|
| 741 |
+
* @param dst Pointer to the destination buffer where transformed tensor data
|
| 742 |
+
* will be stored.
|
| 743 |
+
*/
|
| 744 |
+
static void ggml_backend_cann_transform_back(
|
| 745 |
+
const ggml_tensor* tensor, void* src, void* dst) {
|
| 746 |
+
switch (tensor->type) {
|
| 747 |
+
case GGML_TYPE_Q4_0:
|
| 748 |
+
ggml_backend_cann_transform_back_q4_0(tensor, src, dst);
|
| 749 |
+
break;
|
| 750 |
+
case GGML_TYPE_Q8_0:
|
| 751 |
+
ggml_backend_cann_transform_back_q8_0(tensor, src, dst);
|
| 752 |
+
break;
|
| 753 |
+
default:
|
| 754 |
+
break;
|
| 755 |
+
}
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
/**
|
| 759 |
+
* @brief Check if transformation is needed for a given tensor type.
|
| 760 |
+
*
|
| 761 |
+
* This function checks if transformation is needed for a given tensor type
|
| 762 |
+
* to prepare data for CANN processing.
|
| 763 |
+
*
|
| 764 |
+
* @param type The tensor type to check.
|
| 765 |
+
* @return true if transformation is needed, false otherwise.
|
| 766 |
+
*/
|
| 767 |
+
static bool need_transform(ggml_type type) {
|
| 768 |
+
switch (type) {
|
| 769 |
+
case GGML_TYPE_Q4_0:
|
| 770 |
+
case GGML_TYPE_Q8_0:
|
| 771 |
+
return true;
|
| 772 |
+
default:
|
| 773 |
+
return false;
|
| 774 |
+
}
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
/**
|
| 778 |
+
* @brief Initialize a tensor using data from a CANN buffer.
|
| 779 |
+
*
|
| 780 |
+
* This function initializes a tensor using data from a CANN buffer.
|
| 781 |
+
* It handles special cases such as views and quantization.
|
| 782 |
+
*
|
| 783 |
+
* @param buffer The CANN buffer from which to initialize the tensor.
|
| 784 |
+
* @param tensor Pointer to the tensor to be initialized.
|
| 785 |
+
*/
|
| 786 |
+
static void ggml_backend_cann_buffer_init_tensor(
|
| 787 |
+
ggml_backend_buffer_t buffer, ggml_tensor* tensor) {
|
| 788 |
+
if (tensor->view_src != NULL && tensor->view_offs == 0) {
|
| 789 |
+
GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
|
| 790 |
+
return;
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
// TODO: can backend doesn't support quantized yet. Just leave the code
|
| 794 |
+
// here.
|
| 795 |
+
if (ggml_is_quantized(tensor->type)) {
|
| 796 |
+
// Initialize padding to 0 to avoid possible NaN values
|
| 797 |
+
size_t original_size = ggml_nbytes(tensor);
|
| 798 |
+
size_t padded_size =
|
| 799 |
+
ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
|
| 800 |
+
|
| 801 |
+
if (padded_size > original_size && tensor->view_src == nullptr) {
|
| 802 |
+
size_t memset_size = padded_size - original_size;
|
| 803 |
+
ACL_CHECK(aclrtMemset((char*)tensor->data + original_size,
|
| 804 |
+
memset_size, 0, memset_size));
|
| 805 |
+
}
|
| 806 |
+
}
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
// TODO: need handle tensor which has paddings.
|
| 810 |
+
/**
|
| 811 |
+
* @brief Set tensor data in a CANN buffer.
|
| 812 |
+
*
|
| 813 |
+
* This function sets tensor data in a CANN buffer, handling transformations
|
| 814 |
+
* if needed based on the tensor's type.
|
| 815 |
+
*
|
| 816 |
+
* @param buffer The CANN buffer where the tensor data will be set.
|
| 817 |
+
* @param tensor Pointer to the tensor whose data will be set.
|
| 818 |
+
* @param data Pointer to the source data to be copied into the tensor.
|
| 819 |
+
* @param offset Offset in the source data from where to start copying.
|
| 820 |
+
* @param size Size of the data to be copied, in bytes.
|
| 821 |
+
*/
|
| 822 |
+
static void ggml_backend_cann_buffer_set_tensor(
|
| 823 |
+
ggml_backend_buffer_t buffer, ggml_tensor *tensor, const void *data,
|
| 824 |
+
size_t offset, size_t size) {
|
| 825 |
+
ggml_backend_cann_buffer_context *ctx =
|
| 826 |
+
(ggml_backend_cann_buffer_context *)buffer->context;
|
| 827 |
+
|
| 828 |
+
ggml_cann_set_device(ctx->device);
|
| 829 |
+
// TODO: refer to cann(#6017), it use thread's default stream.
|
| 830 |
+
// For acl, synchronous functions use this default stream.
|
| 831 |
+
// Why aclrtSynchronizeDevice?
|
| 832 |
+
|
| 833 |
+
if (!need_transform(tensor->type)) {
|
| 834 |
+
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
|
| 835 |
+
ACL_MEMCPY_HOST_TO_DEVICE));
|
| 836 |
+
} else {
|
| 837 |
+
void *transform_buffer = malloc(size);
|
| 838 |
+
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
| 839 |
+
|
| 840 |
+
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
|
| 841 |
+
transform_buffer, size,
|
| 842 |
+
ACL_MEMCPY_HOST_TO_DEVICE));
|
| 843 |
+
free(transform_buffer);
|
| 844 |
+
}
|
| 845 |
+
}
|
| 846 |
+
|
| 847 |
+
/**
|
| 848 |
+
* @brief Get tensor data from a CANN buffer.
|
| 849 |
+
*
|
| 850 |
+
* This function retrieves tensor data from a CANN buffer, handling
|
| 851 |
+
* transformations if needed based on the tensor's type.
|
| 852 |
+
*
|
| 853 |
+
* @param buffer The CANN buffer from which to retrieve tensor data.
|
| 854 |
+
* @param tensor Pointer to the tensor whose data will be retrieved.
|
| 855 |
+
* @param data Pointer to the destination buffer where the tensor data will be
|
| 856 |
+
* copied.
|
| 857 |
+
* @param offset Offset in the destination buffer where to start copying.
|
| 858 |
+
* @param size Size of the data to be copied, in bytes.
|
| 859 |
+
*/
|
| 860 |
+
static void ggml_backend_cann_buffer_get_tensor(
|
| 861 |
+
ggml_backend_buffer_t buffer, const ggml_tensor* tensor, void* data,
|
| 862 |
+
size_t offset, size_t size) {
|
| 863 |
+
ggml_backend_cann_buffer_context* ctx =
|
| 864 |
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
| 865 |
+
|
| 866 |
+
ggml_cann_set_device(ctx->device);
|
| 867 |
+
|
| 868 |
+
if (!need_transform(tensor->type)) {
|
| 869 |
+
ACL_CHECK(aclrtMemcpy(data, size, (char*)tensor->data + offset, size,
|
| 870 |
+
ACL_MEMCPY_DEVICE_TO_HOST));
|
| 871 |
+
} else {
|
| 872 |
+
void* transform_buffer = malloc(size);
|
| 873 |
+
ACL_CHECK(aclrtMemcpy(transform_buffer, size,
|
| 874 |
+
(char*)tensor->data + offset, size,
|
| 875 |
+
ACL_MEMCPY_DEVICE_TO_HOST));
|
| 876 |
+
ggml_backend_cann_transform_back(tensor, transform_buffer, data);
|
| 877 |
+
free(transform_buffer);
|
| 878 |
+
}
|
| 879 |
+
}
|
| 880 |
+
|
| 881 |
+
/**
|
| 882 |
+
* @brief Copy tensor data between CANN buffers if possible.
|
| 883 |
+
*
|
| 884 |
+
* This function copies tensor data between CANN buffers if the source and
|
| 885 |
+
* destination buffers are CANN buffers and they meet the necessary conditions
|
| 886 |
+
* (same device or devices can access each other).
|
| 887 |
+
*
|
| 888 |
+
* @param buffer The destination CANN buffer where the tensor data will be
|
| 889 |
+
* copied.
|
| 890 |
+
* @param src Pointer to the source tensor whose data will be copied.
|
| 891 |
+
* @param dst Pointer to the destination tensor where the data will be copied.
|
| 892 |
+
* @return true if the copy operation succeeded, false otherwise.
|
| 893 |
+
*/
|
| 894 |
+
static bool ggml_backend_cann_buffer_cpy_tensor(
|
| 895 |
+
ggml_backend_buffer_t buffer, const ggml_tensor* src, ggml_tensor* dst) {
|
| 896 |
+
if (ggml_backend_buffer_is_cann(src->buffer)) {
|
| 897 |
+
ggml_backend_cann_buffer_context* src_ctx =
|
| 898 |
+
(ggml_backend_cann_buffer_context*)src->buffer->context;
|
| 899 |
+
ggml_backend_cann_buffer_context* dst_ctx =
|
| 900 |
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
| 901 |
+
|
| 902 |
+
size_t memcpy_size = ggml_nbytes(src);
|
| 903 |
+
// Same device.
|
| 904 |
+
if (src_ctx->device == dst_ctx->device) {
|
| 905 |
+
ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
|
| 906 |
+
(const char*)src->data, memcpy_size,
|
| 907 |
+
ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 908 |
+
return true;
|
| 909 |
+
} else {
|
| 910 |
+
// Different device but can access by peer.
|
| 911 |
+
int32_t canAccessPeer = 0;
|
| 912 |
+
ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
|
| 913 |
+
dst_ctx->device));
|
| 914 |
+
if (canAccessPeer) {
|
| 915 |
+
ggml_cann_set_device(src_ctx->device);
|
| 916 |
+
ACL_CHECK(aclrtDeviceEnablePeerAccess(dst_ctx->device, 0));
|
| 917 |
+
ACL_CHECK(aclrtMemcpy((char*)dst->data, memcpy_size,
|
| 918 |
+
(const char*)src->data, memcpy_size,
|
| 919 |
+
ACL_MEMCPY_DEVICE_TO_DEVICE));
|
| 920 |
+
return true;
|
| 921 |
+
}
|
| 922 |
+
}
|
| 923 |
+
}
|
| 924 |
+
return false;
|
| 925 |
+
}
|
| 926 |
+
|
| 927 |
+
/**
|
| 928 |
+
* @brief Clear a CANN buffer by setting all its memory to a specified value.
|
| 929 |
+
*
|
| 930 |
+
* This function clears a CANN buffer by setting all its memory to a specified
|
| 931 |
+
* value.
|
| 932 |
+
*
|
| 933 |
+
* @param buffer The CANN buffer to be cleared.
|
| 934 |
+
* @param value The value to which each byte in the buffer will be set.
|
| 935 |
+
*/
|
| 936 |
+
static void ggml_backend_cann_buffer_clear(
|
| 937 |
+
ggml_backend_buffer_t buffer, uint8_t value) {
|
| 938 |
+
ggml_backend_cann_buffer_context* ctx =
|
| 939 |
+
(ggml_backend_cann_buffer_context*)buffer->context;
|
| 940 |
+
|
| 941 |
+
ggml_cann_set_device(ctx->device);
|
| 942 |
+
ACL_CHECK(aclrtMemset(ctx->dev_ptr, buffer->size, value, buffer->size));
|
| 943 |
+
}
|
| 944 |
+
|
| 945 |
+
/**
|
| 946 |
+
* @brief Interface for a CANN buffer in the backend.
|
| 947 |
+
*
|
| 948 |
+
* This structure defines function pointers to operations that can be performed
|
| 949 |
+
* on a CANN buffer within the backend.
|
| 950 |
+
*/
|
| 951 |
+
static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
|
| 952 |
+
/* .free_buffer = */ ggml_backend_cann_buffer_free_buffer,
|
| 953 |
+
/* .get_base = */ ggml_backend_cann_buffer_get_base,
|
| 954 |
+
/* .init_tensor = */ ggml_backend_cann_buffer_init_tensor,
|
| 955 |
+
/* .memset_tensor = */ NULL,
|
| 956 |
+
/* .set_tensor = */ ggml_backend_cann_buffer_set_tensor,
|
| 957 |
+
/* .get_tensor = */ ggml_backend_cann_buffer_get_tensor,
|
| 958 |
+
/* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor,
|
| 959 |
+
/* .clear = */ ggml_backend_cann_buffer_clear,
|
| 960 |
+
/* .reset = */ NULL,
|
| 961 |
+
};
|
| 962 |
+
|
| 963 |
+
// cann buffer type
|
| 964 |
+
/**
|
| 965 |
+
* @brief Structure representing context information for a specific backend
|
| 966 |
+
* buffer type.
|
| 967 |
+
*/
|
| 968 |
+
struct ggml_backend_cann_buffer_type_context {
|
| 969 |
+
int32_t
|
| 970 |
+
device; /**< Device identifier associated with the buffer context. */
|
| 971 |
+
std::string name; /**< Name associated with the buffer context. */
|
| 972 |
+
};
|
| 973 |
+
|
| 974 |
+
/**
|
| 975 |
+
* @brief Retrieves the name associated with a CANN buffer type.
|
| 976 |
+
*
|
| 977 |
+
* This function returns the descriptive name associated with the specified
|
| 978 |
+
* CANN buffer type context.
|
| 979 |
+
*
|
| 980 |
+
* @param buft Pointer to the buffer type context.
|
| 981 |
+
* @return Const pointer to the C-style string containing the name.
|
| 982 |
+
*/
|
| 983 |
+
static const char* ggml_backend_cann_buffer_type_name(
|
| 984 |
+
ggml_backend_buffer_type_t buft) {
|
| 985 |
+
ggml_backend_cann_buffer_type_context* buft_ctx =
|
| 986 |
+
(ggml_backend_cann_buffer_type_context*)buft->context;
|
| 987 |
+
|
| 988 |
+
return buft_ctx->name.c_str();
|
| 989 |
+
}
|
| 990 |
+
|
| 991 |
+
/**
|
| 992 |
+
* @brief Allocates a new CANN buffer of the specified type and size.
|
| 993 |
+
*
|
| 994 |
+
* This function allocates a new CANN buffer on the specified device with the
|
| 995 |
+
* given size.
|
| 996 |
+
*
|
| 997 |
+
* @param buft Pointer to the buffer type context.
|
| 998 |
+
* @param size Size in bytes of the buffer to allocate.
|
| 999 |
+
* @return Pointer to the allocated buffer, or nullptr if allocation fails.
|
| 1000 |
+
*/
|
| 1001 |
+
static ggml_backend_buffer_t
|
| 1002 |
+
ggml_backend_cann_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
|
| 1003 |
+
size_t size) {
|
| 1004 |
+
ggml_backend_cann_buffer_type_context* buft_ctx =
|
| 1005 |
+
(ggml_backend_cann_buffer_type_context*)buft->context;
|
| 1006 |
+
|
| 1007 |
+
ggml_cann_set_device(buft_ctx->device);
|
| 1008 |
+
|
| 1009 |
+
size = std::max(size, (size_t)1);
|
| 1010 |
+
|
| 1011 |
+
void* dev_ptr;
|
| 1012 |
+
aclError err = aclrtMalloc(&dev_ptr, size, ACL_MEM_MALLOC_HUGE_FIRST);
|
| 1013 |
+
if (err != ACL_SUCCESS) {
|
| 1014 |
+
GGML_LOG_ERROR(
|
| 1015 |
+
"%s: allocating %.2f MiB on device %d: aclrtMalloc failed: %s\n",
|
| 1016 |
+
__func__, size / 1024.0 / 1024.0, buft_ctx->device,
|
| 1017 |
+
aclGetRecentErrMsg());
|
| 1018 |
+
return nullptr;
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
ggml_backend_cann_buffer_context* ctx =
|
| 1022 |
+
new ggml_backend_cann_buffer_context(buft_ctx->device, dev_ptr);
|
| 1023 |
+
|
| 1024 |
+
return ggml_backend_buffer_init(buft, ggml_backend_cann_buffer_interface,
|
| 1025 |
+
ctx, size);
|
| 1026 |
+
}
|
| 1027 |
+
|
| 1028 |
+
/**
|
| 1029 |
+
* @brief Retrieves the memory alignment requirement for CANN buffers of this
|
| 1030 |
+
* type.
|
| 1031 |
+
*
|
| 1032 |
+
* This function returns the alignment requirement in bytes for memory allocated
|
| 1033 |
+
* by the CANN buffer type.
|
| 1034 |
+
*
|
| 1035 |
+
* @param buft Pointer to the buffer type context (unused in this
|
| 1036 |
+
* implementation).
|
| 1037 |
+
* @return The alignment requirement in bytes (fixed at 128 bytes for CANN
|
| 1038 |
+
* buffers).
|
| 1039 |
+
*/
|
| 1040 |
+
static size_t ggml_backend_cann_buffer_type_get_alignment(
|
| 1041 |
+
ggml_backend_buffer_type_t buft) {
|
| 1042 |
+
return 128;
|
| 1043 |
+
|
| 1044 |
+
GGML_UNUSED(buft);
|
| 1045 |
+
}
|
| 1046 |
+
|
| 1047 |
+
/**
|
| 1048 |
+
* @brief Calculates the allocation size required for a tensor in a CANN buffer.
|
| 1049 |
+
*
|
| 1050 |
+
* Computes the total allocation size needed for storing the tensor's data in a
|
| 1051 |
+
* CANN buffer, considering any necessary padding or adjustments for quantized
|
| 1052 |
+
* types.
|
| 1053 |
+
*
|
| 1054 |
+
* @param buft Pointer to the buffer type context (unused in this
|
| 1055 |
+
* implementation).
|
| 1056 |
+
* @param tensor Pointer to the tensor for which the allocation size is
|
| 1057 |
+
* calculated.
|
| 1058 |
+
* @return The total allocation size in bytes required for the tensor in the
|
| 1059 |
+
* CANN buffer.
|
| 1060 |
+
*/
|
| 1061 |
+
static size_t ggml_backend_cann_buffer_type_get_alloc_size(
|
| 1062 |
+
ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
|
| 1063 |
+
size_t size = ggml_nbytes(tensor);
|
| 1064 |
+
int64_t ne0 = tensor->ne[0];
|
| 1065 |
+
|
| 1066 |
+
// last line must bigger than 32, because every single op deal at
|
| 1067 |
+
// least 32 bytes.
|
| 1068 |
+
// TODO: quantized type?
|
| 1069 |
+
// int64_t line_size = ne0 * ggml_element_size(tensor);
|
| 1070 |
+
// int64_t line_size_align_32 = (line_size + 31) & ~31;
|
| 1071 |
+
// size += (line_size_align_32 - line_size);
|
| 1072 |
+
|
| 1073 |
+
// TODO: not support quantized yet.
|
| 1074 |
+
// TODO: consider un-continue tensor.
|
| 1075 |
+
if (ggml_is_quantized(tensor->type)) {
|
| 1076 |
+
if (ne0 % MATRIX_ROW_PADDING != 0) {
|
| 1077 |
+
size += ggml_row_size(
|
| 1078 |
+
tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
|
| 1079 |
+
}
|
| 1080 |
+
}
|
| 1081 |
+
|
| 1082 |
+
return size;
|
| 1083 |
+
|
| 1084 |
+
GGML_UNUSED(buft);
|
| 1085 |
+
}
|
| 1086 |
+
|
| 1087 |
+
static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
|
| 1088 |
+
return false;
|
| 1089 |
+
|
| 1090 |
+
GGML_UNUSED(buft);
|
| 1091 |
+
}
|
| 1092 |
+
|
| 1093 |
+
/**
|
| 1094 |
+
* @brief Interface for managing CANN buffer types in the GGML backend.
|
| 1095 |
+
*
|
| 1096 |
+
* Provides function pointers for allocating, querying properties, and managing
|
| 1097 |
+
* memory for CANN buffer types in the GGML backend.
|
| 1098 |
+
*/
|
| 1099 |
+
static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
|
| 1100 |
+
/* .get_name = */ ggml_backend_cann_buffer_type_name,
|
| 1101 |
+
/* .alloc_buffer = */ ggml_backend_cann_buffer_type_alloc_buffer,
|
| 1102 |
+
/* .get_alignment = */ ggml_backend_cann_buffer_type_get_alignment,
|
| 1103 |
+
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
| 1104 |
+
/* .get_alloc_size = */ ggml_backend_cann_buffer_type_get_alloc_size,
|
| 1105 |
+
/* .is_host = */ ggml_backend_cann_buffer_type_is_host,
|
| 1106 |
+
};
|
| 1107 |
+
|
| 1108 |
+
/**
|
| 1109 |
+
* @brief Retrieves the CANN buffer type for a specified device.
|
| 1110 |
+
*
|
| 1111 |
+
* This function initializes and returns the buffer type interface associated
|
| 1112 |
+
* with the given device. It ensures thread-safe access using a mutex.
|
| 1113 |
+
*
|
| 1114 |
+
* @param device The device index for which to retrieve the buffer type.
|
| 1115 |
+
* @return A pointer to the buffer type interface for the specified device, or
|
| 1116 |
+
* nullptr if the device index is out of range.
|
| 1117 |
+
*/
|
| 1118 |
+
ggml_backend_buffer_type_t
|
| 1119 |
+
ggml_backend_cann_buffer_type(int32_t device) {
|
| 1120 |
+
static std::mutex mutex;
|
| 1121 |
+
std::lock_guard<std::mutex> lock(mutex);
|
| 1122 |
+
|
| 1123 |
+
if (device >= ggml_backend_cann_get_device_count()) {
|
| 1124 |
+
return nullptr;
|
| 1125 |
+
}
|
| 1126 |
+
|
| 1127 |
+
static ggml_backend_buffer_type
|
| 1128 |
+
ggml_backend_cann_buffer_types[GGML_CANN_MAX_DEVICES];
|
| 1129 |
+
|
| 1130 |
+
static bool ggml_backend_cann_buffer_type_initialized = false;
|
| 1131 |
+
|
| 1132 |
+
if (!ggml_backend_cann_buffer_type_initialized) {
|
| 1133 |
+
for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
|
| 1134 |
+
ggml_backend_cann_buffer_types[i] = {
|
| 1135 |
+
/* .iface = */ ggml_backend_cann_buffer_type_interface,
|
| 1136 |
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
| 1137 |
+
/* .context = */
|
| 1138 |
+
new ggml_backend_cann_buffer_type_context{
|
| 1139 |
+
i, "CANN" + std::to_string(i)},
|
| 1140 |
+
};
|
| 1141 |
+
}
|
| 1142 |
+
ggml_backend_cann_buffer_type_initialized = true;
|
| 1143 |
+
}
|
| 1144 |
+
|
| 1145 |
+
return &ggml_backend_cann_buffer_types[device];
|
| 1146 |
+
}
|
| 1147 |
+
|
| 1148 |
+
/**
|
| 1149 |
+
* @brief Retrieves the name associated with a CANN host buffer type.
|
| 1150 |
+
*
|
| 1151 |
+
* This function returns the descriptive name associated with the specified
|
| 1152 |
+
* CANN host buffer type context.
|
| 1153 |
+
*
|
| 1154 |
+
* @param buft Pointer to the host buffer type context.
|
| 1155 |
+
* @return Const pointer to the C-style string containing the name.
|
| 1156 |
+
*/
|
| 1157 |
+
static const char * ggml_backend_cann_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
|
| 1158 |
+
return "CANN_Host";
|
| 1159 |
+
|
| 1160 |
+
GGML_UNUSED(buft);
|
| 1161 |
+
}
|
| 1162 |
+
|
| 1163 |
+
/**
|
| 1164 |
+
* @brief Retrieves the name associated with a CANN host buffer.
|
| 1165 |
+
*
|
| 1166 |
+
* This function returns the descriptive name associated with the specified
|
| 1167 |
+
* CANN host buffer context.
|
| 1168 |
+
*
|
| 1169 |
+
* @param buft Pointer to the host buffer context.
|
| 1170 |
+
* @return Const pointer to the C-style string containing the name.
|
| 1171 |
+
*/
|
| 1172 |
+
static const char * ggml_backend_cann_host_buffer_name(ggml_backend_buffer_t buffer) {
|
| 1173 |
+
return "CANN_Host";
|
| 1174 |
+
|
| 1175 |
+
GGML_UNUSED(buffer);
|
| 1176 |
+
}
|
| 1177 |
+
|
| 1178 |
+
/**
|
| 1179 |
+
* @brief Free resources associated with a CANN host buffer.
|
| 1180 |
+
*
|
| 1181 |
+
* This function frees the resources associated with a CANN host buffer, including
|
| 1182 |
+
* its context.
|
| 1183 |
+
*
|
| 1184 |
+
* @param buffer The CANN host buffer to free.
|
| 1185 |
+
*/
|
| 1186 |
+
static void ggml_backend_cann_host_buffer_free(ggml_backend_buffer_t buffer) {
|
| 1187 |
+
ACL_CHECK(aclrtFreeHost(buffer->context));
|
| 1188 |
+
}
|
| 1189 |
+
|
| 1190 |
+
/**
|
| 1191 |
+
* @brief Allocates a new CANN host buffer of the specified size.
|
| 1192 |
+
*
|
| 1193 |
+
* This function allocates a new CANN host buffer with the given size.
|
| 1194 |
+
* @param size Size in bytes of the host buffer to allocate.
|
| 1195 |
+
* @return Pointer to the allocated host buffer, or nullptr if allocation fails.
|
| 1196 |
+
*/
|
| 1197 |
+
static void * ggml_cann_host_malloc(size_t size) {
|
| 1198 |
+
if (getenv("GGML_CANN_NO_PINNED") != nullptr) {
|
| 1199 |
+
return nullptr;
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
void * hostPtr = nullptr;
|
| 1203 |
+
aclError err = aclrtMallocHost((void **) &hostPtr, size);
|
| 1204 |
+
if (err != ACL_SUCCESS) {
|
| 1205 |
+
|
| 1206 |
+
GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
|
| 1207 |
+
size / 1024.0 / 1024.0, aclGetRecentErrMsg());
|
| 1208 |
+
return nullptr;
|
| 1209 |
+
}
|
| 1210 |
+
return hostPtr;
|
| 1211 |
+
}
|
| 1212 |
+
|
| 1213 |
+
/**
|
| 1214 |
+
* @brief Allocates a new CANN host buffer of the specified type and size.
|
| 1215 |
+
*
|
| 1216 |
+
* @param buft Pointer to the host buffer type context.
|
| 1217 |
+
* @param size Size in bytes of the host buffer to allocate.
|
| 1218 |
+
* @return Pointer to the allocated host buffer, or CPU buffer pointer if allocation fails.
|
| 1219 |
+
*/
|
| 1220 |
+
static ggml_backend_buffer_t ggml_backend_cann_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
| 1221 |
+
void * hostPtr = ggml_cann_host_malloc(size);
|
| 1222 |
+
|
| 1223 |
+
if (hostPtr == nullptr) {
|
| 1224 |
+
// fallback to cpu buffer
|
| 1225 |
+
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
|
| 1226 |
+
}
|
| 1227 |
+
|
| 1228 |
+
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(hostPtr, size);
|
| 1229 |
+
buffer->buft = buft;
|
| 1230 |
+
buffer->iface.free_buffer = ggml_backend_cann_host_buffer_free;
|
| 1231 |
+
|
| 1232 |
+
return buffer;
|
| 1233 |
+
}
|
| 1234 |
+
|
| 1235 |
+
/**
|
| 1236 |
+
* @brief Interface for managing CANN host buffer types in the GGML backend.
|
| 1237 |
+
*
|
| 1238 |
+
* Provides function pointers for allocating, querying properties, and managing
|
| 1239 |
+
* memory for CANN buffer types in the GGML backend.
|
| 1240 |
+
*/
|
| 1241 |
+
ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
|
| 1242 |
+
static struct ggml_backend_buffer_type ggml_backend_cann_buffer_type_host = {
|
| 1243 |
+
/* .iface = */ {
|
| 1244 |
+
/* .get_name = */ ggml_backend_cann_host_buffer_type_name,
|
| 1245 |
+
/* .alloc_buffer = */ ggml_backend_cann_host_buffer_type_alloc_buffer,
|
| 1246 |
+
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
|
| 1247 |
+
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
| 1248 |
+
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
|
| 1249 |
+
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
|
| 1250 |
+
},
|
| 1251 |
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
|
| 1252 |
+
/* .context = */ nullptr,
|
| 1253 |
+
};
|
| 1254 |
+
|
| 1255 |
+
return &ggml_backend_cann_buffer_type_host;
|
| 1256 |
+
}
|
| 1257 |
+
|
| 1258 |
+
/**
|
| 1259 |
+
* @brief Computes the forward operation for a given tensor using CANN
|
| 1260 |
+
* operations.
|
| 1261 |
+
*
|
| 1262 |
+
* This function selects the appropriate CANN operation based on the type of
|
| 1263 |
+
* operation specified in the tensor and performs the computation.
|
| 1264 |
+
*
|
| 1265 |
+
* @param ctx The CANN context containing necessary resources and
|
| 1266 |
+
* configurations.
|
| 1267 |
+
* @param dst The destination tensor where the result of the computation will be
|
| 1268 |
+
* stored.
|
| 1269 |
+
* @return true if the computation was successful; false otherwise.
|
| 1270 |
+
*/
|
| 1271 |
+
static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx,
|
| 1272 |
+
struct ggml_tensor* dst) {
|
| 1273 |
+
switch (dst->op) {
|
| 1274 |
+
case GGML_OP_REPEAT:
|
| 1275 |
+
ggml_cann_repeat(ctx, dst);
|
| 1276 |
+
break;
|
| 1277 |
+
case GGML_OP_GET_ROWS:
|
| 1278 |
+
ggml_cann_get_rows(ctx, dst);
|
| 1279 |
+
break;
|
| 1280 |
+
case GGML_OP_DUP:
|
| 1281 |
+
ggml_cann_dup(ctx, dst);
|
| 1282 |
+
break;
|
| 1283 |
+
case GGML_OP_ADD:
|
| 1284 |
+
ggml_cann_add(ctx, dst);
|
| 1285 |
+
break;
|
| 1286 |
+
case GGML_OP_ACC:
|
| 1287 |
+
ggml_cann_acc(ctx, dst);
|
| 1288 |
+
break;
|
| 1289 |
+
case GGML_OP_MUL:
|
| 1290 |
+
ggml_cann_mul_div<aclnnMulGetWorkspaceSize, aclnnMul>(ctx, dst);
|
| 1291 |
+
break;
|
| 1292 |
+
case GGML_OP_DIV:
|
| 1293 |
+
ggml_cann_mul_div<aclnnDivGetWorkspaceSize, aclnnDiv>(ctx, dst);
|
| 1294 |
+
break;
|
| 1295 |
+
case GGML_OP_UNARY:
|
| 1296 |
+
switch (ggml_get_unary_op(dst)) {
|
| 1297 |
+
case GGML_UNARY_OP_GELU:
|
| 1298 |
+
ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
|
| 1299 |
+
ctx, dst);
|
| 1300 |
+
break;
|
| 1301 |
+
case GGML_UNARY_OP_SILU:
|
| 1302 |
+
ggml_cann_activation<aclnnSiluGetWorkspaceSize, aclnnSilu>(
|
| 1303 |
+
ctx, dst);
|
| 1304 |
+
break;
|
| 1305 |
+
// TODO: Use faster gelu??
|
| 1306 |
+
case GGML_UNARY_OP_GELU_QUICK:
|
| 1307 |
+
ggml_cann_activation<aclnnGeluGetWorkspaceSize, aclnnGelu>(
|
| 1308 |
+
ctx, dst);
|
| 1309 |
+
break;
|
| 1310 |
+
case GGML_UNARY_OP_TANH:
|
| 1311 |
+
ggml_cann_activation<aclnnTanhGetWorkspaceSize, aclnnTanh>(
|
| 1312 |
+
ctx, dst);
|
| 1313 |
+
break;
|
| 1314 |
+
case GGML_UNARY_OP_RELU:
|
| 1315 |
+
ggml_cann_activation<aclnnReluGetWorkspaceSize, aclnnRelu>(
|
| 1316 |
+
ctx, dst);
|
| 1317 |
+
break;
|
| 1318 |
+
case GGML_UNARY_OP_HARDSIGMOID:
|
| 1319 |
+
ggml_cann_activation<aclnnHardsigmoidGetWorkspaceSize,
|
| 1320 |
+
aclnnHardsigmoid>(ctx, dst);
|
| 1321 |
+
break;
|
| 1322 |
+
case GGML_UNARY_OP_HARDSWISH:
|
| 1323 |
+
ggml_cann_activation<aclnnHardswishGetWorkspaceSize,
|
| 1324 |
+
aclnnHardswish>(ctx, dst);
|
| 1325 |
+
break;
|
| 1326 |
+
default:
|
| 1327 |
+
return false;
|
| 1328 |
+
}
|
| 1329 |
+
break;
|
| 1330 |
+
case GGML_OP_NORM:
|
| 1331 |
+
ggml_cann_norm(ctx, dst);
|
| 1332 |
+
break;
|
| 1333 |
+
case GGML_OP_GROUP_NORM:
|
| 1334 |
+
ggml_cann_group_norm(ctx, dst);
|
| 1335 |
+
break;
|
| 1336 |
+
case GGML_OP_CONCAT:
|
| 1337 |
+
ggml_cann_concat(ctx, dst);
|
| 1338 |
+
break;
|
| 1339 |
+
case GGML_OP_UPSCALE:
|
| 1340 |
+
ggml_cann_upsample_nearest2d(ctx, dst);
|
| 1341 |
+
break;
|
| 1342 |
+
case GGML_OP_PAD:
|
| 1343 |
+
ggml_cann_pad(ctx, dst);
|
| 1344 |
+
break;
|
| 1345 |
+
case GGML_OP_ARANGE:
|
| 1346 |
+
ggml_cann_arange(ctx, dst);
|
| 1347 |
+
break;
|
| 1348 |
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 1349 |
+
ggml_cann_timestep_embedding(ctx, dst);
|
| 1350 |
+
break;
|
| 1351 |
+
case GGML_OP_LEAKY_RELU:
|
| 1352 |
+
ggml_cann_leaky_relu(ctx, dst);
|
| 1353 |
+
break;
|
| 1354 |
+
case GGML_OP_RMS_NORM:
|
| 1355 |
+
ggml_cann_rms_norm(ctx, dst);
|
| 1356 |
+
break;
|
| 1357 |
+
case GGML_OP_MUL_MAT:
|
| 1358 |
+
ggml_cann_mul_mat(ctx, dst);
|
| 1359 |
+
break;
|
| 1360 |
+
case GGML_OP_MUL_MAT_ID:
|
| 1361 |
+
return false;
|
| 1362 |
+
case GGML_OP_SCALE:
|
| 1363 |
+
ggml_cann_scale(ctx, dst);
|
| 1364 |
+
break;
|
| 1365 |
+
case GGML_OP_SQR:
|
| 1366 |
+
ggml_cann_sqr(ctx, dst);
|
| 1367 |
+
break;
|
| 1368 |
+
case GGML_OP_CLAMP:
|
| 1369 |
+
ggml_cann_clamp(ctx, dst);
|
| 1370 |
+
break;
|
| 1371 |
+
case GGML_OP_CPY:
|
| 1372 |
+
ggml_cann_cpy(ctx, dst);
|
| 1373 |
+
break;
|
| 1374 |
+
case GGML_OP_CONT:
|
| 1375 |
+
ggml_cann_dup(ctx, dst);
|
| 1376 |
+
break;
|
| 1377 |
+
case GGML_OP_NONE:
|
| 1378 |
+
case GGML_OP_RESHAPE:
|
| 1379 |
+
case GGML_OP_VIEW:
|
| 1380 |
+
case GGML_OP_PERMUTE:
|
| 1381 |
+
case GGML_OP_TRANSPOSE:
|
| 1382 |
+
break;
|
| 1383 |
+
case GGML_OP_DIAG_MASK_INF:
|
| 1384 |
+
ggml_cann_diag_mask(ctx, dst, -INFINITY);
|
| 1385 |
+
break;
|
| 1386 |
+
case GGML_OP_SOFT_MAX:
|
| 1387 |
+
ggml_cann_softmax(ctx, dst);
|
| 1388 |
+
break;
|
| 1389 |
+
case GGML_OP_ROPE:
|
| 1390 |
+
ggml_cann_rope(ctx, dst);
|
| 1391 |
+
break;
|
| 1392 |
+
case GGML_OP_IM2COL:
|
| 1393 |
+
ggml_cann_im2col(ctx, dst);
|
| 1394 |
+
break;
|
| 1395 |
+
case GGML_OP_POOL_2D:
|
| 1396 |
+
ggml_cann_pool2d(ctx, dst);
|
| 1397 |
+
break;
|
| 1398 |
+
case GGML_OP_SUM_ROWS:
|
| 1399 |
+
ggml_cann_sum_rows(ctx, dst);
|
| 1400 |
+
break;
|
| 1401 |
+
case GGML_OP_ARGSORT:
|
| 1402 |
+
ggml_cann_argsort(ctx, dst);
|
| 1403 |
+
break;
|
| 1404 |
+
default:
|
| 1405 |
+
return false;
|
| 1406 |
+
}
|
| 1407 |
+
|
| 1408 |
+
return true;
|
| 1409 |
+
}
|
| 1410 |
+
|
| 1411 |
+
// backend
|
| 1412 |
+
/**
|
| 1413 |
+
* @brief Retrieves the name associated with the CANN backend.
|
| 1414 |
+
*
|
| 1415 |
+
* This function returns the name assigned to the CANN backend, which is stored
|
| 1416 |
+
* in the context of the provided backend structure.
|
| 1417 |
+
*
|
| 1418 |
+
* @param backend Pointer to the CANN backend structure.
|
| 1419 |
+
* @return A pointer to a constant string representing the backend name.
|
| 1420 |
+
*/
|
| 1421 |
+
static const char* ggml_backend_cann_name(ggml_backend_t backend) {
|
| 1422 |
+
ggml_backend_cann_context* cann_ctx =
|
| 1423 |
+
(ggml_backend_cann_context*)backend->context;
|
| 1424 |
+
|
| 1425 |
+
return cann_ctx->name.c_str();
|
| 1426 |
+
}
|
| 1427 |
+
|
| 1428 |
+
/**
|
| 1429 |
+
* @brief Frees resources associated with the CANN backend.
|
| 1430 |
+
*
|
| 1431 |
+
* This function releases resources associated with the CANN backend context
|
| 1432 |
+
* and resets the device associated with the backend to its initial state.
|
| 1433 |
+
*
|
| 1434 |
+
* @param backend Pointer to the CANN backend structure to be freed.
|
| 1435 |
+
*/
|
| 1436 |
+
static void ggml_backend_cann_free(ggml_backend_t backend) {
|
| 1437 |
+
ggml_backend_cann_context* cann_ctx =
|
| 1438 |
+
(ggml_backend_cann_context*)backend->context;
|
| 1439 |
+
ACL_CHECK(aclrtSynchronizeDevice());
|
| 1440 |
+
ACL_CHECK(aclrtResetDevice(cann_ctx->device));
|
| 1441 |
+
|
| 1442 |
+
// finalize when last backend freed.
|
| 1443 |
+
if (cann_ctx->device == ggml_backend_cann_get_device_count() - 1) {
|
| 1444 |
+
ACL_CHECK(aclFinalize());
|
| 1445 |
+
}
|
| 1446 |
+
|
| 1447 |
+
delete cann_ctx;
|
| 1448 |
+
delete backend;
|
| 1449 |
+
}
|
| 1450 |
+
|
| 1451 |
+
/**
|
| 1452 |
+
* @brief Sets tensor data asynchronously in the CANN backend.
|
| 1453 |
+
*
|
| 1454 |
+
* This function asynchronously sets tensor data in the CANN backend. Depending
|
| 1455 |
+
* on the tensor type, it may perform data transformations before copying data
|
| 1456 |
+
* to the device.
|
| 1457 |
+
*
|
| 1458 |
+
* @param backend Pointer to the CANN backend structure.
|
| 1459 |
+
* @param tensor Pointer to the tensor structure to set data for.
|
| 1460 |
+
* @param data Pointer to the host data to copy to the tensor.
|
| 1461 |
+
* @param offset Offset in bytes within the host data.
|
| 1462 |
+
* @param size Size of the data to copy in bytes.
|
| 1463 |
+
*/
|
| 1464 |
+
static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
|
| 1465 |
+
ggml_tensor *tensor,
|
| 1466 |
+
const void *data,
|
| 1467 |
+
size_t offset,
|
| 1468 |
+
size_t size) {
|
| 1469 |
+
ggml_backend_cann_context *cann_ctx =
|
| 1470 |
+
(ggml_backend_cann_context *)backend->context;
|
| 1471 |
+
|
| 1472 |
+
if (!need_transform(tensor->type)) {
|
| 1473 |
+
ACL_CHECK(aclrtMemcpyAsync((char *)tensor->data + offset, size, data,
|
| 1474 |
+
size, ACL_MEMCPY_HOST_TO_DEVICE,
|
| 1475 |
+
cann_ctx->stream()));
|
| 1476 |
+
} else {
|
| 1477 |
+
void *transform_buffer = malloc(size);
|
| 1478 |
+
ggml_backend_cann_transform(tensor, data, transform_buffer);
|
| 1479 |
+
|
| 1480 |
+
ACL_CHECK(aclrtMemcpyAsync(
|
| 1481 |
+
(char *)tensor->data + offset, size, transform_buffer, size,
|
| 1482 |
+
ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
|
| 1483 |
+
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
| 1484 |
+
free(transform_buffer);
|
| 1485 |
+
}
|
| 1486 |
+
}
|
| 1487 |
+
|
| 1488 |
+
static void ggml_backend_cann_get_tensor_async(
|
| 1489 |
+
ggml_backend_t backend, const ggml_tensor *tensor, void *data,
|
| 1490 |
+
size_t offset, size_t size) {
|
| 1491 |
+
ggml_backend_cann_context *cann_ctx =
|
| 1492 |
+
(ggml_backend_cann_context *)backend->context;
|
| 1493 |
+
ggml_backend_buffer_t buf =
|
| 1494 |
+
tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
|
| 1495 |
+
|
| 1496 |
+
GGML_ASSERT(buf->buft == ggml_backend_cann_buffer_type(cann_ctx->device) &&
|
| 1497 |
+
"unsupported buffer type");
|
| 1498 |
+
|
| 1499 |
+
if (!need_transform(tensor->type)) {
|
| 1500 |
+
ACL_CHECK(aclrtMemcpyAsync(data, size, (char *)tensor->data + offset,
|
| 1501 |
+
size, ACL_MEMCPY_DEVICE_TO_HOST,
|
| 1502 |
+
cann_ctx->stream()));
|
| 1503 |
+
} else {
|
| 1504 |
+
void *transform_buffer = malloc(size);
|
| 1505 |
+
ACL_CHECK(aclrtMemcpyAsync(
|
| 1506 |
+
transform_buffer, size, (char *)tensor->data + offset, size,
|
| 1507 |
+
ACL_MEMCPY_DEVICE_TO_HOST, cann_ctx->stream()));
|
| 1508 |
+
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
| 1509 |
+
ggml_backend_cann_transform_back(tensor, transform_buffer, data);
|
| 1510 |
+
free(transform_buffer);
|
| 1511 |
+
}
|
| 1512 |
+
}
|
| 1513 |
+
|
| 1514 |
+
/**
|
| 1515 |
+
* @brief Asynchronously copies tensor data between CANN backends.
|
| 1516 |
+
*
|
| 1517 |
+
* This function copies tensor data asynchronously between two CANN backends. It
|
| 1518 |
+
* checks if both tensors reside in CANN buffers and whether the devices support
|
| 1519 |
+
* peer-to-peer access for direct copying. If not, it returns false.
|
| 1520 |
+
*
|
| 1521 |
+
* @param backend_src Pointer to the source CANN backend structure.
|
| 1522 |
+
* @param backend_dst Pointer to the destination CANN backend structure.
|
| 1523 |
+
* @param src Pointer to the source tensor to copy data from.
|
| 1524 |
+
* @param dst Pointer to the destination tensor to copy data to.
|
| 1525 |
+
* @return true if the copy operation succeeds, false otherwise.
|
| 1526 |
+
*/
|
| 1527 |
+
static bool ggml_backend_cann_cpy_tensor_async(
|
| 1528 |
+
ggml_backend_t backend_src, ggml_backend_t backend_dst,
|
| 1529 |
+
const ggml_tensor* src, ggml_tensor* dst) {
|
| 1530 |
+
GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
|
| 1531 |
+
ggml_backend_is_cann(backend_dst));
|
| 1532 |
+
|
| 1533 |
+
if (!ggml_backend_buffer_is_cann(src->buffer) ||
|
| 1534 |
+
!ggml_backend_buffer_is_cann(dst->buffer)) {
|
| 1535 |
+
return false;
|
| 1536 |
+
}
|
| 1537 |
+
|
| 1538 |
+
ggml_backend_buffer_t buf_src =
|
| 1539 |
+
src->view_src ? src->view_src->buffer : src->buffer;
|
| 1540 |
+
ggml_backend_buffer_t buf_dst =
|
| 1541 |
+
dst->view_src ? dst->view_src->buffer : dst->buffer;
|
| 1542 |
+
|
| 1543 |
+
ggml_backend_cann_context* cann_ctx_src =
|
| 1544 |
+
(ggml_backend_cann_context*)backend_src->context;
|
| 1545 |
+
ggml_backend_cann_context* cann_ctx_dst =
|
| 1546 |
+
(ggml_backend_cann_context*)backend_dst->context;
|
| 1547 |
+
|
| 1548 |
+
size_t copy_size = ggml_nbytes(dst);
|
| 1549 |
+
if (backend_src != backend_dst) {
|
| 1550 |
+
ggml_backend_cann_buffer_context* buf_ctx_src =
|
| 1551 |
+
(ggml_backend_cann_buffer_context*)buf_src->context;
|
| 1552 |
+
ggml_backend_cann_buffer_context* buf_ctx_dst =
|
| 1553 |
+
(ggml_backend_cann_buffer_context*)buf_dst->context;
|
| 1554 |
+
|
| 1555 |
+
GGML_ASSERT(cann_ctx_src->device == buf_ctx_src->device);
|
| 1556 |
+
GGML_ASSERT(cann_ctx_dst->device == buf_ctx_dst->device);
|
| 1557 |
+
|
| 1558 |
+
int32_t canAccessPeer = 0;
|
| 1559 |
+
ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, cann_ctx_src->device,
|
| 1560 |
+
cann_ctx_dst->device));
|
| 1561 |
+
if (!canAccessPeer) {
|
| 1562 |
+
return false;
|
| 1563 |
+
}
|
| 1564 |
+
|
| 1565 |
+
// need open both directions for memcpyasync between devices.
|
| 1566 |
+
ggml_cann_set_device(cann_ctx_dst->device);
|
| 1567 |
+
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
|
| 1568 |
+
ggml_cann_set_device(cann_ctx_src->device);
|
| 1569 |
+
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
|
| 1570 |
+
|
| 1571 |
+
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
|
| 1572 |
+
ACL_MEMCPY_DEVICE_TO_DEVICE,
|
| 1573 |
+
cann_ctx_src->stream()));
|
| 1574 |
+
|
| 1575 |
+
//TODO: workaround for Event didn`t work here.
|
| 1576 |
+
aclrtSynchronizeStream(cann_ctx_src->stream());
|
| 1577 |
+
} else {
|
| 1578 |
+
// src and dst are on the same backend
|
| 1579 |
+
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
|
| 1580 |
+
ACL_MEMCPY_DEVICE_TO_DEVICE,
|
| 1581 |
+
cann_ctx_dst->stream()));
|
| 1582 |
+
}
|
| 1583 |
+
|
| 1584 |
+
return true;
|
| 1585 |
+
}
|
| 1586 |
+
|
| 1587 |
+
/**
|
| 1588 |
+
* @brief Synchronizes a CANN backend.
|
| 1589 |
+
*
|
| 1590 |
+
* This function synchronizes the specified CANN backend by waiting for all
|
| 1591 |
+
* operations in its associated stream to complete.
|
| 1592 |
+
*
|
| 1593 |
+
* @param backend Pointer to the CANN backend structure to synchronize.
|
| 1594 |
+
*/
|
| 1595 |
+
static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
|
| 1596 |
+
ggml_backend_cann_context* cann_ctx =
|
| 1597 |
+
(ggml_backend_cann_context*)backend->context;
|
| 1598 |
+
|
| 1599 |
+
ggml_cann_set_device(cann_ctx->device);
|
| 1600 |
+
|
| 1601 |
+
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
|
| 1602 |
+
}
|
| 1603 |
+
|
| 1604 |
+
/**
|
| 1605 |
+
* @brief Computes a computational graph using a CANN backend.
|
| 1606 |
+
*
|
| 1607 |
+
* This function computes the operations defined in the computational graph
|
| 1608 |
+
* using the specified CANN backend.
|
| 1609 |
+
*
|
| 1610 |
+
* @param backend Pointer to the CANN backend structure to use for computation.
|
| 1611 |
+
* @param cgraph Pointer to the computational graph structure containing nodes
|
| 1612 |
+
* representing operations to be computed.
|
| 1613 |
+
* @return enum ggml_status Returns GGML_STATUS_SUCCESS if computation
|
| 1614 |
+
* completes successfully, otherwise an appropriate error status.
|
| 1615 |
+
*/
|
| 1616 |
+
static enum ggml_status ggml_backend_cann_graph_compute(
|
| 1617 |
+
ggml_backend_t backend, ggml_cgraph* cgraph) {
|
| 1618 |
+
ggml_backend_cann_context* cann_ctx =
|
| 1619 |
+
(ggml_backend_cann_context*)backend->context;
|
| 1620 |
+
|
| 1621 |
+
ggml_cann_set_device(cann_ctx->device);
|
| 1622 |
+
|
| 1623 |
+
for (int i = 0; i < cgraph->n_nodes; i++) {
|
| 1624 |
+
ggml_tensor* node = cgraph->nodes[i];
|
| 1625 |
+
|
| 1626 |
+
if (ggml_is_empty(node) || node->op == GGML_OP_NONE) {
|
| 1627 |
+
continue;
|
| 1628 |
+
}
|
| 1629 |
+
|
| 1630 |
+
bool ok = ggml_cann_compute_forward(*cann_ctx, node);
|
| 1631 |
+
|
| 1632 |
+
if (!ok) {
|
| 1633 |
+
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__,
|
| 1634 |
+
node->name, ggml_op_name(node->op));
|
| 1635 |
+
}
|
| 1636 |
+
GGML_ASSERT(ok);
|
| 1637 |
+
}
|
| 1638 |
+
|
| 1639 |
+
return GGML_STATUS_SUCCESS;
|
| 1640 |
+
}
|
| 1641 |
+
|
| 1642 |
+
/**
|
| 1643 |
+
* @brief Checks if the CANN backend supports a specific operation.
|
| 1644 |
+
*
|
| 1645 |
+
* This function checks whether the specified operation is supported by the
|
| 1646 |
+
* CANN backend.
|
| 1647 |
+
*
|
| 1648 |
+
* @param backend Pointer to the CANN backend structure to check support for
|
| 1649 |
+
* the operation.
|
| 1650 |
+
* @param op Pointer to the tensor representing the operation to check.
|
| 1651 |
+
* @return bool Returns true if the operation is supported by the backend,
|
| 1652 |
+
* otherwise false.
|
| 1653 |
+
*/
|
| 1654 |
+
static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
|
| 1655 |
+
const ggml_tensor* op) {
|
| 1656 |
+
switch (op->op) {
|
| 1657 |
+
case GGML_OP_UNARY:
|
| 1658 |
+
switch (ggml_get_unary_op(op)) {
|
| 1659 |
+
case GGML_UNARY_OP_GELU:
|
| 1660 |
+
case GGML_UNARY_OP_SILU:
|
| 1661 |
+
case GGML_UNARY_OP_RELU:
|
| 1662 |
+
case GGML_UNARY_OP_HARDSIGMOID:
|
| 1663 |
+
case GGML_UNARY_OP_HARDSWISH:
|
| 1664 |
+
case GGML_UNARY_OP_GELU_QUICK:
|
| 1665 |
+
case GGML_UNARY_OP_TANH:
|
| 1666 |
+
return true;
|
| 1667 |
+
default:
|
| 1668 |
+
return false;
|
| 1669 |
+
}
|
| 1670 |
+
case GGML_OP_MUL_MAT: {
|
| 1671 |
+
switch (op->src[0]->type) {
|
| 1672 |
+
case GGML_TYPE_F16:
|
| 1673 |
+
case GGML_TYPE_F32:
|
| 1674 |
+
case GGML_TYPE_Q8_0:
|
| 1675 |
+
// TODO: fix me
|
| 1676 |
+
// Current groupsize should not be greater than k-1 in
|
| 1677 |
+
// aclnnWeightQuantBatchMatmulV2GetWorkspaceSize().
|
| 1678 |
+
case GGML_TYPE_Q4_0:
|
| 1679 |
+
return true;
|
| 1680 |
+
default:
|
| 1681 |
+
return false;
|
| 1682 |
+
}
|
| 1683 |
+
}
|
| 1684 |
+
case GGML_OP_MUL_MAT_ID:
|
| 1685 |
+
return false;
|
| 1686 |
+
// embedding
|
| 1687 |
+
case GGML_OP_GET_ROWS: {
|
| 1688 |
+
switch (op->src[0]->type) {
|
| 1689 |
+
case GGML_TYPE_F32:
|
| 1690 |
+
case GGML_TYPE_F16:
|
| 1691 |
+
case GGML_TYPE_Q4_0:
|
| 1692 |
+
case GGML_TYPE_Q8_0:
|
| 1693 |
+
return true;
|
| 1694 |
+
default:
|
| 1695 |
+
return false;
|
| 1696 |
+
}
|
| 1697 |
+
} break;
|
| 1698 |
+
case GGML_OP_CPY: {
|
| 1699 |
+
switch (op->type) {
|
| 1700 |
+
case GGML_TYPE_F32:
|
| 1701 |
+
case GGML_TYPE_F16:
|
| 1702 |
+
case GGML_TYPE_Q8_0:
|
| 1703 |
+
case GGML_TYPE_Q4_0:
|
| 1704 |
+
return true;
|
| 1705 |
+
default:
|
| 1706 |
+
return false;
|
| 1707 |
+
}
|
| 1708 |
+
}
|
| 1709 |
+
case GGML_OP_DUP:
|
| 1710 |
+
case GGML_OP_REPEAT:
|
| 1711 |
+
case GGML_OP_CONCAT:
|
| 1712 |
+
case GGML_OP_NONE:
|
| 1713 |
+
case GGML_OP_RESHAPE:
|
| 1714 |
+
case GGML_OP_VIEW:
|
| 1715 |
+
case GGML_OP_PERMUTE:
|
| 1716 |
+
case GGML_OP_TRANSPOSE:
|
| 1717 |
+
case GGML_OP_NORM:
|
| 1718 |
+
case GGML_OP_ADD:
|
| 1719 |
+
case GGML_OP_MUL:
|
| 1720 |
+
case GGML_OP_DIV:
|
| 1721 |
+
case GGML_OP_RMS_NORM:
|
| 1722 |
+
case GGML_OP_SCALE:
|
| 1723 |
+
case GGML_OP_SQR:
|
| 1724 |
+
case GGML_OP_CLAMP:
|
| 1725 |
+
case GGML_OP_CONT:
|
| 1726 |
+
case GGML_OP_DIAG_MASK_INF:
|
| 1727 |
+
case GGML_OP_SOFT_MAX:
|
| 1728 |
+
case GGML_OP_ROPE:
|
| 1729 |
+
case GGML_OP_IM2COL:
|
| 1730 |
+
case GGML_OP_POOL_2D:
|
| 1731 |
+
case GGML_OP_SUM_ROWS:
|
| 1732 |
+
case GGML_OP_ARGSORT:
|
| 1733 |
+
case GGML_OP_ACC:
|
| 1734 |
+
case GGML_OP_GROUP_NORM:
|
| 1735 |
+
case GGML_OP_UPSCALE:
|
| 1736 |
+
case GGML_OP_PAD:
|
| 1737 |
+
case GGML_OP_ARANGE:
|
| 1738 |
+
case GGML_OP_TIMESTEP_EMBEDDING:
|
| 1739 |
+
case GGML_OP_LEAKY_RELU:
|
| 1740 |
+
return true;
|
| 1741 |
+
default:
|
| 1742 |
+
return false;
|
| 1743 |
+
}
|
| 1744 |
+
|
| 1745 |
+
GGML_UNUSED(dev);
|
| 1746 |
+
}
|
| 1747 |
+
|
| 1748 |
+
/**
|
| 1749 |
+
* @brief Checks if the backend buffer type is associated with the CANN backend.
|
| 1750 |
+
*
|
| 1751 |
+
* This function checks whether the provided backend buffer type is associated
|
| 1752 |
+
* with the CANN backend based on the comparison of its name retrieval function
|
| 1753 |
+
* pointer.
|
| 1754 |
+
*
|
| 1755 |
+
* @param buft Pointer to the backend buffer type to check.
|
| 1756 |
+
* @return bool Returns true if the buffer type is associated with the CANN
|
| 1757 |
+
* backend, otherwise false.
|
| 1758 |
+
*/
|
| 1759 |
+
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
|
| 1760 |
+
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
|
| 1761 |
+
}
|
| 1762 |
+
|
| 1763 |
+
/**
|
| 1764 |
+
* @brief Determines if a tensor operation should be offloaded to the CANN
|
| 1765 |
+
* backend.
|
| 1766 |
+
*
|
| 1767 |
+
* This function checks if a given tensor operation should be offloaded to the
|
| 1768 |
+
* CANN backend based on the operation type and the size of the tensor. It
|
| 1769 |
+
* returns true if the second dimension (ne[1]) of the tensor is greater than or
|
| 1770 |
+
* equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS.
|
| 1771 |
+
*
|
| 1772 |
+
* @param backend Pointer to the CANN backend.
|
| 1773 |
+
* @param op Pointer to the tensor operation to check.
|
| 1774 |
+
* @return bool Returns true if the operation should be offloaded, otherwise
|
| 1775 |
+
* false.
|
| 1776 |
+
*/
|
| 1777 |
+
static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
|
| 1778 |
+
const ggml_tensor* op) {
|
| 1779 |
+
const int min_batch_size = 32;
|
| 1780 |
+
GGML_UNUSED(dev);
|
| 1781 |
+
|
| 1782 |
+
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
|
| 1783 |
+
}
|
| 1784 |
+
|
| 1785 |
+
/**
|
| 1786 |
+
* @brief Records an event on the CANN backend stream.
|
| 1787 |
+
*
|
| 1788 |
+
* This function records the given event on the ACL runtime stream associated
|
| 1789 |
+
* with the backend context.
|
| 1790 |
+
*
|
| 1791 |
+
* @param event Pointer to the event structure to be recorded.
|
| 1792 |
+
*/
|
| 1793 |
+
static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
|
| 1794 |
+
ggml_backend_cann_context* cann_ctx =
|
| 1795 |
+
(ggml_backend_cann_context*)backend->context;
|
| 1796 |
+
ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
|
| 1797 |
+
}
|
| 1798 |
+
|
| 1799 |
+
/**
|
| 1800 |
+
* @brief Waits for a recorded event to complete on the CANN backend stream.
|
| 1801 |
+
*
|
| 1802 |
+
* This function makes the given backend wait for the event to complete on its
|
| 1803 |
+
* ACL runtime stream.
|
| 1804 |
+
*
|
| 1805 |
+
* @param backend Pointer to the backend structure.
|
| 1806 |
+
* @param event Pointer to the event structure that the backend needs to wait
|
| 1807 |
+
* for.
|
| 1808 |
+
*/
|
| 1809 |
+
static void ggml_backend_cann_event_wait(ggml_backend_t backend,
|
| 1810 |
+
ggml_backend_event_t event) {
|
| 1811 |
+
ggml_backend_cann_context* cann_ctx =
|
| 1812 |
+
(ggml_backend_cann_context*)backend->context;
|
| 1813 |
+
if (ggml_backend_is_cann(backend)) {
|
| 1814 |
+
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
|
| 1815 |
+
(aclrtEvent)event->context));
|
| 1816 |
+
} else {
|
| 1817 |
+
GGML_ABORT("fatal error");
|
| 1818 |
+
}
|
| 1819 |
+
}
|
| 1820 |
+
|
| 1821 |
+
/**
|
| 1822 |
+
* @brief Structure defining the interface for the CANN backend.
|
| 1823 |
+
*
|
| 1824 |
+
* This structure contains function pointers for various operations
|
| 1825 |
+
* supported by the CANN backend, including name retrieval, memory
|
| 1826 |
+
* management, tensor operations, synchronization, and event handling.
|
| 1827 |
+
*/
|
| 1828 |
+
static const ggml_backend_i ggml_backend_cann_interface = {
|
| 1829 |
+
/* .get_name = */ ggml_backend_cann_name,
|
| 1830 |
+
/* .free = */ ggml_backend_cann_free,
|
| 1831 |
+
/* .set_tensor_async = */ ggml_backend_cann_set_tensor_async,
|
| 1832 |
+
/* .get_tensor_async = */ ggml_backend_cann_get_tensor_async,
|
| 1833 |
+
/* .cpy_tensor_async = */ ggml_backend_cann_cpy_tensor_async,
|
| 1834 |
+
/* .synchronize = */ ggml_backend_cann_synchronize,
|
| 1835 |
+
/* .graph_plan_create = */ NULL,
|
| 1836 |
+
/* .graph_plan_free = */ NULL,
|
| 1837 |
+
/* .graph_plan_update = */ NULL,
|
| 1838 |
+
/* .graph_plan_compute = */ NULL,
|
| 1839 |
+
/* .graph_compute = */ ggml_backend_cann_graph_compute,
|
| 1840 |
+
/* .event_record = */ ggml_backend_cann_event_record,
|
| 1841 |
+
/* .event_wait = */ ggml_backend_cann_event_wait,
|
| 1842 |
+
};
|
| 1843 |
+
|
| 1844 |
+
/**
|
| 1845 |
+
* @brief Return the hardcoded GUID for the CANN backend.
|
| 1846 |
+
*
|
| 1847 |
+
* This function returns a static GUID which uniquely identifies the CANN
|
| 1848 |
+
* backend.
|
| 1849 |
+
*
|
| 1850 |
+
* @return A pointer to the static GUID.
|
| 1851 |
+
*/
|
| 1852 |
+
static ggml_guid_t ggml_backend_cann_guid() {
|
| 1853 |
+
static ggml_guid guid = {0xa1, 0x94, 0xaf, 0xac, 0xbd, 0x4f, 0x47, 0x34,
|
| 1854 |
+
0xbe, 0x1a, 0x9e, 0x71, 0x1f, 0x9e, 0xed, 0x64};
|
| 1855 |
+
return &guid;
|
| 1856 |
+
}
|
| 1857 |
+
|
| 1858 |
+
// backend device
|
| 1859 |
+
struct ggml_backend_cann_device_context {
|
| 1860 |
+
int device;
|
| 1861 |
+
std::string name;
|
| 1862 |
+
std::string description;
|
| 1863 |
+
};
|
| 1864 |
+
|
| 1865 |
+
static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
|
| 1866 |
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
| 1867 |
+
return ctx->name.c_str();
|
| 1868 |
+
}
|
| 1869 |
+
|
| 1870 |
+
static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
|
| 1871 |
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
| 1872 |
+
return ctx->description.c_str();
|
| 1873 |
+
}
|
| 1874 |
+
|
| 1875 |
+
static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
| 1876 |
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
| 1877 |
+
ggml_backend_cann_get_device_memory(ctx->device, free, total);
|
| 1878 |
+
}
|
| 1879 |
+
|
| 1880 |
+
static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
|
| 1881 |
+
GGML_UNUSED(dev);
|
| 1882 |
+
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
| 1883 |
+
}
|
| 1884 |
+
|
| 1885 |
+
static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
|
| 1886 |
+
props->name = ggml_backend_cann_device_get_name(dev);
|
| 1887 |
+
props->description = ggml_backend_cann_device_get_description(dev);
|
| 1888 |
+
props->type = ggml_backend_cann_device_get_type(dev);
|
| 1889 |
+
ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
| 1890 |
+
|
| 1891 |
+
bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
|
| 1892 |
+
|
| 1893 |
+
props->caps = {
|
| 1894 |
+
/* .async = */ false,
|
| 1895 |
+
/* .host_buffer = */ host_buffer,
|
| 1896 |
+
/* .buffer_from_host_ptr = */ false,
|
| 1897 |
+
/* .events = */ true,
|
| 1898 |
+
};
|
| 1899 |
+
}
|
| 1900 |
+
|
| 1901 |
+
static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
|
| 1902 |
+
GGML_UNUSED(params);
|
| 1903 |
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
| 1904 |
+
return ggml_backend_cann_init(ctx->device);
|
| 1905 |
+
}
|
| 1906 |
+
|
| 1907 |
+
/**
|
| 1908 |
+
* @brief Checks if the CANN backend supports a specific backend buffer type.
|
| 1909 |
+
*
|
| 1910 |
+
* This function determines whether the CANN backend supports the given backend
|
| 1911 |
+
* buffer type by comparing the device context of the backend and buffer type.
|
| 1912 |
+
* It returns true if the devices are same between the backend context and
|
| 1913 |
+
* buffer type context.
|
| 1914 |
+
*
|
| 1915 |
+
* @param backend Pointer to the CANN backend.
|
| 1916 |
+
* @param buft Pointer to the backend buffer type to check.
|
| 1917 |
+
* @return bool Returns true if the CANN backend supports the buffer type,
|
| 1918 |
+
* otherwise false.
|
| 1919 |
+
*/
|
| 1920 |
+
static bool ggml_backend_cann_supports_buft(
|
| 1921 |
+
ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
| 1922 |
+
if (ggml_backend_buft_is_cann(buft)) {
|
| 1923 |
+
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
| 1924 |
+
ggml_backend_cann_buffer_type_context * buft_ctx =
|
| 1925 |
+
(ggml_backend_cann_buffer_type_context *)buft->context;
|
| 1926 |
+
return buft_ctx->device == dev_ctx->device;
|
| 1927 |
+
}
|
| 1928 |
+
return false;
|
| 1929 |
+
}
|
| 1930 |
+
|
| 1931 |
+
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
|
| 1932 |
+
ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
|
| 1933 |
+
return ggml_backend_cann_buffer_type(ctx->device);
|
| 1934 |
+
}
|
| 1935 |
+
|
| 1936 |
+
static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
|
| 1937 |
+
GGML_UNUSED(dev);
|
| 1938 |
+
return ggml_backend_cann_host_buffer_type();
|
| 1939 |
+
}
|
| 1940 |
+
|
| 1941 |
+
/**
|
| 1942 |
+
* @brief Creates a new event for the CANN backend device.
|
| 1943 |
+
*
|
| 1944 |
+
* This function initializes a new event for the CANN backend by setting the
|
| 1945 |
+
* device and creating an ACL runtime event. The created event is then wrapped
|
| 1946 |
+
* in a ggml_backend_event structure and returned.
|
| 1947 |
+
*
|
| 1948 |
+
* @param backend Pointer to the CANN backend.
|
| 1949 |
+
* @return ggml_backend_event_t Returns a pointer to the new event structure.
|
| 1950 |
+
*/
|
| 1951 |
+
static ggml_backend_event_t ggml_backend_cann_device_event_new(
|
| 1952 |
+
ggml_backend_dev_t dev) {
|
| 1953 |
+
ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
|
| 1954 |
+
|
| 1955 |
+
ggml_cann_set_device(dev_ctx->device);
|
| 1956 |
+
|
| 1957 |
+
aclrtEvent event;
|
| 1958 |
+
ACL_CHECK(aclrtCreateEvent(&event));
|
| 1959 |
+
|
| 1960 |
+
return new ggml_backend_event{
|
| 1961 |
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
|
| 1962 |
+
/* .context = */ event,
|
| 1963 |
+
};
|
| 1964 |
+
}
|
| 1965 |
+
|
| 1966 |
+
/**
|
| 1967 |
+
* @brief Frees a CANN backend event.
|
| 1968 |
+
*
|
| 1969 |
+
* This function destroys the ACL runtime event associated with the given CANN
|
| 1970 |
+
* backend event and then deletes the event structure itself.
|
| 1971 |
+
*
|
| 1972 |
+
* @param event Pointer to the event structure to be freed.
|
| 1973 |
+
*/
|
| 1974 |
+
static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
| 1975 |
+
ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
|
| 1976 |
+
|
| 1977 |
+
delete event;
|
| 1978 |
+
GGML_UNUSED(dev);
|
| 1979 |
+
}
|
| 1980 |
+
|
| 1981 |
+
/**
|
| 1982 |
+
* @brief Synchronizes the given event on the CANN backend.
|
| 1983 |
+
*
|
| 1984 |
+
* This function waits for the specified event to complete on the ACL runtime.
|
| 1985 |
+
*
|
| 1986 |
+
* @param event Pointer to the event structure to be synchronized.
|
| 1987 |
+
*/
|
| 1988 |
+
static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
|
| 1989 |
+
ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
|
| 1990 |
+
|
| 1991 |
+
GGML_UNUSED(dev);
|
| 1992 |
+
}
|
| 1993 |
+
|
| 1994 |
+
static const ggml_backend_device_i ggml_backend_cann_device_interface = {
|
| 1995 |
+
/* .get_name = */ ggml_backend_cann_device_get_name,
|
| 1996 |
+
/* .get_description = */ ggml_backend_cann_device_get_description,
|
| 1997 |
+
/* .get_memory = */ ggml_backend_cann_device_get_memory,
|
| 1998 |
+
/* .get_type = */ ggml_backend_cann_device_get_type,
|
| 1999 |
+
/* .get_props = */ ggml_backend_cann_device_get_props,
|
| 2000 |
+
/* .init_backend = */ ggml_backend_cann_device_init, // called for every card
|
| 2001 |
+
/* .get_buffer_type = */ ggml_backend_cann_device_get_buffer_type,
|
| 2002 |
+
/* .get_host_buffer_type = */ ggml_backend_cann_device_get_host_buffer_type,
|
| 2003 |
+
/* .buffer_from_host_ptr = */ NULL, // not supported for CANN
|
| 2004 |
+
/* .supports_op = */ ggml_backend_cann_supports_op,
|
| 2005 |
+
/* .supports_buft = */ ggml_backend_cann_supports_buft,
|
| 2006 |
+
/* .offload_op = */ ggml_backend_cann_offload_op,
|
| 2007 |
+
/* .event_new = */ ggml_backend_cann_device_event_new,
|
| 2008 |
+
/* .event_free = */ ggml_backend_cann_device_event_free,
|
| 2009 |
+
/* .event_synchronize = */ ggml_backend_cann_device_event_synchronize,
|
| 2010 |
+
};
|
| 2011 |
+
|
| 2012 |
+
|
| 2013 |
+
// backend reg
|
| 2014 |
+
struct ggml_backend_cann_reg_context {
|
| 2015 |
+
std::vector<ggml_backend_dev_t> devices;
|
| 2016 |
+
};
|
| 2017 |
+
|
| 2018 |
+
static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
|
| 2019 |
+
GGML_UNUSED(reg);
|
| 2020 |
+
return GGML_CANN_NAME;
|
| 2021 |
+
}
|
| 2022 |
+
|
| 2023 |
+
static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
|
| 2024 |
+
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
| 2025 |
+
return ctx->devices.size();
|
| 2026 |
+
}
|
| 2027 |
+
|
| 2028 |
+
static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
| 2029 |
+
ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
|
| 2030 |
+
GGML_ASSERT(index < ctx->devices.size());
|
| 2031 |
+
return ctx->devices[index];
|
| 2032 |
+
}
|
| 2033 |
+
|
| 2034 |
+
static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
| 2035 |
+
GGML_UNUSED(reg);
|
| 2036 |
+
GGML_UNUSED(name);
|
| 2037 |
+
// reserved for future use
|
| 2038 |
+
return nullptr;
|
| 2039 |
+
}
|
| 2040 |
+
|
| 2041 |
+
static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
|
| 2042 |
+
/* .get_name = */ ggml_backend_cann_reg_get_name,
|
| 2043 |
+
/* .get_device_count = */ ggml_backend_cann_reg_get_device_count,
|
| 2044 |
+
/* .get_device_get = */ ggml_backend_cann_reg_get_device,
|
| 2045 |
+
/* .get_proc_address = */ ggml_backend_cann_reg_get_proc_address,
|
| 2046 |
+
};
|
| 2047 |
+
|
| 2048 |
+
// backend registry, called only once for cann backend
|
| 2049 |
+
ggml_backend_reg_t ggml_backend_cann_reg() {
|
| 2050 |
+
static ggml_backend_reg reg;
|
| 2051 |
+
static bool initialized = false;
|
| 2052 |
+
|
| 2053 |
+
{
|
| 2054 |
+
static std::mutex mutex;
|
| 2055 |
+
std::lock_guard<std::mutex> lock(mutex);
|
| 2056 |
+
if (!initialized) {
|
| 2057 |
+
aclInit(nullptr);
|
| 2058 |
+
ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
|
| 2059 |
+
|
| 2060 |
+
for (int i = 0; i < ggml_cann_info().device_count; i++) {
|
| 2061 |
+
ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
|
| 2062 |
+
dev_ctx->description = aclrtGetSocName();
|
| 2063 |
+
dev_ctx->device = i;
|
| 2064 |
+
dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
|
| 2065 |
+
ggml_cann_set_device(i);
|
| 2066 |
+
ggml_backend_dev_t dev = new ggml_backend_device {
|
| 2067 |
+
/* .interface = */ ggml_backend_cann_device_interface,
|
| 2068 |
+
/* .reg = */ ®,
|
| 2069 |
+
/* .context = */ dev_ctx
|
| 2070 |
+
};
|
| 2071 |
+
ctx->devices.push_back(dev);
|
| 2072 |
+
}
|
| 2073 |
+
|
| 2074 |
+
reg = ggml_backend_reg {
|
| 2075 |
+
/* .interface = */ ggml_backend_cann_reg_interface,
|
| 2076 |
+
/* .context = */ ctx
|
| 2077 |
+
};
|
| 2078 |
+
}
|
| 2079 |
+
|
| 2080 |
+
initialized = true;
|
| 2081 |
+
}
|
| 2082 |
+
|
| 2083 |
+
return ®
|
| 2084 |
+
}
|
| 2085 |
+
|
| 2086 |
+
ggml_backend_t ggml_backend_cann_init(int32_t device) {
|
| 2087 |
+
aclInit(nullptr);
|
| 2088 |
+
if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
|
| 2089 |
+
GGML_LOG_ERROR("%s: error: invalid device %d\n", __func__, device);
|
| 2090 |
+
return nullptr;
|
| 2091 |
+
}
|
| 2092 |
+
|
| 2093 |
+
ggml_backend_cann_context* ctx = new ggml_backend_cann_context(device);
|
| 2094 |
+
if (ctx == nullptr) {
|
| 2095 |
+
GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
| 2096 |
+
return nullptr;
|
| 2097 |
+
}
|
| 2098 |
+
ggml_cann_set_device(ctx->device);
|
| 2099 |
+
ggml_backend_t cann_backend =
|
| 2100 |
+
new ggml_backend{/* .guid = */ ggml_backend_cann_guid(),
|
| 2101 |
+
/* .interface = */ ggml_backend_cann_interface,
|
| 2102 |
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
|
| 2103 |
+
/* .context = */ ctx};
|
| 2104 |
+
|
| 2105 |
+
return cann_backend;
|
| 2106 |
+
}
|
| 2107 |
+
|
| 2108 |
+
bool ggml_backend_is_cann(ggml_backend_t backend) {
|
| 2109 |
+
return backend != NULL &&
|
| 2110 |
+
ggml_guid_matches(backend->guid, ggml_backend_cann_guid());
|
| 2111 |
+
}
|
| 2112 |
+
|
| 2113 |
+
int32_t ggml_backend_cann_get_device_count() {
|
| 2114 |
+
return ggml_cann_info().device_count;
|
| 2115 |
+
}
|
| 2116 |
+
|
| 2117 |
+
void ggml_backend_cann_get_device_description(
|
| 2118 |
+
int32_t device, char* description, size_t description_size) {
|
| 2119 |
+
ggml_cann_set_device(device);
|
| 2120 |
+
const char* soc_name = aclrtGetSocName();
|
| 2121 |
+
snprintf(description, description_size, "%s", soc_name);
|
| 2122 |
+
}
|
| 2123 |
+
|
| 2124 |
+
void ggml_backend_cann_get_device_memory(int32_t device, size_t* free,
|
| 2125 |
+
size_t* total) {
|
| 2126 |
+
ggml_cann_set_device(device);
|
| 2127 |
+
ACL_CHECK(aclrtGetMemInfo(ACL_HBM_MEM, free, total));
|
| 2128 |
+
}
|
ggml/src/ggml-cpu/CMakeLists.txt
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
add_library(ggml-cpu
|
| 2 |
+
ggml-cpu.c
|
| 3 |
+
ggml-cpu.cpp
|
| 4 |
+
ggml-cpu-aarch64.c
|
| 5 |
+
ggml-cpu-aarch64.h
|
| 6 |
+
ggml-cpu-quants.c
|
| 7 |
+
ggml-cpu-quants.h
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
target_link_libraries(ggml-cpu PRIVATE ggml-base)
|
| 11 |
+
target_include_directories(ggml-cpu PRIVATE . ..)
|
| 12 |
+
|
| 13 |
+
if (APPLE AND GGML_ACCELERATE)
|
| 14 |
+
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
| 15 |
+
if (ACCELERATE_FRAMEWORK)
|
| 16 |
+
message(STATUS "Accelerate framework found")
|
| 17 |
+
|
| 18 |
+
add_compile_definitions(GGML_USE_ACCELERATE)
|
| 19 |
+
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
| 20 |
+
add_compile_definitions(ACCELERATE_LAPACK_ILP64)
|
| 21 |
+
|
| 22 |
+
target_link_libraries(ggml-cpu PRIVATE ${ACCELERATE_FRAMEWORK})
|
| 23 |
+
else()
|
| 24 |
+
message(WARNING "Accelerate framework not found")
|
| 25 |
+
endif()
|
| 26 |
+
endif()
|
| 27 |
+
|
| 28 |
+
if (GGML_OPENMP)
|
| 29 |
+
find_package(OpenMP)
|
| 30 |
+
if (OpenMP_FOUND)
|
| 31 |
+
message(STATUS "OpenMP found")
|
| 32 |
+
|
| 33 |
+
add_compile_definitions(GGML_USE_OPENMP)
|
| 34 |
+
|
| 35 |
+
target_link_libraries(ggml-cpu PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
| 36 |
+
|
| 37 |
+
# FIXME: should be replaced with a compiler id check
|
| 38 |
+
#if (GGML_MUSA)
|
| 39 |
+
# list(APPEND GGML_CPU_EXTRA_INCLUDES "/usr/lib/llvm-14/lib/clang/14.0.0/include")
|
| 40 |
+
# list(APPEND GGML_CPU_EXTRA_LIBS_PRIVATE "/usr/lib/llvm-14/lib/libomp.so")
|
| 41 |
+
#endif()
|
| 42 |
+
else()
|
| 43 |
+
message(WARNING "OpenMP not found")
|
| 44 |
+
endif()
|
| 45 |
+
endif()
|
| 46 |
+
|
| 47 |
+
if (GGML_LLAMAFILE)
|
| 48 |
+
message(STATUS "Using llamafile")
|
| 49 |
+
|
| 50 |
+
add_compile_definitions(GGML_USE_LLAMAFILE)
|
| 51 |
+
|
| 52 |
+
target_sources(ggml-cpu PRIVATE
|
| 53 |
+
llamafile/sgemm.cpp
|
| 54 |
+
llamafile/sgemm.h)
|
| 55 |
+
endif()
|
| 56 |
+
|
| 57 |
+
if (GGML_CPU_HBM)
|
| 58 |
+
find_library(memkind memkind REQUIRED)
|
| 59 |
+
|
| 60 |
+
message(STATUS "Using memkind for CPU HBM")
|
| 61 |
+
|
| 62 |
+
add_compile_definitions(GGML_USE_CPU_HBM)
|
| 63 |
+
|
| 64 |
+
target_link_libraries(ggml-cpu PUBLIC memkind)
|
| 65 |
+
endif()
|
| 66 |
+
|
| 67 |
+
if (CMAKE_OSX_ARCHITECTURES STREQUAL "arm64" OR
|
| 68 |
+
CMAKE_GENERATOR_PLATFORM_LWR STREQUAL "arm64" OR
|
| 69 |
+
(NOT CMAKE_OSX_ARCHITECTURES AND
|
| 70 |
+
NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
| 71 |
+
CMAKE_SYSTEM_PROCESSOR MATCHES "^(aarch64|arm.*|ARM64)$"))
|
| 72 |
+
|
| 73 |
+
message(STATUS "ARM detected")
|
| 74 |
+
|
| 75 |
+
if (MSVC)
|
| 76 |
+
add_compile_definitions(__aarch64__) # MSVC defines _M_ARM64 instead
|
| 77 |
+
add_compile_definitions(__ARM_NEON)
|
| 78 |
+
add_compile_definitions(__ARM_FEATURE_FMA)
|
| 79 |
+
|
| 80 |
+
set(CMAKE_REQUIRED_FLAGS_PREV ${CMAKE_REQUIRED_FLAGS})
|
| 81 |
+
string(JOIN " " CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS} "/arch:armv8.2")
|
| 82 |
+
|
| 83 |
+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vdotq_s32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_DOTPROD)
|
| 84 |
+
if (GGML_COMPILER_SUPPORT_DOTPROD)
|
| 85 |
+
add_compile_definitions(__ARM_FEATURE_DOTPROD)
|
| 86 |
+
endif ()
|
| 87 |
+
|
| 88 |
+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { int8x16_t _a, _b; int32x4_t _s = vmlaq_f32(_s, _a, _b); return 0; }" GGML_COMPILER_SUPPORT_MATMUL_INT8)
|
| 89 |
+
|
| 90 |
+
if (GGML_COMPILER_SUPPORT_MATMUL_INT8)
|
| 91 |
+
add_compile_definitions(__ARM_FEATURE_MATMUL_INT8)
|
| 92 |
+
endif ()
|
| 93 |
+
|
| 94 |
+
check_cxx_source_compiles("#include <arm_neon.h>\nint main() { float16_t _a; float16x8_t _s = vdupq_n_f16(_a); return 0; }" GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
|
| 95 |
+
if (GGML_COMPILER_SUPPORT_FP16_VECTOR_ARITHMETIC)
|
| 96 |
+
add_compile_definitions(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
|
| 97 |
+
endif ()
|
| 98 |
+
|
| 99 |
+
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_PREV})
|
| 100 |
+
else()
|
| 101 |
+
check_cxx_compiler_flag(-mfp16-format=ieee COMPILER_SUPPORTS_FP16_FORMAT_I3E)
|
| 102 |
+
if (NOT "${COMPILER_SUPPORTS_FP16_FORMAT_I3E}" STREQUAL "")
|
| 103 |
+
list(APPEND ARCH_FLAGS -mfp16-format=ieee)
|
| 104 |
+
endif()
|
| 105 |
+
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv6")
|
| 106 |
+
# Raspberry Pi 1, Zero
|
| 107 |
+
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access)
|
| 108 |
+
endif()
|
| 109 |
+
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv7")
|
| 110 |
+
if ("${CMAKE_SYSTEM_NAME}" STREQUAL "Android")
|
| 111 |
+
# Android armeabi-v7a
|
| 112 |
+
list(APPEND ARCH_FLAGS -mfpu=neon-vfpv4 -mno-unaligned-access -funsafe-math-optimizations)
|
| 113 |
+
else()
|
| 114 |
+
# Raspberry Pi 2
|
| 115 |
+
list(APPEND ARCH_FLAGS -mfpu=neon-fp-armv8 -mno-unaligned-access -funsafe-math-optimizations)
|
| 116 |
+
endif()
|
| 117 |
+
endif()
|
| 118 |
+
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "armv8")
|
| 119 |
+
# Android arm64-v8a
|
| 120 |
+
# Raspberry Pi 3, 4, Zero 2 (32-bit)
|
| 121 |
+
list(APPEND ARCH_FLAGS -mno-unaligned-access)
|
| 122 |
+
endif()
|
| 123 |
+
if (GGML_SVE)
|
| 124 |
+
list(APPEND ARCH_FLAGS -march=armv8.6-a+sve)
|
| 125 |
+
endif()
|
| 126 |
+
endif()
|
| 127 |
+
elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LWR MATCHES "^(x86_64|i686|amd64|x64|win32)$" OR
|
| 128 |
+
(NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_GENERATOR_PLATFORM_LWR AND
|
| 129 |
+
CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|i686|AMD64)$"))
|
| 130 |
+
message(STATUS "x86 detected")
|
| 131 |
+
if (MSVC)
|
| 132 |
+
# instruction set detection for MSVC only
|
| 133 |
+
if (GGML_NATIVE)
|
| 134 |
+
# TODO: improve, should not reference files from the parent folder
|
| 135 |
+
include(cmake/FindSIMD.cmake)
|
| 136 |
+
endif ()
|
| 137 |
+
if (GGML_AVX512)
|
| 138 |
+
list(APPEND ARCH_FLAGS /arch:AVX512)
|
| 139 |
+
# MSVC has no compile-time flags enabling specific
|
| 140 |
+
# AVX512 extensions, neither it defines the
|
| 141 |
+
# macros corresponding to the extensions.
|
| 142 |
+
# Do it manually.
|
| 143 |
+
if (GGML_AVX512_VBMI)
|
| 144 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
|
| 145 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
|
| 146 |
+
endif()
|
| 147 |
+
if (GGML_AVX512_VNNI)
|
| 148 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
| 149 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
| 150 |
+
endif()
|
| 151 |
+
if (GGML_AVX512_BF16)
|
| 152 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512BF16__>)
|
| 153 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512BF16__>)
|
| 154 |
+
endif()
|
| 155 |
+
if (GGML_AMX_TILE)
|
| 156 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_TILE__>)
|
| 157 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_TILE__>)
|
| 158 |
+
endif()
|
| 159 |
+
if (GGML_AMX_INT8)
|
| 160 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_INT8__>)
|
| 161 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_INT8__>)
|
| 162 |
+
endif()
|
| 163 |
+
if (GGML_AMX_BF16)
|
| 164 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AMX_BF16__>)
|
| 165 |
+
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AMX_BF16__>)
|
| 166 |
+
endif()
|
| 167 |
+
elseif (GGML_AVX2)
|
| 168 |
+
list(APPEND ARCH_FLAGS /arch:AVX2)
|
| 169 |
+
elseif (GGML_AVX)
|
| 170 |
+
list(APPEND ARCH_FLAGS /arch:AVX)
|
| 171 |
+
endif()
|
| 172 |
+
else()
|
| 173 |
+
if (GGML_NATIVE)
|
| 174 |
+
list(APPEND ARCH_FLAGS -march=native)
|
| 175 |
+
endif()
|
| 176 |
+
if (GGML_F16C)
|
| 177 |
+
list(APPEND ARCH_FLAGS -mf16c)
|
| 178 |
+
endif()
|
| 179 |
+
if (GGML_FMA)
|
| 180 |
+
list(APPEND ARCH_FLAGS -mfma)
|
| 181 |
+
endif()
|
| 182 |
+
if (GGML_AVX)
|
| 183 |
+
list(APPEND ARCH_FLAGS -mavx)
|
| 184 |
+
endif()
|
| 185 |
+
if (GGML_AVX2)
|
| 186 |
+
list(APPEND ARCH_FLAGS -mavx2)
|
| 187 |
+
endif()
|
| 188 |
+
if (GGML_AVX512)
|
| 189 |
+
list(APPEND ARCH_FLAGS -mavx512f)
|
| 190 |
+
list(APPEND ARCH_FLAGS -mavx512dq)
|
| 191 |
+
list(APPEND ARCH_FLAGS -mavx512bw)
|
| 192 |
+
endif()
|
| 193 |
+
if (GGML_AVX512_VBMI)
|
| 194 |
+
list(APPEND ARCH_FLAGS -mavx512vbmi)
|
| 195 |
+
endif()
|
| 196 |
+
if (GGML_AVX512_VNNI)
|
| 197 |
+
list(APPEND ARCH_FLAGS -mavx512vnni)
|
| 198 |
+
endif()
|
| 199 |
+
if (GGML_AVX512_BF16)
|
| 200 |
+
list(APPEND ARCH_FLAGS -mavx512bf16)
|
| 201 |
+
endif()
|
| 202 |
+
if (GGML_AMX_TILE)
|
| 203 |
+
list(APPEND ARCH_FLAGS -mamx-tile)
|
| 204 |
+
endif()
|
| 205 |
+
if (GGML_AMX_INT8)
|
| 206 |
+
list(APPEND ARCH_FLAGS -mamx-int8)
|
| 207 |
+
endif()
|
| 208 |
+
if (GGML_AMX_BF16)
|
| 209 |
+
list(APPEND ARCH_FLAGS -mamx-bf16)
|
| 210 |
+
endif()
|
| 211 |
+
endif()
|
| 212 |
+
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
| 213 |
+
message(STATUS "PowerPC detected")
|
| 214 |
+
execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1"
|
| 215 |
+
OUTPUT_VARIABLE POWER10_M)
|
| 216 |
+
string(FIND ${POWER10_M} "POWER10" substring_index)
|
| 217 |
+
if(${substring_index} GREATER_EQUAL 0)
|
| 218 |
+
list(APPEND ARCH_FLAGS -mcpu=power10)
|
| 219 |
+
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
| 220 |
+
list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
|
| 221 |
+
else()
|
| 222 |
+
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
|
| 223 |
+
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
|
| 224 |
+
endif()
|
| 225 |
+
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64")
|
| 226 |
+
message(STATUS "loongarch64 detected")
|
| 227 |
+
|
| 228 |
+
list(APPEND ARCH_FLAGS -march=loongarch64)
|
| 229 |
+
if (GGML_LASX)
|
| 230 |
+
list(APPEND ARCH_FLAGS -mlasx)
|
| 231 |
+
endif()
|
| 232 |
+
if (GGML_LSX)
|
| 233 |
+
list(APPEND ARCH_FLAGS -mlsx)
|
| 234 |
+
endif()
|
| 235 |
+
else()
|
| 236 |
+
message(STATUS "Unknown architecture")
|
| 237 |
+
endif()
|
| 238 |
+
|
| 239 |
+
target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
|
| 240 |
+
target_compile_options(ggml-cpu PRIVATE "$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
|
| 241 |
+
|
| 242 |
+
if (EMSCRIPTEN)
|
| 243 |
+
set_target_properties(ggml-cpu PROPERTIES COMPILE_FLAGS "-msimd128")
|
| 244 |
+
endif()
|
ggml/src/ggml-cpu/cmake/FindSIMD.cmake
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include(CheckCSourceRuns)
|
| 2 |
+
|
| 3 |
+
set(AVX_CODE "
|
| 4 |
+
#include <immintrin.h>
|
| 5 |
+
int main()
|
| 6 |
+
{
|
| 7 |
+
__m256 a;
|
| 8 |
+
a = _mm256_set1_ps(0);
|
| 9 |
+
return 0;
|
| 10 |
+
}
|
| 11 |
+
")
|
| 12 |
+
|
| 13 |
+
set(AVX512_CODE "
|
| 14 |
+
#include <immintrin.h>
|
| 15 |
+
int main()
|
| 16 |
+
{
|
| 17 |
+
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
|
| 18 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 19 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 20 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 21 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 22 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 23 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
| 24 |
+
0, 0, 0, 0, 0, 0, 0, 0);
|
| 25 |
+
__m512i b = a;
|
| 26 |
+
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
|
| 27 |
+
return 0;
|
| 28 |
+
}
|
| 29 |
+
")
|
| 30 |
+
|
| 31 |
+
set(AVX2_CODE "
|
| 32 |
+
#include <immintrin.h>
|
| 33 |
+
int main()
|
| 34 |
+
{
|
| 35 |
+
__m256i a = {0};
|
| 36 |
+
a = _mm256_abs_epi16(a);
|
| 37 |
+
__m256i x;
|
| 38 |
+
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
|
| 39 |
+
return 0;
|
| 40 |
+
}
|
| 41 |
+
")
|
| 42 |
+
|
| 43 |
+
set(FMA_CODE "
|
| 44 |
+
#include <immintrin.h>
|
| 45 |
+
int main()
|
| 46 |
+
{
|
| 47 |
+
__m256 acc = _mm256_setzero_ps();
|
| 48 |
+
const __m256 d = _mm256_setzero_ps();
|
| 49 |
+
const __m256 p = _mm256_setzero_ps();
|
| 50 |
+
acc = _mm256_fmadd_ps( d, p, acc );
|
| 51 |
+
return 0;
|
| 52 |
+
}
|
| 53 |
+
")
|
| 54 |
+
|
| 55 |
+
macro(check_sse type flags)
|
| 56 |
+
set(__FLAG_I 1)
|
| 57 |
+
set(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
|
| 58 |
+
foreach (__FLAG ${flags})
|
| 59 |
+
if (NOT ${type}_FOUND)
|
| 60 |
+
set(CMAKE_REQUIRED_FLAGS ${__FLAG})
|
| 61 |
+
check_c_source_runs("${${type}_CODE}" HAS_${type}_${__FLAG_I})
|
| 62 |
+
if (HAS_${type}_${__FLAG_I})
|
| 63 |
+
set(${type}_FOUND TRUE CACHE BOOL "${type} support")
|
| 64 |
+
set(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
|
| 65 |
+
endif()
|
| 66 |
+
math(EXPR __FLAG_I "${__FLAG_I}+1")
|
| 67 |
+
endif()
|
| 68 |
+
endforeach()
|
| 69 |
+
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
|
| 70 |
+
|
| 71 |
+
if (NOT ${type}_FOUND)
|
| 72 |
+
set(${type}_FOUND FALSE CACHE BOOL "${type} support")
|
| 73 |
+
set(${type}_FLAGS "" CACHE STRING "${type} flags")
|
| 74 |
+
endif()
|
| 75 |
+
|
| 76 |
+
mark_as_advanced(${type}_FOUND ${type}_FLAGS)
|
| 77 |
+
endmacro()
|
| 78 |
+
|
| 79 |
+
# flags are for MSVC only!
|
| 80 |
+
check_sse("AVX" " ;/arch:AVX")
|
| 81 |
+
if (NOT ${AVX_FOUND})
|
| 82 |
+
set(GGML_AVX OFF)
|
| 83 |
+
else()
|
| 84 |
+
set(GGML_AVX ON)
|
| 85 |
+
endif()
|
| 86 |
+
|
| 87 |
+
check_sse("AVX2" " ;/arch:AVX2")
|
| 88 |
+
check_sse("FMA" " ;/arch:AVX2")
|
| 89 |
+
if ((NOT ${AVX2_FOUND}) OR (NOT ${FMA_FOUND}))
|
| 90 |
+
set(GGML_AVX2 OFF)
|
| 91 |
+
else()
|
| 92 |
+
set(GGML_AVX2 ON)
|
| 93 |
+
endif()
|
| 94 |
+
|
| 95 |
+
check_sse("AVX512" " ;/arch:AVX512")
|
| 96 |
+
if (NOT ${AVX512_FOUND})
|
| 97 |
+
set(GGML_AVX512 OFF)
|
| 98 |
+
else()
|
| 99 |
+
set(GGML_AVX512 ON)
|
| 100 |
+
endif()
|
ggml/src/ggml-cpu/ggml-cpu-aarch64.c
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml/src/ggml-cpu/ggml-cpu-aarch64.h
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include "ggml.h"
|
| 4 |
+
|
| 5 |
+
// GGML internal header
|
| 6 |
+
|
| 7 |
+
#ifdef __cplusplus
|
| 8 |
+
extern "C" {
|
| 9 |
+
#endif
|
| 10 |
+
|
| 11 |
+
// Quantization
|
| 12 |
+
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
|
| 13 |
+
|
| 14 |
+
// GEMV
|
| 15 |
+
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 16 |
+
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 17 |
+
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 18 |
+
|
| 19 |
+
// GEMM
|
| 20 |
+
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 21 |
+
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 22 |
+
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
| 23 |
+
|
| 24 |
+
#ifdef __cplusplus
|
| 25 |
+
}
|
| 26 |
+
#endif
|
| 27 |
+
|
ggml/src/ggml-cpu/ggml-cpu-impl.h
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
// GGML CPU internal header
|
| 4 |
+
|
| 5 |
+
#include "ggml.h"
|
| 6 |
+
#include "ggml-impl.h"
|
| 7 |
+
#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
|
| 8 |
+
//#include <stddef.h>
|
| 9 |
+
#include <stdbool.h>
|
| 10 |
+
#include <string.h> // memcpy
|
| 11 |
+
#include <math.h> // fabsf
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
#ifdef __cplusplus
|
| 15 |
+
extern "C" {
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#if defined(_MSC_VER)
|
| 19 |
+
|
| 20 |
+
#define m512bh(p) p
|
| 21 |
+
#define m512i(p) p
|
| 22 |
+
|
| 23 |
+
#else
|
| 24 |
+
|
| 25 |
+
#define m512bh(p) (__m512bh)(p)
|
| 26 |
+
#define m512i(p) (__m512i)(p)
|
| 27 |
+
|
| 28 |
+
#endif
|
| 29 |
+
|
| 30 |
+
// __FMA__ and __F16C__ are not defined in MSVC, however they are implied with AVX2/AVX512
|
| 31 |
+
#if defined(_MSC_VER) && (defined(__AVX2__) || defined(__AVX512F__))
|
| 32 |
+
#ifndef __FMA__
|
| 33 |
+
#define __FMA__
|
| 34 |
+
#endif
|
| 35 |
+
#ifndef __F16C__
|
| 36 |
+
#define __F16C__
|
| 37 |
+
#endif
|
| 38 |
+
#endif
|
| 39 |
+
|
| 40 |
+
// __SSE3__ and __SSSE3__ are not defined in MSVC, but SSE3/SSSE3 are present when AVX/AVX2/AVX512 are available
|
| 41 |
+
#if defined(_MSC_VER) && (defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__))
|
| 42 |
+
#ifndef __SSE3__
|
| 43 |
+
#define __SSE3__
|
| 44 |
+
#endif
|
| 45 |
+
#ifndef __SSSE3__
|
| 46 |
+
#define __SSSE3__
|
| 47 |
+
#endif
|
| 48 |
+
#endif
|
| 49 |
+
|
| 50 |
+
#if defined(__ARM_FEATURE_SVE)
|
| 51 |
+
#include <arm_sve.h>
|
| 52 |
+
#include <sys/prctl.h>
|
| 53 |
+
#endif
|
| 54 |
+
|
| 55 |
+
// 16-bit float
|
| 56 |
+
// on Arm, we use __fp16
|
| 57 |
+
// on x86, we use uint16_t
|
| 58 |
+
#if defined(__ARM_NEON)
|
| 59 |
+
|
| 60 |
+
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
| 61 |
+
//
|
| 62 |
+
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
| 63 |
+
//
|
| 64 |
+
#include <arm_neon.h>
|
| 65 |
+
|
| 66 |
+
#ifdef _MSC_VER
|
| 67 |
+
|
| 68 |
+
typedef uint16_t ggml_fp16_internal_t;
|
| 69 |
+
|
| 70 |
+
#define ggml_vld1q_u32(w,x,y,z) { ((w) + ((uint64_t)(x) << 32)), ((y) + ((uint64_t)(z) << 32)) }
|
| 71 |
+
|
| 72 |
+
#else
|
| 73 |
+
|
| 74 |
+
typedef __fp16 ggml_fp16_internal_t;
|
| 75 |
+
|
| 76 |
+
#define ggml_vld1q_u32(w,x,y,z) { (w), (x), (y), (z) }
|
| 77 |
+
|
| 78 |
+
#endif // _MSC_VER
|
| 79 |
+
|
| 80 |
+
#if !defined(__aarch64__)
|
| 81 |
+
|
| 82 |
+
// 32-bit ARM compatibility
|
| 83 |
+
|
| 84 |
+
// vaddlvq_s16
|
| 85 |
+
// vpaddq_s16
|
| 86 |
+
// vpaddq_s32
|
| 87 |
+
// vaddvq_s32
|
| 88 |
+
// vaddvq_f32
|
| 89 |
+
// vmaxvq_f32
|
| 90 |
+
// vcvtnq_s32_f32
|
| 91 |
+
// vzip1_u8
|
| 92 |
+
// vzip2_u8
|
| 93 |
+
|
| 94 |
+
inline static int32_t vaddlvq_s16(int16x8_t v) {
|
| 95 |
+
int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v)));
|
| 96 |
+
return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
|
| 100 |
+
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
|
| 101 |
+
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
|
| 102 |
+
return vcombine_s16(a0, b0);
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
inline static int32x4_t vpaddq_s32(int32x4_t a, int32x4_t b) {
|
| 106 |
+
int32x2_t a0 = vpadd_s32(vget_low_s32(a), vget_high_s32(a));
|
| 107 |
+
int32x2_t b0 = vpadd_s32(vget_low_s32(b), vget_high_s32(b));
|
| 108 |
+
return vcombine_s32(a0, b0);
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
inline static int32_t vaddvq_s32(int32x4_t v) {
|
| 112 |
+
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
inline static float vaddvq_f32(float32x4_t v) {
|
| 116 |
+
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline static float vmaxvq_f32(float32x4_t v) {
|
| 120 |
+
return
|
| 121 |
+
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
| 122 |
+
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
|
| 126 |
+
int32x4_t res;
|
| 127 |
+
|
| 128 |
+
res[0] = roundf(vgetq_lane_f32(v, 0));
|
| 129 |
+
res[1] = roundf(vgetq_lane_f32(v, 1));
|
| 130 |
+
res[2] = roundf(vgetq_lane_f32(v, 2));
|
| 131 |
+
res[3] = roundf(vgetq_lane_f32(v, 3));
|
| 132 |
+
|
| 133 |
+
return res;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
inline static uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
| 137 |
+
uint8x8_t res;
|
| 138 |
+
|
| 139 |
+
res[0] = a[0]; res[1] = b[0];
|
| 140 |
+
res[2] = a[1]; res[3] = b[1];
|
| 141 |
+
res[4] = a[2]; res[5] = b[2];
|
| 142 |
+
res[6] = a[3]; res[7] = b[3];
|
| 143 |
+
|
| 144 |
+
return res;
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
inline static uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
| 148 |
+
uint8x8_t res;
|
| 149 |
+
|
| 150 |
+
res[0] = a[4]; res[1] = b[4];
|
| 151 |
+
res[2] = a[5]; res[3] = b[5];
|
| 152 |
+
res[4] = a[6]; res[5] = b[6];
|
| 153 |
+
res[6] = a[7]; res[7] = b[7];
|
| 154 |
+
|
| 155 |
+
return res;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
// vld1q_s16_x2
|
| 159 |
+
// vld1q_u8_x2
|
| 160 |
+
// vld1q_u8_x4
|
| 161 |
+
// vld1q_s8_x2
|
| 162 |
+
// vld1q_s8_x4
|
| 163 |
+
// TODO: double-check these work correctly
|
| 164 |
+
|
| 165 |
+
typedef struct ggml_int16x8x2_t {
|
| 166 |
+
int16x8_t val[2];
|
| 167 |
+
} ggml_int16x8x2_t;
|
| 168 |
+
|
| 169 |
+
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
|
| 170 |
+
ggml_int16x8x2_t res;
|
| 171 |
+
|
| 172 |
+
res.val[0] = vld1q_s16(ptr + 0);
|
| 173 |
+
res.val[1] = vld1q_s16(ptr + 8);
|
| 174 |
+
|
| 175 |
+
return res;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
typedef struct ggml_uint8x16x2_t {
|
| 179 |
+
uint8x16_t val[2];
|
| 180 |
+
} ggml_uint8x16x2_t;
|
| 181 |
+
|
| 182 |
+
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
|
| 183 |
+
ggml_uint8x16x2_t res;
|
| 184 |
+
|
| 185 |
+
res.val[0] = vld1q_u8(ptr + 0);
|
| 186 |
+
res.val[1] = vld1q_u8(ptr + 16);
|
| 187 |
+
|
| 188 |
+
return res;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
typedef struct ggml_uint8x16x4_t {
|
| 192 |
+
uint8x16_t val[4];
|
| 193 |
+
} ggml_uint8x16x4_t;
|
| 194 |
+
|
| 195 |
+
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
|
| 196 |
+
ggml_uint8x16x4_t res;
|
| 197 |
+
|
| 198 |
+
res.val[0] = vld1q_u8(ptr + 0);
|
| 199 |
+
res.val[1] = vld1q_u8(ptr + 16);
|
| 200 |
+
res.val[2] = vld1q_u8(ptr + 32);
|
| 201 |
+
res.val[3] = vld1q_u8(ptr + 48);
|
| 202 |
+
|
| 203 |
+
return res;
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
typedef struct ggml_int8x16x2_t {
|
| 207 |
+
int8x16_t val[2];
|
| 208 |
+
} ggml_int8x16x2_t;
|
| 209 |
+
|
| 210 |
+
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
|
| 211 |
+
ggml_int8x16x2_t res;
|
| 212 |
+
|
| 213 |
+
res.val[0] = vld1q_s8(ptr + 0);
|
| 214 |
+
res.val[1] = vld1q_s8(ptr + 16);
|
| 215 |
+
|
| 216 |
+
return res;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
typedef struct ggml_int8x16x4_t {
|
| 220 |
+
int8x16_t val[4];
|
| 221 |
+
} ggml_int8x16x4_t;
|
| 222 |
+
|
| 223 |
+
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
|
| 224 |
+
ggml_int8x16x4_t res;
|
| 225 |
+
|
| 226 |
+
res.val[0] = vld1q_s8(ptr + 0);
|
| 227 |
+
res.val[1] = vld1q_s8(ptr + 16);
|
| 228 |
+
res.val[2] = vld1q_s8(ptr + 32);
|
| 229 |
+
res.val[3] = vld1q_s8(ptr + 48);
|
| 230 |
+
|
| 231 |
+
return res;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
// NOTE: not tested
|
| 235 |
+
inline static int8x16_t ggml_vqtbl1q_s8(int8x16_t a, uint8x16_t b) {
|
| 236 |
+
int8x16_t res;
|
| 237 |
+
|
| 238 |
+
res[ 0] = a[b[ 0]];
|
| 239 |
+
res[ 1] = a[b[ 1]];
|
| 240 |
+
res[ 2] = a[b[ 2]];
|
| 241 |
+
res[ 3] = a[b[ 3]];
|
| 242 |
+
res[ 4] = a[b[ 4]];
|
| 243 |
+
res[ 5] = a[b[ 5]];
|
| 244 |
+
res[ 6] = a[b[ 6]];
|
| 245 |
+
res[ 7] = a[b[ 7]];
|
| 246 |
+
res[ 8] = a[b[ 8]];
|
| 247 |
+
res[ 9] = a[b[ 9]];
|
| 248 |
+
res[10] = a[b[10]];
|
| 249 |
+
res[11] = a[b[11]];
|
| 250 |
+
res[12] = a[b[12]];
|
| 251 |
+
res[13] = a[b[13]];
|
| 252 |
+
res[14] = a[b[14]];
|
| 253 |
+
res[15] = a[b[15]];
|
| 254 |
+
|
| 255 |
+
return res;
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
// NOTE: not tested
|
| 259 |
+
inline static uint8x16_t ggml_vqtbl1q_u8(uint8x16_t a, uint8x16_t b) {
|
| 260 |
+
uint8x16_t res;
|
| 261 |
+
|
| 262 |
+
res[ 0] = a[b[ 0]];
|
| 263 |
+
res[ 1] = a[b[ 1]];
|
| 264 |
+
res[ 2] = a[b[ 2]];
|
| 265 |
+
res[ 3] = a[b[ 3]];
|
| 266 |
+
res[ 4] = a[b[ 4]];
|
| 267 |
+
res[ 5] = a[b[ 5]];
|
| 268 |
+
res[ 6] = a[b[ 6]];
|
| 269 |
+
res[ 7] = a[b[ 7]];
|
| 270 |
+
res[ 8] = a[b[ 8]];
|
| 271 |
+
res[ 9] = a[b[ 9]];
|
| 272 |
+
res[10] = a[b[10]];
|
| 273 |
+
res[11] = a[b[11]];
|
| 274 |
+
res[12] = a[b[12]];
|
| 275 |
+
res[13] = a[b[13]];
|
| 276 |
+
res[14] = a[b[14]];
|
| 277 |
+
res[15] = a[b[15]];
|
| 278 |
+
|
| 279 |
+
return res;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
#else
|
| 283 |
+
|
| 284 |
+
#define ggml_int16x8x2_t int16x8x2_t
|
| 285 |
+
#define ggml_uint8x16x2_t uint8x16x2_t
|
| 286 |
+
#define ggml_uint8x16x4_t uint8x16x4_t
|
| 287 |
+
#define ggml_int8x16x2_t int8x16x2_t
|
| 288 |
+
#define ggml_int8x16x4_t int8x16x4_t
|
| 289 |
+
|
| 290 |
+
#define ggml_vld1q_s16_x2 vld1q_s16_x2
|
| 291 |
+
#define ggml_vld1q_u8_x2 vld1q_u8_x2
|
| 292 |
+
#define ggml_vld1q_u8_x4 vld1q_u8_x4
|
| 293 |
+
#define ggml_vld1q_s8_x2 vld1q_s8_x2
|
| 294 |
+
#define ggml_vld1q_s8_x4 vld1q_s8_x4
|
| 295 |
+
#define ggml_vqtbl1q_s8 vqtbl1q_s8
|
| 296 |
+
#define ggml_vqtbl1q_u8 vqtbl1q_u8
|
| 297 |
+
|
| 298 |
+
#endif // !defined(__aarch64__)
|
| 299 |
+
|
| 300 |
+
#if !defined(__ARM_FEATURE_DOTPROD)
|
| 301 |
+
|
| 302 |
+
inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b) {
|
| 303 |
+
const int16x8_t p0 = vmull_s8(vget_low_s8 (a), vget_low_s8 (b));
|
| 304 |
+
const int16x8_t p1 = vmull_s8(vget_high_s8(a), vget_high_s8(b));
|
| 305 |
+
|
| 306 |
+
return vaddq_s32(acc, vaddq_s32(vpaddlq_s16(p0), vpaddlq_s16(p1)));
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
#else
|
| 310 |
+
|
| 311 |
+
#define ggml_vdotq_s32(a, b, c) vdotq_s32(a, b, c)
|
| 312 |
+
|
| 313 |
+
#endif // !defined(__ARM_FEATURE_DOTPROD)
|
| 314 |
+
|
| 315 |
+
#endif // defined(__ARM_NEON)
|
| 316 |
+
|
| 317 |
+
#ifdef __wasm_simd128__
|
| 318 |
+
#include <wasm_simd128.h>
|
| 319 |
+
#else
|
| 320 |
+
#ifdef __POWER9_VECTOR__
|
| 321 |
+
#include <altivec.h>
|
| 322 |
+
#undef bool
|
| 323 |
+
#define bool _Bool
|
| 324 |
+
#else
|
| 325 |
+
#if defined(_MSC_VER) || defined(__MINGW32__)
|
| 326 |
+
#include <intrin.h>
|
| 327 |
+
#else
|
| 328 |
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__) || defined(__SSE__)
|
| 329 |
+
#if !defined(__riscv)
|
| 330 |
+
#include <immintrin.h>
|
| 331 |
+
#endif
|
| 332 |
+
#endif
|
| 333 |
+
#endif
|
| 334 |
+
#endif
|
| 335 |
+
#endif
|
| 336 |
+
|
| 337 |
+
#ifdef __riscv_v_intrinsic
|
| 338 |
+
#include <riscv_vector.h>
|
| 339 |
+
#endif
|
| 340 |
+
|
| 341 |
+
#if defined(__loongarch64)
|
| 342 |
+
#if defined(__loongarch_asx)
|
| 343 |
+
#include <lasxintrin.h>
|
| 344 |
+
#endif
|
| 345 |
+
#if defined(__loongarch_sx)
|
| 346 |
+
#include <lsxintrin.h>
|
| 347 |
+
#endif
|
| 348 |
+
#endif
|
| 349 |
+
|
| 350 |
+
#if defined(__loongarch_asx)
|
| 351 |
+
|
| 352 |
+
typedef union {
|
| 353 |
+
int32_t i;
|
| 354 |
+
float f;
|
| 355 |
+
} ft_union;
|
| 356 |
+
|
| 357 |
+
/* float type data load instructions */
|
| 358 |
+
static __m128 __lsx_vreplfr2vr_s(float val) {
|
| 359 |
+
ft_union fi_tmpval = {.f = val};
|
| 360 |
+
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
static __m256 __lasx_xvreplfr2vr_s(float val) {
|
| 364 |
+
ft_union fi_tmpval = {.f = val};
|
| 365 |
+
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
|
| 366 |
+
}
|
| 367 |
+
#endif
|
| 368 |
+
|
| 369 |
+
#ifdef __cplusplus
|
| 370 |
+
}
|
| 371 |
+
#endif
|
ggml/src/ggml-cpu/ggml-cpu-quants.c
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml/src/ggml-cpu/ggml-cpu-quants.h
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#define GGML_COMMON_DECL_C
|
| 4 |
+
#include "ggml-common.h"
|
| 5 |
+
|
| 6 |
+
#include "ggml.h"
|
| 7 |
+
|
| 8 |
+
// GGML CPU internal header
|
| 9 |
+
|
| 10 |
+
#ifdef __cplusplus
|
| 11 |
+
extern "C" {
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
// Quantization
|
| 15 |
+
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 16 |
+
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 17 |
+
void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 18 |
+
void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 19 |
+
void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 20 |
+
void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 21 |
+
|
| 22 |
+
void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 23 |
+
void quantize_row_q3_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 24 |
+
void quantize_row_q4_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 25 |
+
void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 26 |
+
void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 27 |
+
void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 28 |
+
|
| 29 |
+
void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 30 |
+
void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 31 |
+
|
| 32 |
+
void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 33 |
+
void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
| 34 |
+
|
| 35 |
+
// Dot product
|
| 36 |
+
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 37 |
+
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 38 |
+
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 39 |
+
void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 40 |
+
void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 41 |
+
|
| 42 |
+
void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 43 |
+
void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 44 |
+
void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 45 |
+
void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 46 |
+
void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 47 |
+
|
| 48 |
+
void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 49 |
+
void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 50 |
+
|
| 51 |
+
void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 52 |
+
void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 53 |
+
void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 54 |
+
void ggml_vec_dot_iq3_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 55 |
+
void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 56 |
+
void ggml_vec_dot_iq1_m_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 57 |
+
void ggml_vec_dot_iq4_nl_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 58 |
+
void ggml_vec_dot_iq4_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 59 |
+
void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
| 60 |
+
|
| 61 |
+
#ifdef __cplusplus
|
| 62 |
+
}
|
| 63 |
+
#endif
|
ggml/src/ggml-cpu/ggml-cpu.c
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml/src/ggml-cpu/ggml-cpu.cpp
ADDED
|
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include "ggml-backend.h"
|
| 2 |
+
#include "ggml-backend-impl.h"
|
| 3 |
+
#include "ggml-cpu.h"
|
| 4 |
+
#include "ggml-impl.h"
|
| 5 |
+
#include <cctype>
|
| 6 |
+
#include <string>
|
| 7 |
+
#include <vector>
|
| 8 |
+
|
| 9 |
+
#if defined(__APPLE__)
|
| 10 |
+
#include <sys/types.h>
|
| 11 |
+
#include <sys/sysctl.h>
|
| 12 |
+
#endif
|
| 13 |
+
|
| 14 |
+
#if defined(_WIN32)
|
| 15 |
+
#define WIN32_LEAN_AND_MEAN
|
| 16 |
+
#ifndef NOMINMAX
|
| 17 |
+
#define NOMINMAX
|
| 18 |
+
#endif
|
| 19 |
+
#include <windows.h>
|
| 20 |
+
#endif
|
| 21 |
+
|
| 22 |
+
// ggml-backend interface
|
| 23 |
+
|
| 24 |
+
#ifdef GGML_USE_CPU_HBM
|
| 25 |
+
|
| 26 |
+
// buffer type HBM
|
| 27 |
+
|
| 28 |
+
#include <hbwmalloc.h>
|
| 29 |
+
|
| 30 |
+
static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
|
| 31 |
+
return "CPU_HBM";
|
| 32 |
+
|
| 33 |
+
GGML_UNUSED(buft);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
| 37 |
+
hbw_free(buffer->context);
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
| 41 |
+
void * ptr;
|
| 42 |
+
int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
|
| 43 |
+
if (result != 0) {
|
| 44 |
+
GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
|
| 45 |
+
return NULL;
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
| 49 |
+
buffer->buft = buft;
|
| 50 |
+
buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
|
| 51 |
+
|
| 52 |
+
return buffer;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
|
| 56 |
+
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
|
| 57 |
+
/* .iface = */ {
|
| 58 |
+
/* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
|
| 59 |
+
/* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
|
| 60 |
+
/* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
|
| 61 |
+
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
|
| 62 |
+
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
|
| 63 |
+
/* .is_host = */ ggml_backend_cpu_buffer_type_is_host,
|
| 64 |
+
},
|
| 65 |
+
/* .context = */ NULL,
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
return &ggml_backend_cpu_buffer_type_hbm;
|
| 69 |
+
}
|
| 70 |
+
#endif
|
| 71 |
+
|
| 72 |
+
static ggml_backend_buffer_type_t * ggml_backend_cpu_get_extra_bufts(ggml_backend_dev_t device) {
|
| 73 |
+
static ggml_backend_buffer_type_t bufts[] = {
|
| 74 |
+
#ifdef GGML_USE_CPU_HBM
|
| 75 |
+
ggml_backend_cpu_hbm_buffer_type(),
|
| 76 |
+
#endif
|
| 77 |
+
NULL
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
return bufts;
|
| 81 |
+
|
| 82 |
+
GGML_UNUSED(device);
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
// CPU backend - backend (stream)
|
| 86 |
+
|
| 87 |
+
struct ggml_backend_cpu_context {
|
| 88 |
+
int n_threads;
|
| 89 |
+
ggml_threadpool_t threadpool;
|
| 90 |
+
|
| 91 |
+
uint8_t * work_data;
|
| 92 |
+
size_t work_size;
|
| 93 |
+
|
| 94 |
+
ggml_abort_callback abort_callback;
|
| 95 |
+
void * abort_callback_data;
|
| 96 |
+
};
|
| 97 |
+
|
| 98 |
+
static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
|
| 99 |
+
return "CPU";
|
| 100 |
+
|
| 101 |
+
GGML_UNUSED(backend);
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
static void ggml_backend_cpu_free(ggml_backend_t backend) {
|
| 105 |
+
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
|
| 106 |
+
delete[] cpu_ctx->work_data;
|
| 107 |
+
delete cpu_ctx;
|
| 108 |
+
delete backend;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
struct ggml_backend_plan_cpu {
|
| 112 |
+
struct ggml_cplan cplan;
|
| 113 |
+
struct ggml_cgraph cgraph;
|
| 114 |
+
};
|
| 115 |
+
|
| 116 |
+
static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
|
| 117 |
+
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
|
| 118 |
+
|
| 119 |
+
struct ggml_backend_plan_cpu * cpu_plan = new ggml_backend_plan_cpu;
|
| 120 |
+
|
| 121 |
+
cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
|
| 122 |
+
cpu_plan->cgraph = *cgraph; // FIXME: deep copy
|
| 123 |
+
|
| 124 |
+
if (cpu_plan->cplan.work_size > 0) {
|
| 125 |
+
cpu_plan->cplan.work_data = new uint8_t[cpu_plan->cplan.work_size];
|
| 126 |
+
if (cpu_plan->cplan.work_data == NULL) {
|
| 127 |
+
delete cpu_plan;
|
| 128 |
+
return NULL;
|
| 129 |
+
}
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
|
| 133 |
+
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
|
| 134 |
+
|
| 135 |
+
return cpu_plan;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
| 139 |
+
struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
|
| 140 |
+
|
| 141 |
+
delete[] cpu_plan->cplan.work_data;
|
| 142 |
+
delete cpu_plan;
|
| 143 |
+
|
| 144 |
+
GGML_UNUSED(backend);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
static enum ggml_status ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
|
| 148 |
+
struct ggml_backend_plan_cpu * cpu_plan = (struct ggml_backend_plan_cpu *)plan;
|
| 149 |
+
|
| 150 |
+
return ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
|
| 151 |
+
|
| 152 |
+
GGML_UNUSED(backend);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
| 156 |
+
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
|
| 157 |
+
|
| 158 |
+
struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads, cpu_ctx->threadpool);
|
| 159 |
+
|
| 160 |
+
if (cpu_ctx->work_size < cplan.work_size) {
|
| 161 |
+
delete[] cpu_ctx->work_data;
|
| 162 |
+
cpu_ctx->work_data = new uint8_t[cplan.work_size];
|
| 163 |
+
if (cpu_ctx->work_data == NULL) {
|
| 164 |
+
cpu_ctx->work_size = 0;
|
| 165 |
+
return GGML_STATUS_ALLOC_FAILED;
|
| 166 |
+
}
|
| 167 |
+
cpu_ctx->work_size = cplan.work_size;
|
| 168 |
+
}
|
| 169 |
+
cplan.work_data = (uint8_t *)cpu_ctx->work_data;
|
| 170 |
+
|
| 171 |
+
cplan.abort_callback = cpu_ctx->abort_callback;
|
| 172 |
+
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
|
| 173 |
+
|
| 174 |
+
return ggml_graph_compute(cgraph, &cplan);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
static const struct ggml_backend_i ggml_backend_cpu_i = {
|
| 178 |
+
/* .get_name = */ ggml_backend_cpu_get_name,
|
| 179 |
+
/* .free = */ ggml_backend_cpu_free,
|
| 180 |
+
/* .set_tensor_async = */ NULL,
|
| 181 |
+
/* .get_tensor_async = */ NULL,
|
| 182 |
+
/* .cpy_tensor_async = */ NULL,
|
| 183 |
+
/* .synchronize = */ NULL,
|
| 184 |
+
/* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
|
| 185 |
+
/* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
|
| 186 |
+
/* .graph_plan_update = */ NULL,
|
| 187 |
+
/* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
|
| 188 |
+
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
|
| 189 |
+
/* .event_record = */ NULL,
|
| 190 |
+
/* .event_wait = */ NULL,
|
| 191 |
+
};
|
| 192 |
+
|
| 193 |
+
static ggml_guid_t ggml_backend_cpu_guid(void) {
|
| 194 |
+
static ggml_guid guid = { 0xaa, 0x67, 0xc7, 0x43, 0x96, 0xe6, 0xa3, 0x8a, 0xe3, 0xaf, 0xea, 0x92, 0x36, 0xbc, 0xfc, 0x89 };
|
| 195 |
+
return &guid;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
ggml_backend_t ggml_backend_cpu_init(void) {
|
| 199 |
+
// initialize CPU backend now to avoid slowing the first graph computation
|
| 200 |
+
ggml_cpu_init();
|
| 201 |
+
|
| 202 |
+
struct ggml_backend_cpu_context * ctx = new ggml_backend_cpu_context;
|
| 203 |
+
if (ctx == NULL) {
|
| 204 |
+
return NULL;
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
ctx->n_threads = GGML_DEFAULT_N_THREADS;
|
| 208 |
+
ctx->threadpool = NULL;
|
| 209 |
+
ctx->work_data = NULL;
|
| 210 |
+
ctx->work_size = 0;
|
| 211 |
+
ctx->abort_callback = NULL;
|
| 212 |
+
ctx->abort_callback_data = NULL;
|
| 213 |
+
|
| 214 |
+
ggml_backend_t cpu_backend = new ggml_backend {
|
| 215 |
+
/* .guid = */ ggml_backend_cpu_guid(),
|
| 216 |
+
/* .interface = */ ggml_backend_cpu_i,
|
| 217 |
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
| 218 |
+
/* .context = */ ctx,
|
| 219 |
+
};
|
| 220 |
+
|
| 221 |
+
if (cpu_backend == NULL) {
|
| 222 |
+
delete ctx;
|
| 223 |
+
return NULL;
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
return cpu_backend;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
bool ggml_backend_is_cpu(ggml_backend_t backend) {
|
| 230 |
+
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_cpu_guid());
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
|
| 234 |
+
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
|
| 235 |
+
|
| 236 |
+
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
|
| 237 |
+
ctx->n_threads = n_threads;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
void ggml_backend_cpu_set_threadpool(ggml_backend_t backend_cpu, ggml_threadpool_t threadpool) {
|
| 241 |
+
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
|
| 242 |
+
|
| 243 |
+
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
|
| 244 |
+
|
| 245 |
+
if (ctx->threadpool && ctx->threadpool != threadpool) {
|
| 246 |
+
// already had a different threadpool, pause/suspend it before switching
|
| 247 |
+
ggml_threadpool_pause(ctx->threadpool);
|
| 248 |
+
}
|
| 249 |
+
ctx->threadpool = threadpool;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
|
| 253 |
+
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
|
| 254 |
+
|
| 255 |
+
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
|
| 256 |
+
ctx->abort_callback = abort_callback;
|
| 257 |
+
ctx->abort_callback_data = abort_callback_data;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// CPU backend - device
|
| 261 |
+
|
| 262 |
+
struct ggml_backend_cpu_device_context {
|
| 263 |
+
std::string description = "CPU";
|
| 264 |
+
|
| 265 |
+
ggml_backend_cpu_device_context() {
|
| 266 |
+
#ifdef __APPLE__
|
| 267 |
+
size_t len = 0;
|
| 268 |
+
if (!sysctlbyname("machdep.cpu.brand_string", NULL, &len, NULL, 0)) {
|
| 269 |
+
description.resize(len);
|
| 270 |
+
sysctlbyname("machdep.cpu.brand_string", &description[0], &len, NULL, 0); // NOLINT
|
| 271 |
+
}
|
| 272 |
+
#elif defined(__linux__)
|
| 273 |
+
FILE * f = fopen("/proc/cpuinfo", "r");
|
| 274 |
+
if (f) {
|
| 275 |
+
char buf[1024];
|
| 276 |
+
while (fgets(buf, sizeof(buf), f)) {
|
| 277 |
+
if (strncmp(buf, "model name", 10) == 0) {
|
| 278 |
+
char * p = strchr(buf, ':');
|
| 279 |
+
if (p) {
|
| 280 |
+
p++;
|
| 281 |
+
while (std::isspace(*p)) {
|
| 282 |
+
p++;
|
| 283 |
+
}
|
| 284 |
+
while (std::isspace(p[strlen(p) - 1])) {
|
| 285 |
+
p[strlen(p) - 1] = '\0';
|
| 286 |
+
}
|
| 287 |
+
description = p;
|
| 288 |
+
break;
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
fclose(f);
|
| 293 |
+
}
|
| 294 |
+
#elif defined(_WIN32)
|
| 295 |
+
HKEY hKey;
|
| 296 |
+
if (RegOpenKeyEx(HKEY_LOCAL_MACHINE,
|
| 297 |
+
TEXT("HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"),
|
| 298 |
+
0,
|
| 299 |
+
KEY_READ,
|
| 300 |
+
&hKey) == ERROR_SUCCESS) {
|
| 301 |
+
DWORD cpu_brand_size = 0;
|
| 302 |
+
if (RegQueryValueExA(hKey,
|
| 303 |
+
TEXT("ProcessorNameString"),
|
| 304 |
+
NULL,
|
| 305 |
+
NULL,
|
| 306 |
+
NULL,
|
| 307 |
+
&cpu_brand_size) == ERROR_SUCCESS) {
|
| 308 |
+
description.resize(cpu_brand_size);
|
| 309 |
+
if (RegQueryValueExA(hKey,
|
| 310 |
+
TEXT("ProcessorNameString"),
|
| 311 |
+
NULL,
|
| 312 |
+
NULL,
|
| 313 |
+
(LPBYTE)&description[0], // NOLINT
|
| 314 |
+
&cpu_brand_size) == ERROR_SUCCESS) {
|
| 315 |
+
if (description.find('\0') != std::string::npos) {
|
| 316 |
+
description.resize(description.find('\0'));
|
| 317 |
+
}
|
| 318 |
+
}
|
| 319 |
+
}
|
| 320 |
+
RegCloseKey(hKey);
|
| 321 |
+
}
|
| 322 |
+
#endif
|
| 323 |
+
}
|
| 324 |
+
};
|
| 325 |
+
|
| 326 |
+
static const char * ggml_backend_cpu_device_get_name(ggml_backend_dev_t dev) {
|
| 327 |
+
return "CPU";
|
| 328 |
+
|
| 329 |
+
GGML_UNUSED(dev);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
static const char * ggml_backend_cpu_device_get_description(ggml_backend_dev_t dev) {
|
| 333 |
+
struct ggml_backend_cpu_device_context * ctx = (struct ggml_backend_cpu_device_context *)dev->context;
|
| 334 |
+
|
| 335 |
+
return ctx->description.c_str();
|
| 336 |
+
}
|
| 337 |
+
|
| 338 |
+
static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
| 339 |
+
// TODO
|
| 340 |
+
*free = 0;
|
| 341 |
+
*total = 0;
|
| 342 |
+
|
| 343 |
+
GGML_UNUSED(dev);
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
static enum ggml_backend_dev_type ggml_backend_cpu_device_get_type(ggml_backend_dev_t dev) {
|
| 347 |
+
return GGML_BACKEND_DEVICE_TYPE_CPU;
|
| 348 |
+
|
| 349 |
+
GGML_UNUSED(dev);
|
| 350 |
+
}
|
| 351 |
+
|
| 352 |
+
static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
| 353 |
+
props->name = ggml_backend_cpu_device_get_name(dev);
|
| 354 |
+
props->description = ggml_backend_cpu_device_get_description(dev);
|
| 355 |
+
props->type = ggml_backend_cpu_device_get_type(dev);
|
| 356 |
+
ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
| 357 |
+
props->caps = {
|
| 358 |
+
/* .async = */ false,
|
| 359 |
+
/* .host_buffer = */ false,
|
| 360 |
+
/* .buffer_from_host_ptr = */ true,
|
| 361 |
+
/* .events = */ false,
|
| 362 |
+
};
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
static ggml_backend_t ggml_backend_cpu_device_init_backend(ggml_backend_dev_t dev, const char * params) {
|
| 366 |
+
return ggml_backend_cpu_init();
|
| 367 |
+
|
| 368 |
+
GGML_UNUSED(dev);
|
| 369 |
+
GGML_UNUSED(params);
|
| 370 |
+
}
|
| 371 |
+
|
| 372 |
+
static ggml_backend_buffer_type_t ggml_backend_cpu_device_get_buffer_type(ggml_backend_dev_t dev) {
|
| 373 |
+
return ggml_backend_cpu_buffer_type();
|
| 374 |
+
|
| 375 |
+
GGML_UNUSED(dev);
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
static ggml_backend_buffer_t ggml_backend_cpu_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
|
| 379 |
+
return ggml_backend_cpu_buffer_from_ptr(ptr, size);
|
| 380 |
+
|
| 381 |
+
GGML_UNUSED(dev);
|
| 382 |
+
GGML_UNUSED(max_tensor_size);
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
| 386 |
+
switch (op->op) {
|
| 387 |
+
case GGML_OP_CPY:
|
| 388 |
+
return
|
| 389 |
+
op->type != GGML_TYPE_IQ2_XXS &&
|
| 390 |
+
op->type != GGML_TYPE_IQ2_XS &&
|
| 391 |
+
op->type != GGML_TYPE_IQ1_S &&
|
| 392 |
+
op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
|
| 393 |
+
case GGML_OP_MUL_MAT:
|
| 394 |
+
return op->src[1]->type == GGML_TYPE_F32;// FIXME || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type;
|
| 395 |
+
case GGML_OP_ROPE_BACK:
|
| 396 |
+
return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
|
| 397 |
+
case GGML_OP_IM2COL_BACK:
|
| 398 |
+
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
| 399 |
+
case GGML_OP_OUT_PROD:
|
| 400 |
+
return (op->src[0]->type == GGML_TYPE_F32 || ggml_is_quantized(op->src[0]->type)) && op->src[1]->type == GGML_TYPE_F32;
|
| 401 |
+
default:
|
| 402 |
+
return true;
|
| 403 |
+
}
|
| 404 |
+
|
| 405 |
+
GGML_UNUSED(dev);
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
static bool ggml_backend_cpu_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
| 409 |
+
return ggml_backend_buft_is_host(buft);
|
| 410 |
+
|
| 411 |
+
GGML_UNUSED(dev);
|
| 412 |
+
}
|
| 413 |
+
|
| 414 |
+
static const struct ggml_backend_device_i ggml_backend_cpu_device_i = {
|
| 415 |
+
/* .get_name = */ ggml_backend_cpu_device_get_name,
|
| 416 |
+
/* .get_description = */ ggml_backend_cpu_device_get_description,
|
| 417 |
+
/* .get_memory = */ ggml_backend_cpu_device_get_memory,
|
| 418 |
+
/* .get_type = */ ggml_backend_cpu_device_get_type,
|
| 419 |
+
/* .get_props = */ ggml_backend_cpu_device_get_props,
|
| 420 |
+
/* .init_backend = */ ggml_backend_cpu_device_init_backend,
|
| 421 |
+
/* .get_buffer_type = */ ggml_backend_cpu_device_get_buffer_type,
|
| 422 |
+
/* .get_host_buffer_type = */ NULL,
|
| 423 |
+
/* .buffer_from_host_ptr = */ ggml_backend_cpu_device_buffer_from_host_ptr,
|
| 424 |
+
/* .supports_op = */ ggml_backend_cpu_device_supports_op,
|
| 425 |
+
/* .supports_buft = */ ggml_backend_cpu_device_supports_buft,
|
| 426 |
+
/* .offload_op = */ NULL,
|
| 427 |
+
/* .event_new = */ NULL,
|
| 428 |
+
/* .event_free = */ NULL,
|
| 429 |
+
/* .event_synchronize = */ NULL,
|
| 430 |
+
};
|
| 431 |
+
|
| 432 |
+
// CPU backend - backend (reg)
|
| 433 |
+
|
| 434 |
+
static const char * ggml_backend_cpu_reg_get_name(ggml_backend_reg_t reg) {
|
| 435 |
+
return "CPU";
|
| 436 |
+
|
| 437 |
+
GGML_UNUSED(reg);
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
static size_t ggml_backend_cpu_reg_get_device_count(ggml_backend_reg_t reg) {
|
| 441 |
+
return 1;
|
| 442 |
+
|
| 443 |
+
GGML_UNUSED(reg);
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
|
| 447 |
+
GGML_ASSERT(index == 0);
|
| 448 |
+
|
| 449 |
+
static ggml_backend_cpu_device_context ctx;
|
| 450 |
+
static ggml_backend_device ggml_backend_cpu_device = {
|
| 451 |
+
/* .iface = */ ggml_backend_cpu_device_i,
|
| 452 |
+
/* .reg = */ reg,
|
| 453 |
+
/* .context = */ &ctx,
|
| 454 |
+
};
|
| 455 |
+
|
| 456 |
+
return &ggml_backend_cpu_device;
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
struct ggml_backend_feature {
|
| 460 |
+
const char * name;
|
| 461 |
+
const char * value;
|
| 462 |
+
};
|
| 463 |
+
|
| 464 |
+
// Not used yet
|
| 465 |
+
// This is intended to replace the the ggml_cpu_has_* functions when loading the CPU backend dynamically,
|
| 466 |
+
// and additionally to allow other backends to expose their own list of features that applications can query using the same API.
|
| 467 |
+
static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t reg) {
|
| 468 |
+
static std::vector<ggml_backend_feature> features = []() {
|
| 469 |
+
std::vector<ggml_backend_feature> features;
|
| 470 |
+
if (ggml_cpu_has_sse3()) {
|
| 471 |
+
features.push_back({ "SSE3", "1" });
|
| 472 |
+
}
|
| 473 |
+
if (ggml_cpu_has_ssse3()) {
|
| 474 |
+
features.push_back({ "SSSE3", "1" });
|
| 475 |
+
}
|
| 476 |
+
if (ggml_cpu_has_avx()) {
|
| 477 |
+
features.push_back({ "AVX", "1" });
|
| 478 |
+
}
|
| 479 |
+
if (ggml_cpu_has_avx2()) {
|
| 480 |
+
features.push_back({ "AVX2", "1" });
|
| 481 |
+
}
|
| 482 |
+
if (ggml_cpu_has_f16c()) {
|
| 483 |
+
features.push_back({ "F16C", "1" });
|
| 484 |
+
}
|
| 485 |
+
if (ggml_cpu_has_fma()) {
|
| 486 |
+
features.push_back({ "FMA", "1" });
|
| 487 |
+
}
|
| 488 |
+
if (ggml_cpu_has_avx_vnni()) {
|
| 489 |
+
features.push_back({ "AVX_VNNI", "1" });
|
| 490 |
+
}
|
| 491 |
+
if (ggml_cpu_has_avx512()) {
|
| 492 |
+
features.push_back({ "AVX512", "1" });
|
| 493 |
+
}
|
| 494 |
+
if (ggml_cpu_has_avx512_vbmi()) {
|
| 495 |
+
features.push_back({ "AVX512_VBMI", "1" });
|
| 496 |
+
}
|
| 497 |
+
if (ggml_cpu_has_avx512_vnni()) {
|
| 498 |
+
features.push_back({ "AVX512_VNNI", "1" });
|
| 499 |
+
}
|
| 500 |
+
if (ggml_cpu_has_avx512_bf16()) {
|
| 501 |
+
features.push_back({ "AVX512_BF16", "1" });
|
| 502 |
+
}
|
| 503 |
+
if (ggml_cpu_has_amx_int8()) {
|
| 504 |
+
features.push_back({ "AMX_INT8", "1" });
|
| 505 |
+
}
|
| 506 |
+
if (ggml_cpu_has_neon()) {
|
| 507 |
+
features.push_back({ "NEON", "1" });
|
| 508 |
+
}
|
| 509 |
+
if (ggml_cpu_has_arm_fma()) {
|
| 510 |
+
features.push_back({ "ARM_FMA", "1" });
|
| 511 |
+
}
|
| 512 |
+
if (ggml_cpu_has_fp16_va()) {
|
| 513 |
+
features.push_back({ "FP16_VA", "1" });
|
| 514 |
+
}
|
| 515 |
+
if (ggml_cpu_has_matmul_int8()) {
|
| 516 |
+
features.push_back({ "MATMUL_INT8", "1" });
|
| 517 |
+
}
|
| 518 |
+
if (ggml_cpu_has_sve()) {
|
| 519 |
+
features.push_back({ "SVE", "1" });
|
| 520 |
+
}
|
| 521 |
+
if (ggml_cpu_get_sve_cnt() > 0) {
|
| 522 |
+
static std::string sve_cnt = std::to_string(ggml_cpu_get_sve_cnt());
|
| 523 |
+
features.push_back({ "SVE_CNT", sve_cnt.c_str() });
|
| 524 |
+
}
|
| 525 |
+
if (ggml_cpu_has_riscv_v()) {
|
| 526 |
+
features.push_back({ "RISCV_V", "1" });
|
| 527 |
+
}
|
| 528 |
+
if (ggml_cpu_has_vsx()) {
|
| 529 |
+
features.push_back({ "VSX", "1" });
|
| 530 |
+
}
|
| 531 |
+
if (ggml_cpu_has_wasm_simd()) {
|
| 532 |
+
features.push_back({ "WASM_SIMD", "1" });
|
| 533 |
+
}
|
| 534 |
+
if (ggml_cpu_has_llamafile()) {
|
| 535 |
+
features.push_back({ "LLAMAFILE", "1" });
|
| 536 |
+
}
|
| 537 |
+
|
| 538 |
+
features.push_back({ nullptr, nullptr });
|
| 539 |
+
|
| 540 |
+
return features;
|
| 541 |
+
}();
|
| 542 |
+
|
| 543 |
+
return features.data();
|
| 544 |
+
|
| 545 |
+
GGML_UNUSED(reg);
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
|
| 549 |
+
if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
|
| 550 |
+
return (void *)ggml_backend_cpu_set_n_threads;
|
| 551 |
+
}
|
| 552 |
+
if (strcmp(name, "ggml_backend_dev_get_extra_bufts") == 0) {
|
| 553 |
+
return (void *)ggml_backend_cpu_get_extra_bufts;
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
return NULL;
|
| 557 |
+
|
| 558 |
+
GGML_UNUSED(reg);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
|
| 562 |
+
/* .get_name = */ ggml_backend_cpu_reg_get_name,
|
| 563 |
+
/* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
|
| 564 |
+
/* .get_device = */ ggml_backend_cpu_reg_get_device,
|
| 565 |
+
/* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
|
| 566 |
+
};
|
| 567 |
+
|
| 568 |
+
ggml_backend_reg_t ggml_backend_cpu_reg(void) {
|
| 569 |
+
static struct ggml_backend_reg ggml_backend_cpu_reg = {
|
| 570 |
+
/* .iface = */ ggml_backend_cpu_reg_i,
|
| 571 |
+
/* .context = */ NULL,
|
| 572 |
+
};
|
| 573 |
+
|
| 574 |
+
return &ggml_backend_cpu_reg;
|
| 575 |
+
}
|
ggml/src/ggml-cpu/llamafile/sgemm.cpp
ADDED
|
@@ -0,0 +1,1884 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright 2024 Mozilla Foundation
|
| 2 |
+
//
|
| 3 |
+
// Permission is hereby granted, free of charge, to any person obtaining
|
| 4 |
+
// a copy of this software and associated documentation files (the
|
| 5 |
+
// "Software"), to deal in the Software without restriction, including
|
| 6 |
+
// without limitation the rights to use, copy, modify, merge, publish,
|
| 7 |
+
// distribute, sublicense, and/or sell copies of the Software, and to
|
| 8 |
+
// permit persons to whom the Software is furnished to do so, subject to
|
| 9 |
+
// the following conditions:
|
| 10 |
+
//
|
| 11 |
+
// The above copyright notice and this permission notice shall be
|
| 12 |
+
// included in all copies or substantial portions of the Software.
|
| 13 |
+
//
|
| 14 |
+
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
| 15 |
+
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
| 16 |
+
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
| 17 |
+
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
|
| 18 |
+
// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
|
| 19 |
+
// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
| 20 |
+
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
// SOFTWARE.
|
| 22 |
+
|
| 23 |
+
//
|
| 24 |
+
// _ _ ___ _ _ ___
|
| 25 |
+
// | |_(_)_ _ _ _| _ ) | /_\ / __|
|
| 26 |
+
// | _| | ' \ || | _ \ |__ / _ \\__ \.
|
| 27 |
+
// \__|_|_||_\_, |___/____/_/ \_\___/
|
| 28 |
+
// |__/
|
| 29 |
+
//
|
| 30 |
+
// BASIC LINEAR ALGEBRA SUBPROGRAMS
|
| 31 |
+
//
|
| 32 |
+
//
|
| 33 |
+
// This file implements multithreaded CPU matrix multiplication for the
|
| 34 |
+
// common contiguous use case C = Aᵀ * B. These kernels are designed to
|
| 35 |
+
// have excellent performance[1] for matrices that fit in the CPU cache
|
| 36 |
+
// without imposing any overhead such as cache filling or malloc calls.
|
| 37 |
+
//
|
| 38 |
+
// This implementation does not guarantee any upper bound with rounding
|
| 39 |
+
// errors, which grow along with k. Our goal's to maximally exploit the
|
| 40 |
+
// hardware for performance, and then use whatever resources remain for
|
| 41 |
+
// improving numerical accuracy.
|
| 42 |
+
//
|
| 43 |
+
// [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online].
|
| 44 |
+
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
|
| 45 |
+
|
| 46 |
+
#if defined(__GNUC__)
|
| 47 |
+
#pragma GCC diagnostic ignored "-Wpedantic"
|
| 48 |
+
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
| 49 |
+
#endif
|
| 50 |
+
|
| 51 |
+
#include "sgemm.h"
|
| 52 |
+
#include "ggml-impl.h"
|
| 53 |
+
#include "ggml-cpu-impl.h"
|
| 54 |
+
#include "ggml-quants.h"
|
| 55 |
+
|
| 56 |
+
#ifdef _MSC_VER
|
| 57 |
+
#define NOINLINE __declspec(noinline)
|
| 58 |
+
#else
|
| 59 |
+
#define NOINLINE __attribute__((__noinline__))
|
| 60 |
+
#endif
|
| 61 |
+
|
| 62 |
+
#if defined(__ARM_NEON) || defined(__AVX512F__)
|
| 63 |
+
#define VECTOR_REGISTERS 32
|
| 64 |
+
#else
|
| 65 |
+
#define VECTOR_REGISTERS 16
|
| 66 |
+
#endif
|
| 67 |
+
|
| 68 |
+
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
| 69 |
+
|
| 70 |
+
namespace {
|
| 71 |
+
|
| 72 |
+
inline float unhalf(ggml_fp16_t d) {
|
| 73 |
+
return GGML_FP16_TO_FP32(d);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 77 |
+
// VECTORIZED ARITHMETIC OPERATIONS
|
| 78 |
+
|
| 79 |
+
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 80 |
+
inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
|
| 81 |
+
inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
|
| 82 |
+
inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
|
| 83 |
+
#endif // __SSE__
|
| 84 |
+
|
| 85 |
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 86 |
+
inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
|
| 87 |
+
inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
|
| 88 |
+
inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
|
| 89 |
+
#endif // __AVX__
|
| 90 |
+
|
| 91 |
+
#if defined(__AVX512F__)
|
| 92 |
+
inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
|
| 93 |
+
inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
|
| 94 |
+
inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
|
| 95 |
+
#endif // __AVX512F__
|
| 96 |
+
|
| 97 |
+
#if defined(__ARM_NEON)
|
| 98 |
+
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
|
| 99 |
+
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
|
| 100 |
+
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
|
| 101 |
+
#endif // __ARM_NEON
|
| 102 |
+
|
| 103 |
+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
|
| 104 |
+
inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
|
| 105 |
+
inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
| 106 |
+
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
| 107 |
+
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 108 |
+
|
| 109 |
+
#if defined(__MMA__)
|
| 110 |
+
typedef vector unsigned char vec_t;
|
| 111 |
+
typedef __vector_quad acc_t;
|
| 112 |
+
#endif
|
| 113 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 114 |
+
// VECTORIZED FUSED MULTIPLY ADD
|
| 115 |
+
|
| 116 |
+
/**
|
| 117 |
+
* Computes a * b + c.
|
| 118 |
+
*/
|
| 119 |
+
template <typename T, typename U>
|
| 120 |
+
inline U madd(T a, T b, U c) {
|
| 121 |
+
return add(mul(a, b), c);
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
#if defined(__FMA__)
|
| 125 |
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 126 |
+
template <>
|
| 127 |
+
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
|
| 128 |
+
return _mm256_fmadd_ps(a, b, c);
|
| 129 |
+
}
|
| 130 |
+
#endif
|
| 131 |
+
#if defined(__AVX512F__)
|
| 132 |
+
template <>
|
| 133 |
+
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
|
| 134 |
+
return _mm512_fmadd_ps(a, b, c);
|
| 135 |
+
}
|
| 136 |
+
#endif
|
| 137 |
+
#endif
|
| 138 |
+
|
| 139 |
+
#if defined(__ARM_FEATURE_FMA)
|
| 140 |
+
template <>
|
| 141 |
+
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
|
| 142 |
+
return vfmaq_f32(c, b, a);
|
| 143 |
+
}
|
| 144 |
+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
| 145 |
+
template <>
|
| 146 |
+
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
| 147 |
+
return vfmaq_f16(c, b, a);
|
| 148 |
+
}
|
| 149 |
+
#endif
|
| 150 |
+
#endif
|
| 151 |
+
|
| 152 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 153 |
+
// VECTORIZED HORIZONTAL SUM
|
| 154 |
+
|
| 155 |
+
#if defined(__ARM_NEON)
|
| 156 |
+
inline float hsum(float32x4_t x) {
|
| 157 |
+
return vaddvq_f32(x);
|
| 158 |
+
}
|
| 159 |
+
#endif // __ARM_NEON
|
| 160 |
+
|
| 161 |
+
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
| 162 |
+
inline float hsum(float16x8_t x) {
|
| 163 |
+
return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
|
| 164 |
+
vcvt_f32_f16(vget_high_f16(x))));
|
| 165 |
+
}
|
| 166 |
+
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
| 167 |
+
|
| 168 |
+
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 169 |
+
inline float hsum(__m128 x) {
|
| 170 |
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 171 |
+
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
|
| 172 |
+
x = _mm_add_ss(x, _mm_movehdup_ps(x));
|
| 173 |
+
#else
|
| 174 |
+
__m128 t;
|
| 175 |
+
t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
|
| 176 |
+
x = _mm_add_ps(x, t);
|
| 177 |
+
t = _mm_movehl_ps(t, x);
|
| 178 |
+
x = _mm_add_ss(x, t);
|
| 179 |
+
#endif
|
| 180 |
+
return _mm_cvtss_f32(x);
|
| 181 |
+
}
|
| 182 |
+
#endif
|
| 183 |
+
|
| 184 |
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 185 |
+
inline float hsum(__m256 x) {
|
| 186 |
+
return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
|
| 187 |
+
_mm256_castps256_ps128(x)));
|
| 188 |
+
}
|
| 189 |
+
#endif // __AVX__
|
| 190 |
+
|
| 191 |
+
#if defined(__AVX512F__)
|
| 192 |
+
inline float hsum(__m512 x) {
|
| 193 |
+
return _mm512_reduce_add_ps(x);
|
| 194 |
+
}
|
| 195 |
+
#endif // __AVX512F__
|
| 196 |
+
|
| 197 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 198 |
+
// VECTORIZED MEMORY LOADING
|
| 199 |
+
|
| 200 |
+
template <typename T, typename U> T load(const U *);
|
| 201 |
+
|
| 202 |
+
#if defined(__ARM_NEON)
|
| 203 |
+
template <> inline float32x4_t load(const float *p) {
|
| 204 |
+
return vld1q_f32(p);
|
| 205 |
+
}
|
| 206 |
+
#if !defined(_MSC_VER)
|
| 207 |
+
template <> inline float16x8_t load(const ggml_fp16_t *p) {
|
| 208 |
+
return vld1q_f16((const float16_t *)p);
|
| 209 |
+
}
|
| 210 |
+
template <> inline float32x4_t load(const ggml_fp16_t *p) {
|
| 211 |
+
return vcvt_f32_f16(vld1_f16((const float16_t *)p));
|
| 212 |
+
}
|
| 213 |
+
#endif // _MSC_VER
|
| 214 |
+
#endif // __ARM_NEON
|
| 215 |
+
|
| 216 |
+
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 217 |
+
template <> inline __m128 load(const float *p) {
|
| 218 |
+
return _mm_loadu_ps(p);
|
| 219 |
+
}
|
| 220 |
+
#endif // __SSE__
|
| 221 |
+
|
| 222 |
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 223 |
+
template <> inline __m256 load(const float *p) {
|
| 224 |
+
return _mm256_loadu_ps(p);
|
| 225 |
+
}
|
| 226 |
+
#endif // __AVX__
|
| 227 |
+
|
| 228 |
+
#if defined(__F16C__)
|
| 229 |
+
template <> inline __m256 load(const ggml_fp16_t *p) {
|
| 230 |
+
return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
|
| 231 |
+
}
|
| 232 |
+
#endif // __F16C__
|
| 233 |
+
|
| 234 |
+
#if defined(__AVX512F__)
|
| 235 |
+
template <> inline __m512 load(const float *p) {
|
| 236 |
+
return _mm512_loadu_ps(p);
|
| 237 |
+
}
|
| 238 |
+
template <> inline __m512 load(const ggml_fp16_t *p) {
|
| 239 |
+
return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
|
| 240 |
+
}
|
| 241 |
+
#endif // __AVX512F__
|
| 242 |
+
|
| 243 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 244 |
+
// CONSTANTS
|
| 245 |
+
|
| 246 |
+
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
|
| 247 |
+
static const int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
|
| 248 |
+
static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
| 249 |
+
#endif
|
| 250 |
+
|
| 251 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 252 |
+
// FLOATING POINT MATRIX MULTIPLICATION
|
| 253 |
+
|
| 254 |
+
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
| 255 |
+
class tinyBLAS {
|
| 256 |
+
public:
|
| 257 |
+
tinyBLAS(int64_t k,
|
| 258 |
+
const TA *A, int64_t lda,
|
| 259 |
+
const TB *B, int64_t ldb,
|
| 260 |
+
TC *C, int64_t ldc,
|
| 261 |
+
int ith, int nth)
|
| 262 |
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
void matmul(int64_t m, int64_t n) {
|
| 266 |
+
mnpack(0, m, 0, n);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
private:
|
| 270 |
+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 271 |
+
int64_t mc, nc, mp, np;
|
| 272 |
+
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
|
| 273 |
+
#if VECTOR_REGISTERS == 32
|
| 274 |
+
case 0x55:
|
| 275 |
+
mc = 5;
|
| 276 |
+
nc = 5;
|
| 277 |
+
gemm<5, 5>(m0, m, n0, n);
|
| 278 |
+
break;
|
| 279 |
+
case 0x45:
|
| 280 |
+
mc = 4;
|
| 281 |
+
nc = 5;
|
| 282 |
+
gemm<4, 5>(m0, m, n0, n);
|
| 283 |
+
break;
|
| 284 |
+
case 0x54:
|
| 285 |
+
mc = 5;
|
| 286 |
+
nc = 4;
|
| 287 |
+
gemm<5, 4>(m0, m, n0, n);
|
| 288 |
+
break;
|
| 289 |
+
case 0x44:
|
| 290 |
+
mc = 4;
|
| 291 |
+
nc = 4;
|
| 292 |
+
gemm<4, 4>(m0, m, n0, n);
|
| 293 |
+
break;
|
| 294 |
+
case 0x53:
|
| 295 |
+
mc = 5;
|
| 296 |
+
nc = 3;
|
| 297 |
+
gemm<5, 3>(m0, m, n0, n);
|
| 298 |
+
break;
|
| 299 |
+
case 0x35:
|
| 300 |
+
mc = 3;
|
| 301 |
+
nc = 5;
|
| 302 |
+
gemm<3, 5>(m0, m, n0, n);
|
| 303 |
+
break;
|
| 304 |
+
case 0x43:
|
| 305 |
+
mc = 4;
|
| 306 |
+
nc = 3;
|
| 307 |
+
gemm<4, 3>(m0, m, n0, n);
|
| 308 |
+
break;
|
| 309 |
+
#else
|
| 310 |
+
case 0x55:
|
| 311 |
+
case 0x54:
|
| 312 |
+
case 0x53:
|
| 313 |
+
case 0x45:
|
| 314 |
+
case 0x44:
|
| 315 |
+
case 0x43:
|
| 316 |
+
mc = 4;
|
| 317 |
+
nc = 3;
|
| 318 |
+
gemm<4, 3>(m0, m, n0, n);
|
| 319 |
+
break;
|
| 320 |
+
case 0x35:
|
| 321 |
+
#endif
|
| 322 |
+
case 0x34:
|
| 323 |
+
mc = 3;
|
| 324 |
+
nc = 4;
|
| 325 |
+
gemm<3, 4>(m0, m, n0, n);
|
| 326 |
+
break;
|
| 327 |
+
case 0x52:
|
| 328 |
+
mc = 5;
|
| 329 |
+
nc = 2;
|
| 330 |
+
gemm<5, 2>(m0, m, n0, n);
|
| 331 |
+
break;
|
| 332 |
+
case 0x33:
|
| 333 |
+
mc = 3;
|
| 334 |
+
nc = 3;
|
| 335 |
+
gemm<3, 3>(m0, m, n0, n);
|
| 336 |
+
break;
|
| 337 |
+
case 0x25:
|
| 338 |
+
mc = 2;
|
| 339 |
+
nc = 5;
|
| 340 |
+
gemm<2, 5>(m0, m, n0, n);
|
| 341 |
+
break;
|
| 342 |
+
case 0x42:
|
| 343 |
+
mc = 4;
|
| 344 |
+
nc = 2;
|
| 345 |
+
gemm<4, 2>(m0, m, n0, n);
|
| 346 |
+
break;
|
| 347 |
+
case 0x24:
|
| 348 |
+
mc = 2;
|
| 349 |
+
nc = 4;
|
| 350 |
+
gemm<2, 4>(m0, m, n0, n);
|
| 351 |
+
break;
|
| 352 |
+
case 0x32:
|
| 353 |
+
mc = 3;
|
| 354 |
+
nc = 2;
|
| 355 |
+
gemm<3, 2>(m0, m, n0, n);
|
| 356 |
+
break;
|
| 357 |
+
case 0x23:
|
| 358 |
+
mc = 2;
|
| 359 |
+
nc = 3;
|
| 360 |
+
gemm<2, 3>(m0, m, n0, n);
|
| 361 |
+
break;
|
| 362 |
+
case 0x51:
|
| 363 |
+
mc = 5;
|
| 364 |
+
nc = 1;
|
| 365 |
+
gemm<5, 1>(m0, m, n0, n);
|
| 366 |
+
break;
|
| 367 |
+
case 0x41:
|
| 368 |
+
mc = 4;
|
| 369 |
+
nc = 1;
|
| 370 |
+
gemm<4, 1>(m0, m, n0, n);
|
| 371 |
+
break;
|
| 372 |
+
case 0x22:
|
| 373 |
+
mc = 2;
|
| 374 |
+
nc = 2;
|
| 375 |
+
gemm<2, 2>(m0, m, n0, n);
|
| 376 |
+
break;
|
| 377 |
+
case 0x15:
|
| 378 |
+
mc = 1;
|
| 379 |
+
nc = 5;
|
| 380 |
+
gemm<1, 5>(m0, m, n0, n);
|
| 381 |
+
break;
|
| 382 |
+
case 0x14:
|
| 383 |
+
mc = 1;
|
| 384 |
+
nc = 4;
|
| 385 |
+
gemm<1, 4>(m0, m, n0, n);
|
| 386 |
+
break;
|
| 387 |
+
case 0x31:
|
| 388 |
+
mc = 3;
|
| 389 |
+
nc = 1;
|
| 390 |
+
gemm<3, 1>(m0, m, n0, n);
|
| 391 |
+
break;
|
| 392 |
+
case 0x13:
|
| 393 |
+
mc = 1;
|
| 394 |
+
nc = 3;
|
| 395 |
+
gemm<1, 3>(m0, m, n0, n);
|
| 396 |
+
break;
|
| 397 |
+
case 0x21:
|
| 398 |
+
mc = 2;
|
| 399 |
+
nc = 1;
|
| 400 |
+
gemm<2, 1>(m0, m, n0, n);
|
| 401 |
+
break;
|
| 402 |
+
case 0x12:
|
| 403 |
+
mc = 1;
|
| 404 |
+
nc = 2;
|
| 405 |
+
gemm<1, 2>(m0, m, n0, n);
|
| 406 |
+
break;
|
| 407 |
+
case 0x11:
|
| 408 |
+
mc = 1;
|
| 409 |
+
nc = 1;
|
| 410 |
+
gemm<1, 1>(m0, m, n0, n);
|
| 411 |
+
break;
|
| 412 |
+
default:
|
| 413 |
+
return;
|
| 414 |
+
}
|
| 415 |
+
mp = m0 + (m - m0) / mc * mc;
|
| 416 |
+
np = n0 + (n - n0) / nc * nc;
|
| 417 |
+
mnpack(mp, m, n0, np);
|
| 418 |
+
mnpack(m0, m, np, n);
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
template <int RM, int RN>
|
| 422 |
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 423 |
+
int64_t ytiles = (m - m0) / RM;
|
| 424 |
+
int64_t xtiles = (n - n0) / RN;
|
| 425 |
+
int64_t tiles = xtiles * ytiles;
|
| 426 |
+
int64_t duty = (tiles + nth - 1) / nth;
|
| 427 |
+
int64_t start = duty * ith;
|
| 428 |
+
int64_t end = start + duty;
|
| 429 |
+
if (end > tiles)
|
| 430 |
+
end = tiles;
|
| 431 |
+
for (int64_t job = start; job < end; ++job) {
|
| 432 |
+
int64_t ii = m0 + job / xtiles * RM;
|
| 433 |
+
int64_t jj = n0 + job % xtiles * RN;
|
| 434 |
+
D Cv[RN][RM] = {};
|
| 435 |
+
for (int64_t l = 0; l < k; l += KN)
|
| 436 |
+
for (int64_t j = 0; j < RN; ++j)
|
| 437 |
+
for (int64_t i = 0; i < RM; ++i)
|
| 438 |
+
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
|
| 439 |
+
load<V>(B + ldb * (jj + j) + l),
|
| 440 |
+
Cv[j][i]);
|
| 441 |
+
for (int64_t j = 0; j < RN; ++j)
|
| 442 |
+
for (int64_t i = 0; i < RM; ++i)
|
| 443 |
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
const TA *const A;
|
| 448 |
+
const TB *const B;
|
| 449 |
+
TC *const C;
|
| 450 |
+
const int64_t k;
|
| 451 |
+
const int64_t lda;
|
| 452 |
+
const int64_t ldb;
|
| 453 |
+
const int64_t ldc;
|
| 454 |
+
const int ith;
|
| 455 |
+
const int nth;
|
| 456 |
+
};
|
| 457 |
+
|
| 458 |
+
//////////////////////////////////////////////////////////////////////////////////////////
|
| 459 |
+
// QUANT ZERO MATRIX MULTIPLICATION
|
| 460 |
+
|
| 461 |
+
#if defined(__ARM_FEATURE_DOTPROD)
|
| 462 |
+
template <typename TA>
|
| 463 |
+
class tinyBLAS_Q0_ARM {
|
| 464 |
+
public:
|
| 465 |
+
tinyBLAS_Q0_ARM(int64_t k,
|
| 466 |
+
const TA *A, int64_t lda,
|
| 467 |
+
const block_q8_0 *B, int64_t ldb,
|
| 468 |
+
float *C, int64_t ldc,
|
| 469 |
+
int ith, int nth)
|
| 470 |
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
void matmul(int64_t m, int64_t n) {
|
| 474 |
+
mnpack(0, m, 0, n);
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
private:
|
| 478 |
+
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 479 |
+
int64_t mc, nc, mp, np;
|
| 480 |
+
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
|
| 481 |
+
case 0x33:
|
| 482 |
+
mc = 3;
|
| 483 |
+
nc = 3;
|
| 484 |
+
gemm<3, 3>(m0, m, n0, n);
|
| 485 |
+
break;
|
| 486 |
+
case 0x32:
|
| 487 |
+
mc = 3;
|
| 488 |
+
nc = 2;
|
| 489 |
+
gemm<3, 2>(m0, m, n0, n);
|
| 490 |
+
break;
|
| 491 |
+
case 0x23:
|
| 492 |
+
mc = 2;
|
| 493 |
+
nc = 3;
|
| 494 |
+
gemm<2, 3>(m0, m, n0, n);
|
| 495 |
+
break;
|
| 496 |
+
case 0x22:
|
| 497 |
+
mc = 2;
|
| 498 |
+
nc = 2;
|
| 499 |
+
gemm<2, 2>(m0, m, n0, n);
|
| 500 |
+
break;
|
| 501 |
+
case 0x31:
|
| 502 |
+
mc = 3;
|
| 503 |
+
nc = 1;
|
| 504 |
+
gemm<3, 1>(m0, m, n0, n);
|
| 505 |
+
break;
|
| 506 |
+
case 0x13:
|
| 507 |
+
mc = 1;
|
| 508 |
+
nc = 3;
|
| 509 |
+
gemm<1, 3>(m0, m, n0, n);
|
| 510 |
+
break;
|
| 511 |
+
case 0x21:
|
| 512 |
+
mc = 2;
|
| 513 |
+
nc = 1;
|
| 514 |
+
gemm<2, 1>(m0, m, n0, n);
|
| 515 |
+
break;
|
| 516 |
+
case 0x12:
|
| 517 |
+
mc = 1;
|
| 518 |
+
nc = 2;
|
| 519 |
+
gemm<1, 2>(m0, m, n0, n);
|
| 520 |
+
break;
|
| 521 |
+
case 0x11:
|
| 522 |
+
mc = 1;
|
| 523 |
+
nc = 1;
|
| 524 |
+
gemm<1, 1>(m0, m, n0, n);
|
| 525 |
+
break;
|
| 526 |
+
default:
|
| 527 |
+
return;
|
| 528 |
+
}
|
| 529 |
+
mp = m0 + (m - m0) / mc * mc;
|
| 530 |
+
np = n0 + (n - n0) / nc * nc;
|
| 531 |
+
mnpack(mp, m, n0, np);
|
| 532 |
+
mnpack(m0, m, np, n);
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
template <int RM, int RN>
|
| 536 |
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 537 |
+
int64_t ytiles = (m - m0) / RM;
|
| 538 |
+
int64_t xtiles = (n - n0) / RN;
|
| 539 |
+
int64_t tiles = xtiles * ytiles;
|
| 540 |
+
int64_t duty = (tiles + nth - 1) / nth;
|
| 541 |
+
int64_t start = duty * ith;
|
| 542 |
+
int64_t end = start + duty;
|
| 543 |
+
if (end > tiles)
|
| 544 |
+
end = tiles;
|
| 545 |
+
for (int64_t job = start; job < end; ++job) {
|
| 546 |
+
int64_t ii = m0 + job / xtiles * RM;
|
| 547 |
+
int64_t jj = n0 + job % xtiles * RN;
|
| 548 |
+
float32x4_t Cv[RN][RM] = {};
|
| 549 |
+
for (int64_t l = 0; l < k; ++l)
|
| 550 |
+
for (int64_t j = 0; j < RN; ++j)
|
| 551 |
+
for (int64_t i = 0; i < RM; ++i)
|
| 552 |
+
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
|
| 553 |
+
vcvtq_f32_s32(vdotq_s32(
|
| 554 |
+
vdotq_s32(vdupq_n_s32(0),
|
| 555 |
+
load_lo(A + lda * (ii + i) + l),
|
| 556 |
+
load_lo(B + ldb * (jj + j) + l)),
|
| 557 |
+
load_hi(A + lda * (ii + i) + l),
|
| 558 |
+
load_hi(B + ldb * (jj + j) + l))),
|
| 559 |
+
unhalf(A[lda * (ii + i) + l].d) *
|
| 560 |
+
unhalf(B[ldb * (jj + j) + l].d));
|
| 561 |
+
for (int64_t j = 0; j < RN; ++j)
|
| 562 |
+
for (int64_t i = 0; i < RM; ++i)
|
| 563 |
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
| 564 |
+
}
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
inline int8x16_t load_lo(const block_q8_0 *b) {
|
| 568 |
+
return vld1q_s8(b->qs);
|
| 569 |
+
}
|
| 570 |
+
|
| 571 |
+
inline int8x16_t load_hi(const block_q8_0 *b) {
|
| 572 |
+
return vld1q_s8(b->qs + 16);
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
inline int8x16_t load_lo(const block_q4_0 *b) {
|
| 576 |
+
return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
|
| 577 |
+
vdupq_n_u8(0x0f))),
|
| 578 |
+
vdupq_n_s8(0x8));
|
| 579 |
+
}
|
| 580 |
+
|
| 581 |
+
inline int8x16_t load_hi(const block_q4_0 *b) {
|
| 582 |
+
return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
|
| 583 |
+
vdupq_n_s8(0x8));
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
const TA *const A;
|
| 587 |
+
const block_q8_0 *const B;
|
| 588 |
+
float *const C;
|
| 589 |
+
const int64_t k;
|
| 590 |
+
const int64_t lda;
|
| 591 |
+
const int64_t ldb;
|
| 592 |
+
const int64_t ldc;
|
| 593 |
+
const int ith;
|
| 594 |
+
const int nth;
|
| 595 |
+
};
|
| 596 |
+
#endif // __ARM_FEATURE_DOTPROD
|
| 597 |
+
|
| 598 |
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
| 599 |
+
template <typename TA, typename TB, typename TC>
|
| 600 |
+
class tinyBLAS_Q0_AVX {
|
| 601 |
+
public:
|
| 602 |
+
tinyBLAS_Q0_AVX(int64_t k,
|
| 603 |
+
const TA *A, int64_t lda,
|
| 604 |
+
const TB *B, int64_t ldb,
|
| 605 |
+
TC *C, int64_t ldc,
|
| 606 |
+
int ith, int nth)
|
| 607 |
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
| 608 |
+
}
|
| 609 |
+
|
| 610 |
+
void matmul(int64_t m, int64_t n) {
|
| 611 |
+
mnpack(0, m, 0, n);
|
| 612 |
+
}
|
| 613 |
+
|
| 614 |
+
private:
|
| 615 |
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 616 |
+
int64_t mc, nc, mp, np;
|
| 617 |
+
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
|
| 618 |
+
#if VECTOR_REGISTERS == 32
|
| 619 |
+
case 0x44:
|
| 620 |
+
mc = 4;
|
| 621 |
+
nc = 4;
|
| 622 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 623 |
+
gemm4xN<4>(m0, m, n0, n);
|
| 624 |
+
#else
|
| 625 |
+
gemm<4, 4>(m0, m, n0, n);
|
| 626 |
+
#endif
|
| 627 |
+
break;
|
| 628 |
+
case 0x43:
|
| 629 |
+
mc = 4;
|
| 630 |
+
nc = 3;
|
| 631 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 632 |
+
gemm4xN<3>(m0, m, n0, n);
|
| 633 |
+
#else
|
| 634 |
+
gemm<4, 3>(m0, m, n0, n);
|
| 635 |
+
#endif
|
| 636 |
+
break;
|
| 637 |
+
case 0x34:
|
| 638 |
+
mc = 3;
|
| 639 |
+
nc = 4;
|
| 640 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 641 |
+
gemmMx4<3>(m0, m, n0, n);
|
| 642 |
+
#else
|
| 643 |
+
gemm<3, 4>(m0, m, n0, n);
|
| 644 |
+
#endif
|
| 645 |
+
break;
|
| 646 |
+
case 0x33:
|
| 647 |
+
mc = 3;
|
| 648 |
+
nc = 3;
|
| 649 |
+
gemm<3, 3>(m0, m, n0, n);
|
| 650 |
+
break;
|
| 651 |
+
case 0x42:
|
| 652 |
+
mc = 4;
|
| 653 |
+
nc = 2;
|
| 654 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 655 |
+
gemm4xN<2>(m0, m, n0, n);
|
| 656 |
+
#else
|
| 657 |
+
gemm<4, 2>(m0, m, n0, n);
|
| 658 |
+
#endif
|
| 659 |
+
break;
|
| 660 |
+
case 0x24:
|
| 661 |
+
mc = 2;
|
| 662 |
+
nc = 4;
|
| 663 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 664 |
+
gemmMx4<2>(m0, m, n0, n);
|
| 665 |
+
#else
|
| 666 |
+
gemm<2, 4>(m0, m, n0, n);
|
| 667 |
+
#endif
|
| 668 |
+
break;
|
| 669 |
+
#else
|
| 670 |
+
case 0x44:
|
| 671 |
+
case 0x43:
|
| 672 |
+
case 0x42:
|
| 673 |
+
mc = 4;
|
| 674 |
+
nc = 2;
|
| 675 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 676 |
+
gemm4xN<2>(m0, m, n0, n);
|
| 677 |
+
#else
|
| 678 |
+
gemm<4, 2>(m0, m, n0, n);
|
| 679 |
+
#endif
|
| 680 |
+
break;
|
| 681 |
+
case 0x34:
|
| 682 |
+
case 0x24:
|
| 683 |
+
mc = 2;
|
| 684 |
+
nc = 4;
|
| 685 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 686 |
+
gemmMx4<2>(m0, m, n0, n);
|
| 687 |
+
#else
|
| 688 |
+
gemm<2, 4>(m0, m, n0, n);
|
| 689 |
+
#endif
|
| 690 |
+
break;
|
| 691 |
+
case 0x33:
|
| 692 |
+
#endif
|
| 693 |
+
case 0x32:
|
| 694 |
+
mc = 3;
|
| 695 |
+
nc = 2;
|
| 696 |
+
gemm<3, 2>(m0, m, n0, n);
|
| 697 |
+
break;
|
| 698 |
+
case 0x23:
|
| 699 |
+
mc = 2;
|
| 700 |
+
nc = 3;
|
| 701 |
+
gemm<2, 3>(m0, m, n0, n);
|
| 702 |
+
break;
|
| 703 |
+
case 0x41:
|
| 704 |
+
mc = 4;
|
| 705 |
+
nc = 1;
|
| 706 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 707 |
+
gemm4xN<1>(m0, m, n0, n);
|
| 708 |
+
#else
|
| 709 |
+
gemm<4, 1>(m0, m, n0, n);
|
| 710 |
+
#endif
|
| 711 |
+
break;
|
| 712 |
+
case 0x22:
|
| 713 |
+
mc = 2;
|
| 714 |
+
nc = 2;
|
| 715 |
+
gemm<2, 2>(m0, m, n0, n);
|
| 716 |
+
break;
|
| 717 |
+
case 0x14:
|
| 718 |
+
mc = 1;
|
| 719 |
+
nc = 4;
|
| 720 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 721 |
+
gemmMx4<1>(m0, m, n0, n);
|
| 722 |
+
#else
|
| 723 |
+
gemm<1, 4>(m0, m, n0, n);
|
| 724 |
+
#endif
|
| 725 |
+
break;
|
| 726 |
+
case 0x31:
|
| 727 |
+
mc = 3;
|
| 728 |
+
nc = 1;
|
| 729 |
+
gemm<3, 1>(m0, m, n0, n);
|
| 730 |
+
break;
|
| 731 |
+
case 0x13:
|
| 732 |
+
mc = 1;
|
| 733 |
+
nc = 3;
|
| 734 |
+
gemm<1, 3>(m0, m, n0, n);
|
| 735 |
+
break;
|
| 736 |
+
case 0x21:
|
| 737 |
+
mc = 2;
|
| 738 |
+
nc = 1;
|
| 739 |
+
gemm<2, 1>(m0, m, n0, n);
|
| 740 |
+
break;
|
| 741 |
+
case 0x12:
|
| 742 |
+
mc = 1;
|
| 743 |
+
nc = 2;
|
| 744 |
+
gemm<1, 2>(m0, m, n0, n);
|
| 745 |
+
break;
|
| 746 |
+
case 0x11:
|
| 747 |
+
mc = 1;
|
| 748 |
+
nc = 1;
|
| 749 |
+
gemm<1, 1>(m0, m, n0, n);
|
| 750 |
+
break;
|
| 751 |
+
default:
|
| 752 |
+
return;
|
| 753 |
+
}
|
| 754 |
+
mp = m0 + (m - m0) / mc * mc;
|
| 755 |
+
np = n0 + (n - n0) / nc * nc;
|
| 756 |
+
mnpack(mp, m, n0, np);
|
| 757 |
+
mnpack(m0, m, np, n);
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
#if defined(__AVX2__) && defined(__F16C__)
|
| 761 |
+
// Templated functions for gemm of dimensions 4xN
|
| 762 |
+
template <int RN>
|
| 763 |
+
NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 764 |
+
int64_t ytiles = (m - m0) / 4;
|
| 765 |
+
int64_t xtiles = (n - n0) / RN;
|
| 766 |
+
int64_t tiles = xtiles * ytiles;
|
| 767 |
+
int64_t duty = (tiles + nth - 1) / nth;
|
| 768 |
+
int64_t start = duty * ith;
|
| 769 |
+
int64_t end = start + duty;
|
| 770 |
+
if (end > tiles)
|
| 771 |
+
end = tiles;
|
| 772 |
+
for (int64_t job = start; job < end; ++job) {
|
| 773 |
+
int64_t ii = m0 + job / xtiles * 4;
|
| 774 |
+
int64_t jj = n0 + job % xtiles * RN;
|
| 775 |
+
__m256 Cv[RN][4] = {};
|
| 776 |
+
for (int64_t l = 0; l < k; ++l) {
|
| 777 |
+
uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
|
| 778 |
+
// Convert delta values for four blocks to float values
|
| 779 |
+
__m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
|
| 780 |
+
__m256i avec0 = load(A + lda * (ii + 0) + l);
|
| 781 |
+
__m256i avec1 = load(A + lda * (ii + 1) + l);
|
| 782 |
+
__m256i avec2 = load(A + lda * (ii + 2) + l);
|
| 783 |
+
__m256i avec3 = load(A + lda * (ii + 3) + l);
|
| 784 |
+
for (int64_t j = 0; j < RN; ++j) {
|
| 785 |
+
__m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
|
| 786 |
+
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
|
| 787 |
+
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
|
| 788 |
+
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
|
| 789 |
+
// Computation of dot product and multiplication with appropriate delta value products
|
| 790 |
+
Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
|
| 791 |
+
updot(_mm256_sign_epi8(avec0, avec0),
|
| 792 |
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
|
| 793 |
+
Cv[j][0]);
|
| 794 |
+
Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
|
| 795 |
+
updot(_mm256_sign_epi8(avec1, avec1),
|
| 796 |
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
|
| 797 |
+
Cv[j][1]);
|
| 798 |
+
Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
|
| 799 |
+
updot(_mm256_sign_epi8(avec2, avec2),
|
| 800 |
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
|
| 801 |
+
Cv[j][2]);
|
| 802 |
+
Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
|
| 803 |
+
updot(_mm256_sign_epi8(avec3, avec3),
|
| 804 |
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
|
| 805 |
+
Cv[j][3]);
|
| 806 |
+
}
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
for (int64_t j = 0; j < RN; ++j)
|
| 810 |
+
for (int64_t i = 0; i < 4; ++i)
|
| 811 |
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
| 812 |
+
}
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
// Templated functions for gemm of dimensions Mx4
|
| 816 |
+
template <int RM>
|
| 817 |
+
NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 818 |
+
int64_t ytiles = (m - m0) / RM;
|
| 819 |
+
int64_t xtiles = (n - n0) / 4;
|
| 820 |
+
int64_t tiles = xtiles * ytiles;
|
| 821 |
+
int64_t duty = (tiles + nth - 1) / nth;
|
| 822 |
+
int64_t start = duty * ith;
|
| 823 |
+
int64_t end = start + duty;
|
| 824 |
+
if (end > tiles)
|
| 825 |
+
end = tiles;
|
| 826 |
+
for (int64_t job = start; job < end; ++job) {
|
| 827 |
+
int64_t ii = m0 + job / xtiles * RM;
|
| 828 |
+
int64_t jj = n0 + job % xtiles * 4;
|
| 829 |
+
__m256 Cv[4][RM] = {};
|
| 830 |
+
for (int64_t l = 0; l < k; ++l) {
|
| 831 |
+
uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
|
| 832 |
+
// Convert delta values for four blocks to float values
|
| 833 |
+
__m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
|
| 834 |
+
__m256i bvec0 = load(B + ldb * (jj + 0) + l);
|
| 835 |
+
__m256i bvec1 = load(B + ldb * (jj + 1) + l);
|
| 836 |
+
__m256i bvec2 = load(B + ldb * (jj + 2) + l);
|
| 837 |
+
__m256i bvec3 = load(B + ldb * (jj + 3) + l);
|
| 838 |
+
for (int64_t i = 0; i < RM; ++i) {
|
| 839 |
+
__m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
|
| 840 |
+
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
|
| 841 |
+
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
|
| 842 |
+
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
|
| 843 |
+
// Computation of dot product and multiplication with appropriate delta value products
|
| 844 |
+
Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
|
| 845 |
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
| 846 |
+
load(A + lda * (ii + i) + l)),
|
| 847 |
+
_mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
|
| 848 |
+
Cv[0][i]);
|
| 849 |
+
Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
|
| 850 |
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
| 851 |
+
load(A + lda * (ii + i) + l)),
|
| 852 |
+
_mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
|
| 853 |
+
Cv[1][i]);
|
| 854 |
+
Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
|
| 855 |
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
| 856 |
+
load(A + lda * (ii + i) + l)),
|
| 857 |
+
_mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
|
| 858 |
+
Cv[2][i]);
|
| 859 |
+
Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
|
| 860 |
+
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
| 861 |
+
load(A + lda * (ii + i) + l)),
|
| 862 |
+
_mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
|
| 863 |
+
Cv[3][i]);
|
| 864 |
+
}
|
| 865 |
+
}
|
| 866 |
+
for (int64_t j = 0; j < 4; ++j)
|
| 867 |
+
for (int64_t i = 0; i < RM; ++i)
|
| 868 |
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
| 869 |
+
}
|
| 870 |
+
}
|
| 871 |
+
#endif
|
| 872 |
+
|
| 873 |
+
template <int RM, int RN>
|
| 874 |
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 875 |
+
int64_t ytiles = (m - m0) / RM;
|
| 876 |
+
int64_t xtiles = (n - n0) / RN;
|
| 877 |
+
int64_t tiles = xtiles * ytiles;
|
| 878 |
+
int64_t duty = (tiles + nth - 1) / nth;
|
| 879 |
+
int64_t start = duty * ith;
|
| 880 |
+
int64_t end = start + duty;
|
| 881 |
+
if (end > tiles)
|
| 882 |
+
end = tiles;
|
| 883 |
+
for (int64_t job = start; job < end; ++job) {
|
| 884 |
+
int64_t ii = m0 + job / xtiles * RM;
|
| 885 |
+
int64_t jj = n0 + job % xtiles * RN;
|
| 886 |
+
__m256 Cv[RN][RM] = {};
|
| 887 |
+
for (int64_t l = 0; l < k; ++l)
|
| 888 |
+
for (int64_t j = 0; j < RN; ++j)
|
| 889 |
+
for (int64_t i = 0; i < RM; ++i) {
|
| 890 |
+
#if defined(__AVX2__)
|
| 891 |
+
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
| 892 |
+
load(A + lda * (ii + i) + l)),
|
| 893 |
+
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
|
| 894 |
+
load(A + lda * (ii + i) + l)));
|
| 895 |
+
#else
|
| 896 |
+
__m128i ali0 = load0(A + lda * (ii + i) + l);
|
| 897 |
+
__m128i ali1 = load1(A + lda * (ii + i) + l);
|
| 898 |
+
__m128i blj0 = load0(B + ldb * (jj + j) + l);
|
| 899 |
+
__m128i blj1 = load1(B + ldb * (jj + j) + l);
|
| 900 |
+
|
| 901 |
+
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
|
| 902 |
+
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
|
| 903 |
+
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
|
| 904 |
+
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
|
| 905 |
+
|
| 906 |
+
// updot
|
| 907 |
+
const __m128i oneFill = _mm_set1_epi16(1);
|
| 908 |
+
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
|
| 909 |
+
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
|
| 910 |
+
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
|
| 911 |
+
#endif
|
| 912 |
+
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
|
| 913 |
+
unhalf(B[ldb * (jj + j) + l].d)),
|
| 914 |
+
udTmp,
|
| 915 |
+
Cv[j][i]);
|
| 916 |
+
}
|
| 917 |
+
for (int64_t j = 0; j < RN; ++j)
|
| 918 |
+
for (int64_t i = 0; i < RM; ++i)
|
| 919 |
+
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
| 920 |
+
}
|
| 921 |
+
}
|
| 922 |
+
|
| 923 |
+
inline __m256i load(const block_q8_0 *b) {
|
| 924 |
+
return _mm256_loadu_si256((const __m256i *)b->qs);
|
| 925 |
+
}
|
| 926 |
+
|
| 927 |
+
inline __m128i load0(const block_q8_0 *b) {
|
| 928 |
+
return _mm_loadu_si128((const __m128i *)b->qs);
|
| 929 |
+
}
|
| 930 |
+
|
| 931 |
+
inline __m128i load1(const block_q8_0 *b) {
|
| 932 |
+
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
inline __m256i load(const block_q4_0 *b) {
|
| 936 |
+
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
|
| 937 |
+
}
|
| 938 |
+
|
| 939 |
+
inline __m128i load0(const block_q4_0 *b) {
|
| 940 |
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
| 941 |
+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
|
| 942 |
+
}
|
| 943 |
+
|
| 944 |
+
inline __m128i load1(const block_q4_0 *b) {
|
| 945 |
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
| 946 |
+
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
|
| 947 |
+
}
|
| 948 |
+
|
| 949 |
+
inline __m256i load(const block_q5_0 *b) {
|
| 950 |
+
return _mm256_or_si256(denibble(b->qs), bittobyte(b->qh));
|
| 951 |
+
}
|
| 952 |
+
|
| 953 |
+
inline __m128i load0(const block_q5_0* b) {
|
| 954 |
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
| 955 |
+
uint32_t x32;
|
| 956 |
+
memcpy(&x32, b->qh, sizeof(uint32_t));
|
| 957 |
+
__m128i qxl = _mm_and_si128(_mm_set1_epi8(15), x);
|
| 958 |
+
__m128i bytesl = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
|
| 959 |
+
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
| 960 |
+
_mm_shuffle_epi8(_mm_set1_epi32(x32),
|
| 961 |
+
_mm_set_epi64x(0x0101010101010101, 0x0000000000000000))));
|
| 962 |
+
bytesl = _mm_andnot_si128(bytesl, _mm_set1_epi8((char)0xF0));
|
| 963 |
+
return _mm_or_si128(qxl, bytesl);
|
| 964 |
+
}
|
| 965 |
+
|
| 966 |
+
inline __m128i load1(const block_q5_0* b) {
|
| 967 |
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
| 968 |
+
uint32_t x32;
|
| 969 |
+
memcpy(&x32, b->qh, sizeof(uint32_t));
|
| 970 |
+
__m128i qxh = _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4));
|
| 971 |
+
__m128i bytesh = _mm_cmpeq_epi8(_mm_set1_epi64x(-1),
|
| 972 |
+
_mm_or_si128(_mm_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
| 973 |
+
_mm_shuffle_epi8(_mm_set1_epi32(x32),
|
| 974 |
+
_mm_set_epi64x(0x0303030303030303, 0x0202020202020202))));
|
| 975 |
+
bytesh = _mm_andnot_si128(bytesh, _mm_set1_epi8((char)0xF0));
|
| 976 |
+
return _mm_or_si128(qxh, bytesh);
|
| 977 |
+
}
|
| 978 |
+
|
| 979 |
+
inline __m256i load(const block_iq4_nl *b) {
|
| 980 |
+
return MM256_SET_M128I(load1(b), load0(b));
|
| 981 |
+
}
|
| 982 |
+
|
| 983 |
+
inline __m128i load0(const block_iq4_nl *b) {
|
| 984 |
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
| 985 |
+
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), x));
|
| 986 |
+
}
|
| 987 |
+
|
| 988 |
+
inline __m128i load1(const block_iq4_nl *b) {
|
| 989 |
+
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
|
| 990 |
+
return _mm_shuffle_epi8(iq4nlt, _mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)));
|
| 991 |
+
}
|
| 992 |
+
|
| 993 |
+
inline __m256 updot(__m256i u, __m256i s) {
|
| 994 |
+
__m256i res;
|
| 995 |
+
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
| 996 |
+
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
|
| 997 |
+
#else
|
| 998 |
+
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
|
| 999 |
+
#endif
|
| 1000 |
+
return _mm256_cvtepi32_ps(res);
|
| 1001 |
+
}
|
| 1002 |
+
|
| 1003 |
+
static inline __m256i denibble(const uint8_t *p) {
|
| 1004 |
+
__m128i x = _mm_loadu_si128((const __m128i *)p);
|
| 1005 |
+
return _mm256_and_si256(_mm256_set1_epi8(15),
|
| 1006 |
+
_mm256_insertf128_si256(_mm256_castsi128_si256(x),
|
| 1007 |
+
_mm_srli_epi16(x, 4), 1));
|
| 1008 |
+
}
|
| 1009 |
+
|
| 1010 |
+
static inline __m256i bittobyte(const uint8_t *p) {
|
| 1011 |
+
uint32_t x32;
|
| 1012 |
+
memcpy(&x32, p, sizeof(uint32_t));
|
| 1013 |
+
__m256i bytes = _mm256_cmpeq_epi8(_mm256_set1_epi64x(-1),
|
| 1014 |
+
_mm256_or_si256(_mm256_set1_epi64x(0x7fbfdfeff7fbfdfe),
|
| 1015 |
+
_mm256_shuffle_epi8(_mm256_set1_epi32(x32),
|
| 1016 |
+
_mm256_set_epi64x(0x0303030303030303, 0x0202020202020202,
|
| 1017 |
+
0x0101010101010101, 0x0000000000000000))));
|
| 1018 |
+
return _mm256_andnot_si256(bytes, _mm256_set1_epi8((char)0xF0));
|
| 1019 |
+
}
|
| 1020 |
+
|
| 1021 |
+
const TA *const A;
|
| 1022 |
+
const TB *const B;
|
| 1023 |
+
TC *const C;
|
| 1024 |
+
const int64_t k;
|
| 1025 |
+
const int64_t lda;
|
| 1026 |
+
const int64_t ldb;
|
| 1027 |
+
const int64_t ldc;
|
| 1028 |
+
const int ith;
|
| 1029 |
+
const int nth;
|
| 1030 |
+
};
|
| 1031 |
+
#endif // __AVX__
|
| 1032 |
+
|
| 1033 |
+
//PPC Implementation
|
| 1034 |
+
#if defined(__MMA__)
|
| 1035 |
+
|
| 1036 |
+
#define SAVE_ACC(ACC, ii, jj) \
|
| 1037 |
+
__builtin_mma_disassemble_acc(vec_C, ACC); \
|
| 1038 |
+
for (int I = 0; I < 4; I++) { \
|
| 1039 |
+
for (int J = 0; J < 4; J++) { \
|
| 1040 |
+
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
|
| 1041 |
+
} \
|
| 1042 |
+
} \
|
| 1043 |
+
|
| 1044 |
+
template <typename TA, typename TB, typename TC>
|
| 1045 |
+
class tinyBLAS_PPC {
|
| 1046 |
+
public:
|
| 1047 |
+
tinyBLAS_PPC(int64_t k,
|
| 1048 |
+
const TA *A, int64_t lda,
|
| 1049 |
+
const TB *B, int64_t ldb,
|
| 1050 |
+
TC *C, int64_t ldc,
|
| 1051 |
+
int ith, int nth)
|
| 1052 |
+
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
| 1053 |
+
}
|
| 1054 |
+
|
| 1055 |
+
void matmul(int64_t m, int64_t n) {
|
| 1056 |
+
mnpack(0, m, 0, n);
|
| 1057 |
+
}
|
| 1058 |
+
|
| 1059 |
+
private:
|
| 1060 |
+
|
| 1061 |
+
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
|
| 1062 |
+
|
| 1063 |
+
void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
|
| 1064 |
+
int64_t i, j;
|
| 1065 |
+
float *aoffset = NULL, *boffset = NULL;
|
| 1066 |
+
float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
| 1067 |
+
float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
| 1068 |
+
|
| 1069 |
+
aoffset = const_cast<float*>(a);
|
| 1070 |
+
boffset = vec;
|
| 1071 |
+
j = (rows >> 3);
|
| 1072 |
+
if (j > 0) {
|
| 1073 |
+
do {
|
| 1074 |
+
aoffset1 = aoffset;
|
| 1075 |
+
aoffset2 = aoffset1 + lda;
|
| 1076 |
+
aoffset3 = aoffset2 + lda;
|
| 1077 |
+
aoffset4 = aoffset3 + lda;
|
| 1078 |
+
aoffset5 = aoffset4 + lda;
|
| 1079 |
+
aoffset6 = aoffset5 + lda;
|
| 1080 |
+
aoffset7 = aoffset6 + lda;
|
| 1081 |
+
aoffset8 = aoffset7 + lda;
|
| 1082 |
+
aoffset += 8 * lda;
|
| 1083 |
+
i = (cols >> 3);
|
| 1084 |
+
if (i > 0) {
|
| 1085 |
+
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
| 1086 |
+
vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
|
| 1087 |
+
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
| 1088 |
+
do {
|
| 1089 |
+
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
| 1090 |
+
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
| 1091 |
+
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
|
| 1092 |
+
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
|
| 1093 |
+
C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
|
| 1094 |
+
C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
|
| 1095 |
+
C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
|
| 1096 |
+
C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
|
| 1097 |
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
| 1098 |
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
| 1099 |
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
| 1100 |
+
__builtin_vsx_disassemble_pair(c4, &C4);
|
| 1101 |
+
__builtin_vsx_disassemble_pair(c5, &C5);
|
| 1102 |
+
__builtin_vsx_disassemble_pair(c6, &C6);
|
| 1103 |
+
__builtin_vsx_disassemble_pair(c7, &C7);
|
| 1104 |
+
__builtin_vsx_disassemble_pair(c8, &C8);
|
| 1105 |
+
|
| 1106 |
+
t1 = vec_mergeh(c1[0], c2[0]);
|
| 1107 |
+
t2 = vec_mergeh(c3[0], c4[0]);
|
| 1108 |
+
t3 = vec_mergeh(c5[0], c6[0]);
|
| 1109 |
+
t4 = vec_mergeh(c7[0], c8[0]);
|
| 1110 |
+
t5 = vec_xxpermdi(t1, t2, 0);
|
| 1111 |
+
t6 = vec_xxpermdi(t3, t4, 0);
|
| 1112 |
+
t7 = vec_xxpermdi(t1, t2, 3);
|
| 1113 |
+
t8 = vec_xxpermdi(t3, t4, 3);
|
| 1114 |
+
vec_xst(t5, 0, boffset);
|
| 1115 |
+
vec_xst(t6, 0, boffset+4);
|
| 1116 |
+
vec_xst(t7, 0, boffset+8);
|
| 1117 |
+
vec_xst(t8, 0, boffset+12);
|
| 1118 |
+
|
| 1119 |
+
t1 = vec_mergel(c1[0], c2[0]);
|
| 1120 |
+
t2 = vec_mergel(c3[0], c4[0]);
|
| 1121 |
+
t3 = vec_mergel(c5[0], c6[0]);
|
| 1122 |
+
t4 = vec_mergel(c7[0], c8[0]);
|
| 1123 |
+
t5 = vec_xxpermdi(t1, t2, 0);
|
| 1124 |
+
t6 = vec_xxpermdi(t3, t4, 0);
|
| 1125 |
+
t7 = vec_xxpermdi(t1, t2, 3);
|
| 1126 |
+
t8 = vec_xxpermdi(t3, t4, 3);
|
| 1127 |
+
vec_xst(t5, 0, boffset+16);
|
| 1128 |
+
vec_xst(t6, 0, boffset+20);
|
| 1129 |
+
vec_xst(t7, 0, boffset+24);
|
| 1130 |
+
vec_xst(t8, 0, boffset+28);
|
| 1131 |
+
|
| 1132 |
+
t1 = vec_mergeh(c1[1], c2[1]);
|
| 1133 |
+
t2 = vec_mergeh(c3[1], c4[1]);
|
| 1134 |
+
t3 = vec_mergeh(c5[1], c6[1]);
|
| 1135 |
+
t4 = vec_mergeh(c7[1], c8[1]);
|
| 1136 |
+
t5 = vec_xxpermdi(t1, t2, 0);
|
| 1137 |
+
t6 = vec_xxpermdi(t3, t4, 0);
|
| 1138 |
+
t7 = vec_xxpermdi(t1, t2, 3);
|
| 1139 |
+
t8 = vec_xxpermdi(t3, t4, 3);
|
| 1140 |
+
vec_xst(t5, 0, boffset+32);
|
| 1141 |
+
vec_xst(t6, 0, boffset+36);
|
| 1142 |
+
vec_xst(t7, 0, boffset+40);
|
| 1143 |
+
vec_xst(t8, 0, boffset+44);
|
| 1144 |
+
|
| 1145 |
+
t1 = vec_mergel(c1[1], c2[1]);
|
| 1146 |
+
t2 = vec_mergel(c3[1], c4[1]);
|
| 1147 |
+
t3 = vec_mergel(c5[1], c6[1]);
|
| 1148 |
+
t4 = vec_mergel(c7[1], c8[1]);
|
| 1149 |
+
t5 = vec_xxpermdi(t1, t2, 0);
|
| 1150 |
+
t6 = vec_xxpermdi(t3, t4, 0);
|
| 1151 |
+
t7 = vec_xxpermdi(t1, t2, 3);
|
| 1152 |
+
t8 = vec_xxpermdi(t3, t4, 3);
|
| 1153 |
+
vec_xst(t5, 0, boffset+48);
|
| 1154 |
+
vec_xst(t6, 0, boffset+52);
|
| 1155 |
+
vec_xst(t7, 0, boffset+56);
|
| 1156 |
+
vec_xst(t8, 0, boffset+60);
|
| 1157 |
+
|
| 1158 |
+
aoffset1 += 8*lda;
|
| 1159 |
+
aoffset2 += 8*lda;
|
| 1160 |
+
aoffset3 += 8*lda;
|
| 1161 |
+
aoffset4 += 8*lda;
|
| 1162 |
+
boffset += 64;
|
| 1163 |
+
i--;
|
| 1164 |
+
} while(i > 0);
|
| 1165 |
+
}
|
| 1166 |
+
if (cols & 4) {
|
| 1167 |
+
vector float c1, c2, c3, c4, c5, c6, c7, c8;
|
| 1168 |
+
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
| 1169 |
+
c1 = vec_xl(0, aoffset1);
|
| 1170 |
+
c2 = vec_xl(0, aoffset2);
|
| 1171 |
+
c3 = vec_xl(0, aoffset3);
|
| 1172 |
+
c4 = vec_xl(0, aoffset4);
|
| 1173 |
+
c5 = vec_xl(0, aoffset5);
|
| 1174 |
+
c6 = vec_xl(0, aoffset6);
|
| 1175 |
+
c7 = vec_xl(0, aoffset7);
|
| 1176 |
+
c8 = vec_xl(0, aoffset8);
|
| 1177 |
+
|
| 1178 |
+
t1 = vec_mergeh(c1, c2);
|
| 1179 |
+
t2 = vec_mergeh(c3, c4);
|
| 1180 |
+
t3 = vec_mergeh(c5, c6);
|
| 1181 |
+
t4 = vec_mergeh(c7, c8);
|
| 1182 |
+
t5 = vec_xxpermdi(t1, t2, 0);
|
| 1183 |
+
t6 = vec_xxpermdi(t3, t4, 0);
|
| 1184 |
+
t7 = vec_xxpermdi(t1, t2, 3);
|
| 1185 |
+
t8 = vec_xxpermdi(t3, t4, 3);
|
| 1186 |
+
vec_xst(t5, 0, boffset);
|
| 1187 |
+
vec_xst(t6, 0, boffset+4);
|
| 1188 |
+
vec_xst(t7, 0, boffset+8);
|
| 1189 |
+
vec_xst(t8, 0, boffset+12);
|
| 1190 |
+
|
| 1191 |
+
t1 = vec_mergel(c1, c2);
|
| 1192 |
+
t2 = vec_mergel(c3, c4);
|
| 1193 |
+
t3 = vec_mergel(c5, c6);
|
| 1194 |
+
t4 = vec_mergel(c7, c8);
|
| 1195 |
+
t5 = vec_xxpermdi(t1, t2, 0);
|
| 1196 |
+
t6 = vec_xxpermdi(t3, t4, 0);
|
| 1197 |
+
t7 = vec_xxpermdi(t1, t2, 3);
|
| 1198 |
+
t8 = vec_xxpermdi(t3, t4, 3);
|
| 1199 |
+
vec_xst(t5, 0, boffset+16);
|
| 1200 |
+
vec_xst(t6, 0, boffset+20);
|
| 1201 |
+
vec_xst(t7, 0, boffset+24);
|
| 1202 |
+
vec_xst(t8, 0, boffset+28);
|
| 1203 |
+
}
|
| 1204 |
+
j--;
|
| 1205 |
+
} while(j > 0);
|
| 1206 |
+
}
|
| 1207 |
+
|
| 1208 |
+
if (rows & 4) {
|
| 1209 |
+
aoffset1 = aoffset;
|
| 1210 |
+
aoffset2 = aoffset1 + lda;
|
| 1211 |
+
aoffset3 = aoffset2 + lda;
|
| 1212 |
+
aoffset4 = aoffset3 + lda;
|
| 1213 |
+
aoffset += 4 * lda;
|
| 1214 |
+
i = (cols >> 3);
|
| 1215 |
+
if (i > 0) {
|
| 1216 |
+
__vector_pair C1, C2, C3, C4;
|
| 1217 |
+
vector float c1[2], c2[2], c3[2], c4[2];
|
| 1218 |
+
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
| 1219 |
+
do {
|
| 1220 |
+
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
| 1221 |
+
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
| 1222 |
+
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
|
| 1223 |
+
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
|
| 1224 |
+
__builtin_vsx_disassemble_pair(c1, &C1);
|
| 1225 |
+
__builtin_vsx_disassemble_pair(c2, &C2);
|
| 1226 |
+
__builtin_vsx_disassemble_pair(c3, &C3);
|
| 1227 |
+
__builtin_vsx_disassemble_pair(c4, &C4);
|
| 1228 |
+
|
| 1229 |
+
t1 = vec_mergeh(c1[0], c2[0]);
|
| 1230 |
+
t2 = vec_mergeh(c3[0], c4[0]);
|
| 1231 |
+
t3 = vec_mergel(c1[0], c2[0]);
|
| 1232 |
+
t4 = vec_mergel(c3[0], c4[0]);
|
| 1233 |
+
t5 = vec_xxpermdi(t1, t2, 0);
|
| 1234 |
+
t6 = vec_xxpermdi(t1, t2, 3);
|
| 1235 |
+
t7 = vec_xxpermdi(t3, t4, 0);
|
| 1236 |
+
t8 = vec_xxpermdi(t3, t4, 3);
|
| 1237 |
+
vec_xst(t5, 0, boffset);
|
| 1238 |
+
vec_xst(t6, 0, boffset+4);
|
| 1239 |
+
vec_xst(t7, 0, boffset+8);
|
| 1240 |
+
vec_xst(t8, 0, boffset+12);
|
| 1241 |
+
|
| 1242 |
+
t1 = vec_mergeh(c1[1], c2[1]);
|
| 1243 |
+
t2 = vec_mergeh(c3[1], c4[1]);
|
| 1244 |
+
t3 = vec_mergel(c1[1], c2[1]);
|
| 1245 |
+
t4 = vec_mergel(c3[1], c4[1]);
|
| 1246 |
+
t5 = vec_xxpermdi(t1, t2, 0);
|
| 1247 |
+
t6 = vec_xxpermdi(t1, t2, 3);
|
| 1248 |
+
t7 = vec_xxpermdi(t3, t4, 0);
|
| 1249 |
+
t8 = vec_xxpermdi(t3, t4, 3);
|
| 1250 |
+
vec_xst(t5, 0, boffset+16);
|
| 1251 |
+
vec_xst(t6, 0, boffset+20);
|
| 1252 |
+
vec_xst(t7, 0, boffset+24);
|
| 1253 |
+
vec_xst(t8, 0, boffset+28);
|
| 1254 |
+
|
| 1255 |
+
aoffset1 += 8*lda;
|
| 1256 |
+
aoffset2 += 8*lda;
|
| 1257 |
+
aoffset3 += 8*lda;
|
| 1258 |
+
aoffset4 += 8*lda;
|
| 1259 |
+
boffset += 32;
|
| 1260 |
+
i--;
|
| 1261 |
+
} while(i > 0);
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
if (cols & 4) {
|
| 1265 |
+
vector float c1, c2, c3, c4;
|
| 1266 |
+
vector float t1, t2, t3, t4;
|
| 1267 |
+
c1 = vec_xl(0, aoffset1);
|
| 1268 |
+
c2 = vec_xl(0, aoffset2);
|
| 1269 |
+
c3 = vec_xl(0, aoffset3);
|
| 1270 |
+
c4 = vec_xl(0, aoffset4);
|
| 1271 |
+
|
| 1272 |
+
t1 = vec_mergeh(c1, c2);
|
| 1273 |
+
t2 = vec_mergeh(c3, c4);
|
| 1274 |
+
t3 = vec_xxpermdi(t1, t2, 0);
|
| 1275 |
+
t4 = vec_xxpermdi(t1, t2, 3);
|
| 1276 |
+
vec_xst(t3, 0, boffset);
|
| 1277 |
+
vec_xst(t4, 0, boffset+4);
|
| 1278 |
+
|
| 1279 |
+
t1 = vec_mergel(c1, c2);
|
| 1280 |
+
t2 = vec_mergel(c3, c4);
|
| 1281 |
+
t3 = vec_xxpermdi(t1, t2, 0);
|
| 1282 |
+
t4 = vec_xxpermdi(t1, t2, 3);
|
| 1283 |
+
vec_xst(t3, 0, boffset+8);
|
| 1284 |
+
vec_xst(t4, 0, boffset+12);
|
| 1285 |
+
}
|
| 1286 |
+
}
|
| 1287 |
+
if (rows & 3) {
|
| 1288 |
+
aoffset1 = aoffset;
|
| 1289 |
+
aoffset2 = aoffset1 + lda;
|
| 1290 |
+
aoffset3 = aoffset2 + lda;
|
| 1291 |
+
if (cols & 4) {
|
| 1292 |
+
vector float c1, c2, c3, c4 = {0};
|
| 1293 |
+
vector float t1, t2, t3, t4;
|
| 1294 |
+
c1 = vec_xl(0, aoffset1);
|
| 1295 |
+
c2 = vec_xl(0, aoffset2);
|
| 1296 |
+
c3 = vec_xl(0, aoffset3);
|
| 1297 |
+
|
| 1298 |
+
t1 = vec_mergeh(c1, c2);
|
| 1299 |
+
t2 = vec_mergeh(c3, c4);
|
| 1300 |
+
t3 = vec_xxpermdi(t1, t2, 0);
|
| 1301 |
+
t4 = vec_xxpermdi(t1, t2, 3);
|
| 1302 |
+
vec_xst(t3, 0, boffset);
|
| 1303 |
+
vec_xst(t4, 0, boffset+4);
|
| 1304 |
+
|
| 1305 |
+
t1 = vec_mergel(c1, c2);
|
| 1306 |
+
t2 = vec_mergel(c3, c4);
|
| 1307 |
+
t3 = vec_xxpermdi(t1, t2, 0);
|
| 1308 |
+
t4 = vec_xxpermdi(t1, t2, 3);
|
| 1309 |
+
vec_xst(t3, 0, boffset+8);
|
| 1310 |
+
vec_xst(t4, 0, boffset+12);
|
| 1311 |
+
}
|
| 1312 |
+
}
|
| 1313 |
+
}
|
| 1314 |
+
|
| 1315 |
+
void KERNEL_4x4(int64_t ii, int64_t jj) {
|
| 1316 |
+
vec_t vec_A[4], vec_B[4], vec_C[4];
|
| 1317 |
+
acc_t acc_0;
|
| 1318 |
+
__builtin_mma_xxsetaccz(&acc_0);
|
| 1319 |
+
for (int l = 0; l < k; l+=4) {
|
| 1320 |
+
READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
|
| 1321 |
+
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
| 1322 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
| 1323 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
| 1324 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
| 1325 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
|
| 1326 |
+
}
|
| 1327 |
+
SAVE_ACC(&acc_0, ii, jj);
|
| 1328 |
+
}
|
| 1329 |
+
|
| 1330 |
+
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
| 1331 |
+
vec_t vec_A[4], vec_B[8], vec_C[4];
|
| 1332 |
+
acc_t acc_0, acc_1;
|
| 1333 |
+
__builtin_mma_xxsetaccz(&acc_0);
|
| 1334 |
+
__builtin_mma_xxsetaccz(&acc_1);
|
| 1335 |
+
for (int64_t l = 0; l < k; l+=4) {
|
| 1336 |
+
READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
|
| 1337 |
+
READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
|
| 1338 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
|
| 1339 |
+
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
|
| 1340 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
|
| 1341 |
+
__builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
|
| 1342 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
|
| 1343 |
+
__builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
|
| 1344 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
|
| 1345 |
+
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
|
| 1346 |
+
}
|
| 1347 |
+
SAVE_ACC(&acc_0, ii, jj);
|
| 1348 |
+
SAVE_ACC(&acc_1, ii, jj+4);
|
| 1349 |
+
}
|
| 1350 |
+
|
| 1351 |
+
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
| 1352 |
+
vec_t vec_A[8], vec_B[4], vec_C[4];
|
| 1353 |
+
acc_t acc_0, acc_1;
|
| 1354 |
+
__builtin_mma_xxsetaccz(&acc_0);
|
| 1355 |
+
__builtin_mma_xxsetaccz(&acc_1);
|
| 1356 |
+
for (int64_t l = 0; l < k; l+=4) {
|
| 1357 |
+
READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
|
| 1358 |
+
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
| 1359 |
+
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
|
| 1360 |
+
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
|
| 1361 |
+
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
|
| 1362 |
+
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
|
| 1363 |
+
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
|
| 1364 |
+
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
|
| 1365 |
+
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
|
| 1366 |
+
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
|
| 1367 |
+
}
|
| 1368 |
+
SAVE_ACC(&acc_0, ii, jj);
|
| 1369 |
+
SAVE_ACC(&acc_1, ii+4, jj);
|
| 1370 |
+
}
|
| 1371 |
+
|
| 1372 |
+
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
| 1373 |
+
vec_t vec_A[16], vec_B[16], vec_C[4];
|
| 1374 |
+
acc_t acc_0, acc_1, acc_2, acc_3;
|
| 1375 |
+
__builtin_mma_xxsetaccz(&acc_0);
|
| 1376 |
+
__builtin_mma_xxsetaccz(&acc_1);
|
| 1377 |
+
__builtin_mma_xxsetaccz(&acc_2);
|
| 1378 |
+
__builtin_mma_xxsetaccz(&acc_3);
|
| 1379 |
+
for (int l = 0; l < k; l+=8) {
|
| 1380 |
+
READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
|
| 1381 |
+
READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
|
| 1382 |
+
for(int x = 0; x < 16; x+=2) {
|
| 1383 |
+
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
|
| 1384 |
+
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
|
| 1385 |
+
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
|
| 1386 |
+
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
|
| 1387 |
+
}
|
| 1388 |
+
}
|
| 1389 |
+
SAVE_ACC(&acc_0, ii, jj);
|
| 1390 |
+
SAVE_ACC(&acc_1, ii, jj+4);
|
| 1391 |
+
SAVE_ACC(&acc_2, ii+4, jj);
|
| 1392 |
+
SAVE_ACC(&acc_3, ii+4, jj+4);
|
| 1393 |
+
}
|
| 1394 |
+
|
| 1395 |
+
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 1396 |
+
int64_t mc, nc, mp, np;
|
| 1397 |
+
int m_rem = MIN(m - m0, 16);
|
| 1398 |
+
int n_rem = MIN(n - n0, 16);
|
| 1399 |
+
if (m_rem >= 16 && n_rem >= 8) {
|
| 1400 |
+
mc = 8;
|
| 1401 |
+
nc = 8;
|
| 1402 |
+
gemm<8,8>(m0, m, n0, n);
|
| 1403 |
+
} else if(m_rem >= 8 && n_rem >= 16) {
|
| 1404 |
+
mc = 8;
|
| 1405 |
+
nc = 8;
|
| 1406 |
+
gemm<8,8>(m0, m, n0, n);
|
| 1407 |
+
} else if (m_rem >= 8 && n_rem >= 8) {
|
| 1408 |
+
mc = 8;
|
| 1409 |
+
nc = 8;
|
| 1410 |
+
gemm<8,8>(m0, m, n0, n);
|
| 1411 |
+
} else if (m_rem >= 4 && n_rem >= 8) {
|
| 1412 |
+
mc = 4;
|
| 1413 |
+
nc = 8;
|
| 1414 |
+
gemm<4,8>(m0, m, n0, n);
|
| 1415 |
+
} else if (m_rem >= 8 && n_rem >= 4) {
|
| 1416 |
+
mc = 8;
|
| 1417 |
+
nc = 4;
|
| 1418 |
+
gemm<8,4>(m0, m, n0, n);
|
| 1419 |
+
} else if (m_rem >= 4 && n_rem >= 4) {
|
| 1420 |
+
mc = 4;
|
| 1421 |
+
nc = 4;
|
| 1422 |
+
gemm<4,4>(m0, m, n0, n);
|
| 1423 |
+
} else if ((m_rem < 4) && (n_rem > 4)) {
|
| 1424 |
+
nc = 4;
|
| 1425 |
+
switch(m_rem) {
|
| 1426 |
+
case 1:
|
| 1427 |
+
mc = 1;
|
| 1428 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1429 |
+
break;
|
| 1430 |
+
case 2:
|
| 1431 |
+
mc = 2;
|
| 1432 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1433 |
+
break;
|
| 1434 |
+
case 3:
|
| 1435 |
+
mc = 3;
|
| 1436 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1437 |
+
break;
|
| 1438 |
+
default:
|
| 1439 |
+
return;
|
| 1440 |
+
}
|
| 1441 |
+
} else if ((m_rem > 4) && (n_rem < 4)) {
|
| 1442 |
+
mc = 4;
|
| 1443 |
+
switch(n_rem) {
|
| 1444 |
+
case 1:
|
| 1445 |
+
nc = 1;
|
| 1446 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1447 |
+
break;
|
| 1448 |
+
case 2:
|
| 1449 |
+
nc = 2;
|
| 1450 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1451 |
+
break;
|
| 1452 |
+
case 3:
|
| 1453 |
+
nc = 3;
|
| 1454 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1455 |
+
break;
|
| 1456 |
+
default:
|
| 1457 |
+
return;
|
| 1458 |
+
}
|
| 1459 |
+
} else {
|
| 1460 |
+
switch((m_rem << 4) | n_rem) {
|
| 1461 |
+
case 0x43:
|
| 1462 |
+
mc = 4;
|
| 1463 |
+
nc = 3;
|
| 1464 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1465 |
+
break;
|
| 1466 |
+
case 0x42:
|
| 1467 |
+
mc = 4;
|
| 1468 |
+
nc = 2;
|
| 1469 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1470 |
+
break;
|
| 1471 |
+
case 0x41:
|
| 1472 |
+
mc = 4;
|
| 1473 |
+
nc = 1;
|
| 1474 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1475 |
+
break;
|
| 1476 |
+
case 0x34:
|
| 1477 |
+
mc = 3;
|
| 1478 |
+
nc = 4;
|
| 1479 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1480 |
+
break;
|
| 1481 |
+
case 0x33:
|
| 1482 |
+
mc = 3;
|
| 1483 |
+
nc = 3;
|
| 1484 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1485 |
+
break;
|
| 1486 |
+
case 0x32:
|
| 1487 |
+
mc = 3;
|
| 1488 |
+
nc = 2;
|
| 1489 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1490 |
+
break;
|
| 1491 |
+
case 0x31:
|
| 1492 |
+
mc = 3;
|
| 1493 |
+
nc = 1;
|
| 1494 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1495 |
+
break;
|
| 1496 |
+
case 0x24:
|
| 1497 |
+
mc = 2;
|
| 1498 |
+
nc = 4;
|
| 1499 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1500 |
+
break;
|
| 1501 |
+
case 0x23:
|
| 1502 |
+
mc = 2;
|
| 1503 |
+
nc = 3;
|
| 1504 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1505 |
+
break;
|
| 1506 |
+
case 0x22:
|
| 1507 |
+
mc = 2;
|
| 1508 |
+
nc = 2;
|
| 1509 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1510 |
+
break;
|
| 1511 |
+
case 0x21:
|
| 1512 |
+
mc = 2;
|
| 1513 |
+
nc = 1;
|
| 1514 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1515 |
+
break;
|
| 1516 |
+
case 0x14:
|
| 1517 |
+
mc = 1;
|
| 1518 |
+
nc = 4;
|
| 1519 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1520 |
+
break;
|
| 1521 |
+
case 0x13:
|
| 1522 |
+
mc = 1;
|
| 1523 |
+
nc = 3;
|
| 1524 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1525 |
+
break;
|
| 1526 |
+
case 0x12:
|
| 1527 |
+
mc = 1;
|
| 1528 |
+
nc = 2;
|
| 1529 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1530 |
+
break;
|
| 1531 |
+
case 0x11:
|
| 1532 |
+
mc = 1;
|
| 1533 |
+
nc = 1;
|
| 1534 |
+
gemm_small(m0, m, n0, n, mc, nc);
|
| 1535 |
+
break;
|
| 1536 |
+
default:
|
| 1537 |
+
return;
|
| 1538 |
+
}
|
| 1539 |
+
}
|
| 1540 |
+
mp = m0 + (m - m0) / mc * mc;
|
| 1541 |
+
np = n0 + (n - n0) / nc * nc;
|
| 1542 |
+
mnpack(mp, m, n0, np);
|
| 1543 |
+
mnpack(m0, m, np, n);
|
| 1544 |
+
}
|
| 1545 |
+
|
| 1546 |
+
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
| 1547 |
+
int64_t ytiles = (m - m0) / RM;
|
| 1548 |
+
int64_t xtiles = (n - n0) / RN;
|
| 1549 |
+
int64_t tiles = xtiles * ytiles;
|
| 1550 |
+
int64_t duty = (tiles + nth - 1) / nth;
|
| 1551 |
+
int64_t start = duty * ith;
|
| 1552 |
+
int64_t end = start + duty;
|
| 1553 |
+
if (end > tiles)
|
| 1554 |
+
end = tiles;
|
| 1555 |
+
for (int64_t job = start; job < end; ++job) {
|
| 1556 |
+
int64_t ii = m0 + job / xtiles * RM;
|
| 1557 |
+
int64_t jj = n0 + job % xtiles * RN;
|
| 1558 |
+
vec_t vec_C[4];
|
| 1559 |
+
acc_t acc_0;
|
| 1560 |
+
__builtin_mma_xxsetaccz(&acc_0);
|
| 1561 |
+
vec_t vec_A[4], vec_B[4];
|
| 1562 |
+
for (int l=0; l<k; l+=4) {
|
| 1563 |
+
if (RN >= 4 && RM == 1) {
|
| 1564 |
+
float* a = const_cast<float*>(A+(ii)*lda+l);
|
| 1565 |
+
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
| 1566 |
+
vec_A[0] = (vec_t)vec_xl(0,a);
|
| 1567 |
+
vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
|
| 1568 |
+
vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
|
| 1569 |
+
vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
|
| 1570 |
+
} else {
|
| 1571 |
+
READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
|
| 1572 |
+
READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
|
| 1573 |
+
}
|
| 1574 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
| 1575 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
| 1576 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
| 1577 |
+
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
|
| 1578 |
+
}
|
| 1579 |
+
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
| 1580 |
+
for (int I = 0; I < RM; I++) {
|
| 1581 |
+
for (int J = 0; J < RN; J++) {
|
| 1582 |
+
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
|
| 1583 |
+
}
|
| 1584 |
+
}
|
| 1585 |
+
}
|
| 1586 |
+
}
|
| 1587 |
+
|
| 1588 |
+
template <int RM, int RN>
|
| 1589 |
+
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
| 1590 |
+
int64_t ytiles = (m - m0) / RM;
|
| 1591 |
+
int64_t xtiles = (n - n0) / RN;
|
| 1592 |
+
int64_t tiles = xtiles * ytiles;
|
| 1593 |
+
int64_t duty = (tiles + nth - 1) / nth;
|
| 1594 |
+
int64_t start = duty * ith;
|
| 1595 |
+
int64_t end = start + duty;
|
| 1596 |
+
if (RM == 4 && RN == 4) {
|
| 1597 |
+
kernel = &tinyBLAS_PPC::KERNEL_4x4;
|
| 1598 |
+
} else if (RM == 4 && RN == 8) {
|
| 1599 |
+
kernel = &tinyBLAS_PPC::KERNEL_4x8;
|
| 1600 |
+
} else if (RM == 8 && RN == 4) {
|
| 1601 |
+
kernel = &tinyBLAS_PPC::KERNEL_8x4;
|
| 1602 |
+
} else if (RM == 8 && RN == 8) {
|
| 1603 |
+
kernel = &tinyBLAS_PPC::KERNEL_8x8;
|
| 1604 |
+
}
|
| 1605 |
+
if (end > tiles)
|
| 1606 |
+
end = tiles;
|
| 1607 |
+
for (int64_t job = start; job < end; ++job) {
|
| 1608 |
+
int64_t ii = m0 + job / xtiles * RM;
|
| 1609 |
+
int64_t jj = n0 + job % xtiles * RN;
|
| 1610 |
+
(this->*kernel)(ii, jj);
|
| 1611 |
+
}
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
const TA *const A;
|
| 1615 |
+
const TB *const B;
|
| 1616 |
+
TC *C;
|
| 1617 |
+
TA *At;
|
| 1618 |
+
TB *Bt;
|
| 1619 |
+
const int64_t k;
|
| 1620 |
+
const int64_t lda;
|
| 1621 |
+
const int64_t ldb;
|
| 1622 |
+
const int64_t ldc;
|
| 1623 |
+
const int ith;
|
| 1624 |
+
const int nth;
|
| 1625 |
+
};
|
| 1626 |
+
#endif
|
| 1627 |
+
} // namespace
|
| 1628 |
+
|
| 1629 |
+
/**
|
| 1630 |
+
* Performs optimized matrix multiplication on CPU.
|
| 1631 |
+
*
|
| 1632 |
+
* This subroutine may compute C = Aᵀ * B with column major ordering.
|
| 1633 |
+
* Despite its name, this isn't a generalized implementation. Work is
|
| 1634 |
+
* only performed when a handwritten kernel is written and available.
|
| 1635 |
+
* Otherwise the caller should fall back to a general matmul routine.
|
| 1636 |
+
*
|
| 1637 |
+
* For example, for single-threaded single-precision GEMM you can say
|
| 1638 |
+
*
|
| 1639 |
+
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
|
| 1640 |
+
* 0, 1,
|
| 1641 |
+
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
|
| 1642 |
+
*
|
| 1643 |
+
* @param m is rows in `A` and `C`
|
| 1644 |
+
* @param n is cols in `B` and `C`
|
| 1645 |
+
* @param k is cols in `A` and rows in `B`
|
| 1646 |
+
* @param A is first input matrix (always transposed)
|
| 1647 |
+
* @param lda is row stride of `A`
|
| 1648 |
+
* @param B is second input matrix (never transposed)
|
| 1649 |
+
* @param ldb is row stride of `B`
|
| 1650 |
+
* @param C is input/output array of output matrices
|
| 1651 |
+
* @param ldc is row stride of `C`
|
| 1652 |
+
* @param ith is thread id (must be less than `nth`)
|
| 1653 |
+
* @param nth is number of threads (must be greater than zero)
|
| 1654 |
+
* @param Atype is GGML data type of `A`
|
| 1655 |
+
* @param Btype is GGML data type of `B`
|
| 1656 |
+
* @param Ctype is GGML data type of `C`
|
| 1657 |
+
* @return true if this function was able to service the matmul request
|
| 1658 |
+
*/
|
| 1659 |
+
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
|
| 1660 |
+
int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) {
|
| 1661 |
+
|
| 1662 |
+
assert(m >= 0);
|
| 1663 |
+
assert(n >= 0);
|
| 1664 |
+
assert(k >= 0);
|
| 1665 |
+
assert(lda >= k);
|
| 1666 |
+
assert(ldb >= k);
|
| 1667 |
+
assert(ldc >= m);
|
| 1668 |
+
assert(nth > 0);
|
| 1669 |
+
assert(ith < nth);
|
| 1670 |
+
|
| 1671 |
+
// only enable sgemm for prompt processing
|
| 1672 |
+
if (n < 2)
|
| 1673 |
+
return false;
|
| 1674 |
+
|
| 1675 |
+
if (Ctype != GGML_TYPE_F32)
|
| 1676 |
+
return false;
|
| 1677 |
+
|
| 1678 |
+
switch (Atype) {
|
| 1679 |
+
|
| 1680 |
+
case GGML_TYPE_F32: {
|
| 1681 |
+
if (Btype != GGML_TYPE_F32)
|
| 1682 |
+
return false;
|
| 1683 |
+
#if defined(__AVX512F__)
|
| 1684 |
+
if (k % 16)
|
| 1685 |
+
return false;
|
| 1686 |
+
tinyBLAS<16, __m512, __m512, float, float, float> tb{
|
| 1687 |
+
k, (const float *)A, lda,
|
| 1688 |
+
(const float *)B, ldb,
|
| 1689 |
+
(float *)C, ldc,
|
| 1690 |
+
ith, nth};
|
| 1691 |
+
tb.matmul(m, n);
|
| 1692 |
+
return true;
|
| 1693 |
+
#elif defined(__AVX__) || defined(__AVX2__)
|
| 1694 |
+
if (k % 8)
|
| 1695 |
+
return false;
|
| 1696 |
+
tinyBLAS<8, __m256, __m256, float, float, float> tb{
|
| 1697 |
+
k, (const float *)A, lda,
|
| 1698 |
+
(const float *)B, ldb,
|
| 1699 |
+
(float *)C, ldc,
|
| 1700 |
+
ith, nth};
|
| 1701 |
+
tb.matmul(m, n);
|
| 1702 |
+
return true;
|
| 1703 |
+
#elif defined(__ARM_NEON)
|
| 1704 |
+
if (n < 4)
|
| 1705 |
+
return false;
|
| 1706 |
+
if (k % 4)
|
| 1707 |
+
return false;
|
| 1708 |
+
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
|
| 1709 |
+
k, (const float *)A, lda,
|
| 1710 |
+
(const float *)B, ldb,
|
| 1711 |
+
(float *)C, ldc,
|
| 1712 |
+
ith, nth};
|
| 1713 |
+
tb.matmul(m, n);
|
| 1714 |
+
return true;
|
| 1715 |
+
#elif defined(__MMA__)
|
| 1716 |
+
if (k % 8)
|
| 1717 |
+
return false;
|
| 1718 |
+
tinyBLAS_PPC<float, float, float> tb{
|
| 1719 |
+
k, (const float *)A, lda,
|
| 1720 |
+
(const float *)B, ldb,
|
| 1721 |
+
(float *)C, ldc,
|
| 1722 |
+
ith, nth};
|
| 1723 |
+
tb.matmul(m, n);
|
| 1724 |
+
return true;
|
| 1725 |
+
#else
|
| 1726 |
+
return false;
|
| 1727 |
+
#endif
|
| 1728 |
+
}
|
| 1729 |
+
|
| 1730 |
+
case GGML_TYPE_F16: {
|
| 1731 |
+
#if defined(__AVX512F__)
|
| 1732 |
+
if (k % 16)
|
| 1733 |
+
return false;
|
| 1734 |
+
if (Btype != GGML_TYPE_F32)
|
| 1735 |
+
return false;
|
| 1736 |
+
tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
|
| 1737 |
+
k, (const ggml_fp16_t *)A, lda,
|
| 1738 |
+
(const float *)B, ldb,
|
| 1739 |
+
(float *)C, ldc,
|
| 1740 |
+
ith, nth};
|
| 1741 |
+
tb.matmul(m, n);
|
| 1742 |
+
return true;
|
| 1743 |
+
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
|
| 1744 |
+
if (k % 8)
|
| 1745 |
+
return false;
|
| 1746 |
+
if (Btype != GGML_TYPE_F32)
|
| 1747 |
+
return false;
|
| 1748 |
+
tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
|
| 1749 |
+
k, (const ggml_fp16_t *)A, lda,
|
| 1750 |
+
(const float *)B, ldb,
|
| 1751 |
+
(float *)C, ldc,
|
| 1752 |
+
ith, nth};
|
| 1753 |
+
tb.matmul(m, n);
|
| 1754 |
+
return true;
|
| 1755 |
+
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
|
| 1756 |
+
if (n < 8)
|
| 1757 |
+
return false;
|
| 1758 |
+
if (k % 8)
|
| 1759 |
+
return false;
|
| 1760 |
+
if (Btype != GGML_TYPE_F16)
|
| 1761 |
+
return false;
|
| 1762 |
+
tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
|
| 1763 |
+
k, (const ggml_fp16_t *)A, lda,
|
| 1764 |
+
(const ggml_fp16_t *)B, ldb,
|
| 1765 |
+
(float *)C, ldc,
|
| 1766 |
+
ith, nth};
|
| 1767 |
+
tb.matmul(m, n);
|
| 1768 |
+
return true;
|
| 1769 |
+
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
|
| 1770 |
+
if (k % 4)
|
| 1771 |
+
return false;
|
| 1772 |
+
if (Btype != GGML_TYPE_F32)
|
| 1773 |
+
return false;
|
| 1774 |
+
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
|
| 1775 |
+
k, (const ggml_fp16_t *)A, lda,
|
| 1776 |
+
(const float *)B, ldb,
|
| 1777 |
+
(float *)C, ldc,
|
| 1778 |
+
ith, nth};
|
| 1779 |
+
tb.matmul(m, n);
|
| 1780 |
+
return true;
|
| 1781 |
+
#else
|
| 1782 |
+
return false;
|
| 1783 |
+
#endif
|
| 1784 |
+
}
|
| 1785 |
+
|
| 1786 |
+
case GGML_TYPE_Q8_0: {
|
| 1787 |
+
if (Btype != GGML_TYPE_Q8_0)
|
| 1788 |
+
return false;
|
| 1789 |
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
| 1790 |
+
tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
|
| 1791 |
+
k, (const block_q8_0 *)A, lda,
|
| 1792 |
+
(const block_q8_0 *)B, ldb,
|
| 1793 |
+
(float *)C, ldc,
|
| 1794 |
+
ith, nth};
|
| 1795 |
+
tb.matmul(m, n);
|
| 1796 |
+
return true;
|
| 1797 |
+
#elif defined(__ARM_FEATURE_DOTPROD)
|
| 1798 |
+
tinyBLAS_Q0_ARM<block_q8_0> tb{
|
| 1799 |
+
k, (const block_q8_0 *)A, lda,
|
| 1800 |
+
(const block_q8_0 *)B, ldb,
|
| 1801 |
+
(float *)C, ldc,
|
| 1802 |
+
ith, nth};
|
| 1803 |
+
tb.matmul(m, n);
|
| 1804 |
+
return true;
|
| 1805 |
+
#else
|
| 1806 |
+
return false;
|
| 1807 |
+
#endif
|
| 1808 |
+
}
|
| 1809 |
+
|
| 1810 |
+
case GGML_TYPE_Q4_0: {
|
| 1811 |
+
if (Btype != GGML_TYPE_Q8_0)
|
| 1812 |
+
return false;
|
| 1813 |
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
| 1814 |
+
tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
|
| 1815 |
+
k, (const block_q4_0 *)A, lda,
|
| 1816 |
+
(const block_q8_0 *)B, ldb,
|
| 1817 |
+
(float *)C, ldc,
|
| 1818 |
+
ith, nth};
|
| 1819 |
+
tb.matmul(m, n);
|
| 1820 |
+
return true;
|
| 1821 |
+
#elif defined(__ARM_FEATURE_DOTPROD)
|
| 1822 |
+
tinyBLAS_Q0_ARM<block_q4_0> tb{
|
| 1823 |
+
k, (const block_q4_0 *)A, lda,
|
| 1824 |
+
(const block_q8_0 *)B, ldb,
|
| 1825 |
+
(float *)C, ldc,
|
| 1826 |
+
ith, nth};
|
| 1827 |
+
tb.matmul(m, n);
|
| 1828 |
+
return true;
|
| 1829 |
+
#else
|
| 1830 |
+
return false;
|
| 1831 |
+
#endif
|
| 1832 |
+
}
|
| 1833 |
+
|
| 1834 |
+
case GGML_TYPE_Q5_0: {
|
| 1835 |
+
if (Btype != GGML_TYPE_Q8_0)
|
| 1836 |
+
return false;
|
| 1837 |
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
| 1838 |
+
tinyBLAS_Q0_AVX<block_q5_0, block_q8_0, float> tb{
|
| 1839 |
+
k, (const block_q5_0 *)A, lda,
|
| 1840 |
+
(const block_q8_0 *)B, ldb,
|
| 1841 |
+
(float *)C, ldc,
|
| 1842 |
+
ith, nth};
|
| 1843 |
+
tb.matmul(m, n);
|
| 1844 |
+
return true;
|
| 1845 |
+
#else
|
| 1846 |
+
return false;
|
| 1847 |
+
#endif
|
| 1848 |
+
}
|
| 1849 |
+
|
| 1850 |
+
case GGML_TYPE_IQ4_NL: {
|
| 1851 |
+
if (Btype != GGML_TYPE_Q8_0)
|
| 1852 |
+
return false;
|
| 1853 |
+
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
|
| 1854 |
+
tinyBLAS_Q0_AVX<block_iq4_nl, block_q8_0, float> tb{
|
| 1855 |
+
k, (const block_iq4_nl *)A, lda,
|
| 1856 |
+
(const block_q8_0 *)B, ldb,
|
| 1857 |
+
(float *)C, ldc,
|
| 1858 |
+
ith, nth};
|
| 1859 |
+
tb.matmul(m, n);
|
| 1860 |
+
return true;
|
| 1861 |
+
#else
|
| 1862 |
+
return false;
|
| 1863 |
+
#endif
|
| 1864 |
+
}
|
| 1865 |
+
|
| 1866 |
+
default:
|
| 1867 |
+
return false;
|
| 1868 |
+
}
|
| 1869 |
+
|
| 1870 |
+
(void)m;
|
| 1871 |
+
(void)n;
|
| 1872 |
+
(void)k;
|
| 1873 |
+
(void)A;
|
| 1874 |
+
(void)lda;
|
| 1875 |
+
(void)B;
|
| 1876 |
+
(void)ldb;
|
| 1877 |
+
(void)C;
|
| 1878 |
+
(void)ldc;
|
| 1879 |
+
(void)ith;
|
| 1880 |
+
(void)nth;
|
| 1881 |
+
(void)Atype;
|
| 1882 |
+
(void)Btype;
|
| 1883 |
+
(void)Ctype;
|
| 1884 |
+
}
|
ggml/src/ggml-cpu/llamafile/sgemm.h
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
#include <stdint.h>
|
| 3 |
+
#include <stdbool.h>
|
| 4 |
+
#ifdef __cplusplus
|
| 5 |
+
extern "C" {
|
| 6 |
+
#endif
|
| 7 |
+
|
| 8 |
+
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t,
|
| 9 |
+
const void *, int64_t, void *, int64_t, int, int,
|
| 10 |
+
int, int, int);
|
| 11 |
+
|
| 12 |
+
#ifdef __cplusplus
|
| 13 |
+
}
|
| 14 |
+
#endif
|
ggml/src/ggml-cuda/common.cuh
CHANGED
|
@@ -6,7 +6,7 @@
|
|
| 6 |
#include <cstdint>
|
| 7 |
#include <memory>
|
| 8 |
|
| 9 |
-
#if defined(
|
| 10 |
#define GGML_COMMON_DECL_HIP
|
| 11 |
#define GGML_COMMON_IMPL_HIP
|
| 12 |
#else
|
|
@@ -26,13 +26,13 @@
|
|
| 26 |
#include <string>
|
| 27 |
#include <vector>
|
| 28 |
|
| 29 |
-
#if defined(
|
| 30 |
#include "vendors/hip.h"
|
| 31 |
#elif defined(GGML_USE_MUSA)
|
| 32 |
#include "vendors/musa.h"
|
| 33 |
#else
|
| 34 |
#include "vendors/cuda.h"
|
| 35 |
-
#endif // defined(
|
| 36 |
|
| 37 |
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
| 38 |
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
|
|
@@ -97,7 +97,7 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
|
| 97 |
|
| 98 |
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
|
| 99 |
|
| 100 |
-
#if !defined(
|
| 101 |
static const char * cu_get_error_str(CUresult err) {
|
| 102 |
const char * err_str;
|
| 103 |
cuGetErrorString(err, &err_str);
|
|
@@ -120,21 +120,21 @@ typedef float dfloat; // dequantize float
|
|
| 120 |
typedef float2 dfloat2;
|
| 121 |
#endif // GGML_CUDA_F16
|
| 122 |
|
| 123 |
-
#if (defined(
|
| 124 |
#define FP16_AVAILABLE
|
| 125 |
-
#endif // (defined(
|
| 126 |
|
| 127 |
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
| 128 |
#define FAST_FP16_AVAILABLE
|
| 129 |
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
| 130 |
|
| 131 |
-
#if !(defined(
|
| 132 |
#define FP16_MMA_AVAILABLE
|
| 133 |
-
#endif // !(defined(
|
| 134 |
|
| 135 |
-
#if !(defined(
|
| 136 |
#define INT8_MMA_AVAILABLE
|
| 137 |
-
#endif // !(defined(
|
| 138 |
|
| 139 |
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
|
| 140 |
#define FLASH_ATTN_AVAILABLE
|
|
@@ -156,14 +156,14 @@ static constexpr bool int8_mma_available(const int cc) {
|
|
| 156 |
static __device__ void no_device_code(
|
| 157 |
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
| 158 |
|
| 159 |
-
#if defined(
|
| 160 |
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
|
| 161 |
file_name, line, function_name, arch);
|
| 162 |
GGML_UNUSED(arch_list);
|
| 163 |
#else
|
| 164 |
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
|
| 165 |
file_name, line, function_name, arch, arch_list);
|
| 166 |
-
#endif // defined(
|
| 167 |
__trap();
|
| 168 |
|
| 169 |
GGML_UNUSED(no_device_code); // suppress unused function warning
|
|
@@ -176,7 +176,7 @@ static __device__ void no_device_code(
|
|
| 176 |
#endif // __CUDA_ARCH__
|
| 177 |
|
| 178 |
static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
| 179 |
-
#if !(defined(
|
| 180 |
return __reduce_add_sync(0xffffffff, x);
|
| 181 |
#else
|
| 182 |
#pragma unroll
|
|
@@ -184,7 +184,7 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
|
| 184 |
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
| 185 |
}
|
| 186 |
return x;
|
| 187 |
-
#endif // !(defined(
|
| 188 |
}
|
| 189 |
|
| 190 |
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|
@@ -207,7 +207,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
|
| 207 |
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
| 208 |
#ifdef FP16_AVAILABLE
|
| 209 |
|
| 210 |
-
#if defined(
|
| 211 |
#pragma unroll
|
| 212 |
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 213 |
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
|
|
@@ -221,7 +221,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
|
| 221 |
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
| 222 |
}
|
| 223 |
return a;
|
| 224 |
-
#endif // defined(
|
| 225 |
|
| 226 |
#else
|
| 227 |
NO_DEVICE_CODE;
|
|
@@ -240,11 +240,11 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
|
| 240 |
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
| 241 |
#ifdef FP16_AVAILABLE
|
| 242 |
|
| 243 |
-
#if !(defined(
|
| 244 |
return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
| 245 |
#else
|
| 246 |
return __hmax(a, b);
|
| 247 |
-
#endif // !(defined(
|
| 248 |
|
| 249 |
#else
|
| 250 |
NO_DEVICE_CODE;
|
|
@@ -254,7 +254,7 @@ static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b
|
|
| 254 |
}
|
| 255 |
|
| 256 |
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
| 257 |
-
#if !(defined(
|
| 258 |
|
| 259 |
#if CUDART_VERSION >= CUDART_HMAX
|
| 260 |
return __hmax2(a, b);
|
|
@@ -269,11 +269,11 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
|
|
| 269 |
GGML_UNUSED(a);
|
| 270 |
GGML_UNUSED(b);
|
| 271 |
NO_DEVICE_CODE;
|
| 272 |
-
#endif // !(defined(
|
| 273 |
}
|
| 274 |
|
| 275 |
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
| 276 |
-
#if !(defined(
|
| 277 |
#pragma unroll
|
| 278 |
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 279 |
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
|
@@ -282,7 +282,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
|
| 282 |
#else
|
| 283 |
GGML_UNUSED(x);
|
| 284 |
NO_DEVICE_CODE;
|
| 285 |
-
#endif // !(defined(
|
| 286 |
}
|
| 287 |
|
| 288 |
#if CUDART_VERSION < CUDART_HMASK
|
|
@@ -294,7 +294,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
|
|
| 294 |
#endif // CUDART_VERSION < CUDART_HMASK
|
| 295 |
|
| 296 |
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
|
| 297 |
-
#if defined(
|
| 298 |
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
| 299 |
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
| 300 |
#elif defined(RDNA3)
|
|
@@ -320,7 +320,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|
| 320 |
#endif
|
| 321 |
return c;
|
| 322 |
|
| 323 |
-
#else // defined(
|
| 324 |
|
| 325 |
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 326 |
return __dp4a(a, b, c);
|
|
@@ -330,7 +330,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
|
|
| 330 |
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
|
| 331 |
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 332 |
|
| 333 |
-
#endif // defined(
|
| 334 |
}
|
| 335 |
|
| 336 |
// TODO: move to ggml-common.h
|
|
|
|
| 6 |
#include <cstdint>
|
| 7 |
#include <memory>
|
| 8 |
|
| 9 |
+
#if defined(GGML_USE_HIP)
|
| 10 |
#define GGML_COMMON_DECL_HIP
|
| 11 |
#define GGML_COMMON_IMPL_HIP
|
| 12 |
#else
|
|
|
|
| 26 |
#include <string>
|
| 27 |
#include <vector>
|
| 28 |
|
| 29 |
+
#if defined(GGML_USE_HIP)
|
| 30 |
#include "vendors/hip.h"
|
| 31 |
#elif defined(GGML_USE_MUSA)
|
| 32 |
#include "vendors/musa.h"
|
| 33 |
#else
|
| 34 |
#include "vendors/cuda.h"
|
| 35 |
+
#endif // defined(GGML_USE_HIP)
|
| 36 |
|
| 37 |
#define STRINGIZE_IMPL(...) #__VA_ARGS__
|
| 38 |
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
|
|
|
|
| 97 |
|
| 98 |
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
|
| 99 |
|
| 100 |
+
#if !defined(GGML_USE_HIP)
|
| 101 |
static const char * cu_get_error_str(CUresult err) {
|
| 102 |
const char * err_str;
|
| 103 |
cuGetErrorString(err, &err_str);
|
|
|
|
| 120 |
typedef float2 dfloat2;
|
| 121 |
#endif // GGML_CUDA_F16
|
| 122 |
|
| 123 |
+
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
| 124 |
#define FP16_AVAILABLE
|
| 125 |
+
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
| 126 |
|
| 127 |
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
| 128 |
#define FAST_FP16_AVAILABLE
|
| 129 |
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
| 130 |
|
| 131 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 132 |
#define FP16_MMA_AVAILABLE
|
| 133 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
| 134 |
|
| 135 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
| 136 |
#define INT8_MMA_AVAILABLE
|
| 137 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
| 138 |
|
| 139 |
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
|
| 140 |
#define FLASH_ATTN_AVAILABLE
|
|
|
|
| 156 |
static __device__ void no_device_code(
|
| 157 |
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
| 158 |
|
| 159 |
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 160 |
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
|
| 161 |
file_name, line, function_name, arch);
|
| 162 |
GGML_UNUSED(arch_list);
|
| 163 |
#else
|
| 164 |
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
|
| 165 |
file_name, line, function_name, arch, arch_list);
|
| 166 |
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 167 |
__trap();
|
| 168 |
|
| 169 |
GGML_UNUSED(no_device_code); // suppress unused function warning
|
|
|
|
| 176 |
#endif // __CUDA_ARCH__
|
| 177 |
|
| 178 |
static __device__ __forceinline__ int warp_reduce_sum(int x) {
|
| 179 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
|
| 180 |
return __reduce_add_sync(0xffffffff, x);
|
| 181 |
#else
|
| 182 |
#pragma unroll
|
|
|
|
| 184 |
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
|
| 185 |
}
|
| 186 |
return x;
|
| 187 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
|
| 188 |
}
|
| 189 |
|
| 190 |
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|
|
|
| 207 |
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
| 208 |
#ifdef FP16_AVAILABLE
|
| 209 |
|
| 210 |
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 211 |
#pragma unroll
|
| 212 |
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 213 |
const half2 a_other = __shfl_xor_sync(0xffffffff, a, mask, 32);
|
|
|
|
| 221 |
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
|
| 222 |
}
|
| 223 |
return a;
|
| 224 |
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 225 |
|
| 226 |
#else
|
| 227 |
NO_DEVICE_CODE;
|
|
|
|
| 240 |
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
| 241 |
#ifdef FP16_AVAILABLE
|
| 242 |
|
| 243 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
| 244 |
return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
| 245 |
#else
|
| 246 |
return __hmax(a, b);
|
| 247 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
| 248 |
|
| 249 |
#else
|
| 250 |
NO_DEVICE_CODE;
|
|
|
|
| 254 |
}
|
| 255 |
|
| 256 |
static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const half2 b) {
|
| 257 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 258 |
|
| 259 |
#if CUDART_VERSION >= CUDART_HMAX
|
| 260 |
return __hmax2(a, b);
|
|
|
|
| 269 |
GGML_UNUSED(a);
|
| 270 |
GGML_UNUSED(b);
|
| 271 |
NO_DEVICE_CODE;
|
| 272 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 273 |
}
|
| 274 |
|
| 275 |
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
| 276 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 277 |
#pragma unroll
|
| 278 |
for (int mask = 16; mask > 0; mask >>= 1) {
|
| 279 |
x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
|
|
|
| 282 |
#else
|
| 283 |
GGML_UNUSED(x);
|
| 284 |
NO_DEVICE_CODE;
|
| 285 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
| 286 |
}
|
| 287 |
|
| 288 |
#if CUDART_VERSION < CUDART_HMASK
|
|
|
|
| 294 |
#endif // CUDART_VERSION < CUDART_HMASK
|
| 295 |
|
| 296 |
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
|
| 297 |
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 298 |
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
|
| 299 |
c = __builtin_amdgcn_sdot4(a, b, c, false);
|
| 300 |
#elif defined(RDNA3)
|
|
|
|
| 320 |
#endif
|
| 321 |
return c;
|
| 322 |
|
| 323 |
+
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 324 |
|
| 325 |
#if __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 326 |
return __dp4a(a, b, c);
|
|
|
|
| 330 |
return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
|
| 331 |
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
|
| 332 |
|
| 333 |
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 334 |
}
|
| 335 |
|
| 336 |
// TODO: move to ggml-common.h
|
ggml/src/ggml-cuda/fattn-common.cuh
CHANGED
|
@@ -517,9 +517,9 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|
| 517 |
}
|
| 518 |
|
| 519 |
template<int D, int parallel_blocks> // D == head size
|
| 520 |
-
#if !(defined(
|
| 521 |
__launch_bounds__(D, 1)
|
| 522 |
-
#endif // !(defined(
|
| 523 |
static __global__ void flash_attn_combine_results(
|
| 524 |
const float * __restrict__ VKQ_parts,
|
| 525 |
const float2 * __restrict__ VKQ_meta,
|
|
|
|
| 517 |
}
|
| 518 |
|
| 519 |
template<int D, int parallel_blocks> // D == head size
|
| 520 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 521 |
__launch_bounds__(D, 1)
|
| 522 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 523 |
static __global__ void flash_attn_combine_results(
|
| 524 |
const float * __restrict__ VKQ_parts,
|
| 525 |
const float2 * __restrict__ VKQ_meta,
|
ggml/src/ggml-cuda/fattn-tile-f16.cu
CHANGED
|
@@ -5,9 +5,9 @@
|
|
| 5 |
#define FATTN_KQ_STRIDE_TILE_F16 64
|
| 6 |
|
| 7 |
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
| 8 |
-
#if !(defined(
|
| 9 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 10 |
-
#endif // !(defined(
|
| 11 |
static __global__ void flash_attn_tile_ext_f16(
|
| 12 |
const char * __restrict__ Q,
|
| 13 |
const char * __restrict__ K,
|
|
|
|
| 5 |
#define FATTN_KQ_STRIDE_TILE_F16 64
|
| 6 |
|
| 7 |
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
| 8 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 10 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
static __global__ void flash_attn_tile_ext_f16(
|
| 12 |
const char * __restrict__ Q,
|
| 13 |
const char * __restrict__ K,
|
ggml/src/ggml-cuda/fattn-tile-f32.cu
CHANGED
|
@@ -5,9 +5,9 @@
|
|
| 5 |
#define FATTN_KQ_STRIDE_TILE_F32 32
|
| 6 |
|
| 7 |
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
| 8 |
-
#if !(defined(
|
| 9 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 10 |
-
#endif // !(defined(
|
| 11 |
static __global__ void flash_attn_tile_ext_f32(
|
| 12 |
const char * __restrict__ Q,
|
| 13 |
const char * __restrict__ K,
|
|
|
|
| 5 |
#define FATTN_KQ_STRIDE_TILE_F32 32
|
| 6 |
|
| 7 |
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
|
| 8 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 9 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 10 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
static __global__ void flash_attn_tile_ext_f32(
|
| 12 |
const char * __restrict__ Q,
|
| 13 |
const char * __restrict__ K,
|
ggml/src/ggml-cuda/fattn-vec-f16.cuh
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
| 5 |
-
#if !(defined(
|
| 6 |
__launch_bounds__(D, 1)
|
| 7 |
-
#endif // !(defined(
|
| 8 |
static __global__ void flash_attn_vec_ext_f16(
|
| 9 |
const char * __restrict__ Q,
|
| 10 |
const char * __restrict__ K,
|
|
|
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
| 5 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 6 |
__launch_bounds__(D, 1)
|
| 7 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 8 |
static __global__ void flash_attn_vec_ext_f16(
|
| 9 |
const char * __restrict__ Q,
|
| 10 |
const char * __restrict__ K,
|
ggml/src/ggml-cuda/fattn-vec-f32.cuh
CHANGED
|
@@ -2,9 +2,9 @@
|
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
| 5 |
-
#if !(defined(
|
| 6 |
__launch_bounds__(D, 1)
|
| 7 |
-
#endif // !(defined(
|
| 8 |
static __global__ void flash_attn_vec_ext_f32(
|
| 9 |
const char * __restrict__ Q,
|
| 10 |
const char * __restrict__ K,
|
|
|
|
| 2 |
#include "fattn-common.cuh"
|
| 3 |
|
| 4 |
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
|
| 5 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 6 |
__launch_bounds__(D, 1)
|
| 7 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 8 |
static __global__ void flash_attn_vec_ext_f32(
|
| 9 |
const char * __restrict__ Q,
|
| 10 |
const char * __restrict__ K,
|
ggml/src/ggml-cuda/fattn-wmma-f16.cuh
CHANGED
|
@@ -7,9 +7,9 @@
|
|
| 7 |
|
| 8 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 9 |
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
| 10 |
-
#if !(defined(
|
| 11 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 12 |
-
#endif // !(defined(
|
| 13 |
static __global__ void flash_attn_ext_f16(
|
| 14 |
const char * __restrict__ Q,
|
| 15 |
const char * __restrict__ K,
|
|
|
|
| 7 |
|
| 8 |
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
| 9 |
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
|
| 10 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 11 |
__launch_bounds__(nwarps*WARP_SIZE, 1)
|
| 12 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 13 |
static __global__ void flash_attn_ext_f16(
|
| 14 |
const char * __restrict__ Q,
|
| 15 |
const char * __restrict__ K,
|
ggml/src/ggml-cuda/ggml-cuda.cu
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ggml/src/ggml-cuda/ggml/CMakeLists.txt
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
|
| 2 |
+
|
| 3 |
+
find_package(CUDAToolkit)
|
| 4 |
+
|
| 5 |
+
if (CUDAToolkit_FOUND)
|
| 6 |
+
message(STATUS "CUDA Toolkit found")
|
| 7 |
+
|
| 8 |
+
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
|
| 9 |
+
# 52 == lowest CUDA 12 standard
|
| 10 |
+
# 60 == FP16 CUDA intrinsics
|
| 11 |
+
# 61 == integer CUDA intrinsics
|
| 12 |
+
# 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster
|
| 13 |
+
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
| 14 |
+
set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75")
|
| 15 |
+
else()
|
| 16 |
+
set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75")
|
| 17 |
+
#set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work
|
| 18 |
+
endif()
|
| 19 |
+
endif()
|
| 20 |
+
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
| 21 |
+
|
| 22 |
+
enable_language(CUDA)
|
| 23 |
+
|
| 24 |
+
file(GLOB GGML_HEADERS_CUDA "*.cuh")
|
| 25 |
+
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
|
| 26 |
+
|
| 27 |
+
file(GLOB GGML_SOURCES_CUDA "*.cu")
|
| 28 |
+
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
|
| 29 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
| 30 |
+
file(GLOB SRCS "template-instances/mmq*.cu")
|
| 31 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
| 32 |
+
|
| 33 |
+
if (GGML_CUDA_FA_ALL_QUANTS)
|
| 34 |
+
file(GLOB SRCS "template-instances/fattn-vec*.cu")
|
| 35 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
| 36 |
+
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
| 37 |
+
else()
|
| 38 |
+
file(GLOB SRCS "template-instances/fattn-vec*q4_0-q4_0.cu")
|
| 39 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
| 40 |
+
file(GLOB SRCS "template-instances/fattn-vec*q8_0-q8_0.cu")
|
| 41 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
| 42 |
+
file(GLOB SRCS "template-instances/fattn-vec*f16-f16.cu")
|
| 43 |
+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
|
| 44 |
+
endif()
|
| 45 |
+
|
| 46 |
+
add_library(ggml-cuda
|
| 47 |
+
${GGML_HEADERS_CUDA}
|
| 48 |
+
${GGML_SOURCES_CUDA}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
target_link_libraries(ggml-cuda PRIVATE ggml-base)
|
| 52 |
+
target_include_directories(ggml-cuda PRIVATE . ..)
|
| 53 |
+
|
| 54 |
+
# TODO: change the definitions to this target only
|
| 55 |
+
|
| 56 |
+
add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
|
| 57 |
+
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
|
| 58 |
+
add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
|
| 59 |
+
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
| 60 |
+
|
| 61 |
+
if (GGML_CUDA_GRAPHS)
|
| 62 |
+
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
|
| 63 |
+
endif()
|
| 64 |
+
|
| 65 |
+
if (GGML_CUDA_FORCE_DMMV)
|
| 66 |
+
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
|
| 67 |
+
endif()
|
| 68 |
+
|
| 69 |
+
if (GGML_CUDA_FORCE_MMQ)
|
| 70 |
+
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
| 71 |
+
endif()
|
| 72 |
+
|
| 73 |
+
if (GGML_CUDA_FORCE_CUBLAS)
|
| 74 |
+
add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
|
| 75 |
+
endif()
|
| 76 |
+
|
| 77 |
+
if (GGML_CUDA_NO_VMM)
|
| 78 |
+
add_compile_definitions(GGML_CUDA_NO_VMM)
|
| 79 |
+
endif()
|
| 80 |
+
|
| 81 |
+
if (DEFINED GGML_CUDA_DMMV_Y)
|
| 82 |
+
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility
|
| 83 |
+
endif()
|
| 84 |
+
|
| 85 |
+
if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16)
|
| 86 |
+
add_compile_definitions(GGML_CUDA_F16)
|
| 87 |
+
endif()
|
| 88 |
+
|
| 89 |
+
if (GGML_CUDA_NO_PEER_COPY)
|
| 90 |
+
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
| 91 |
+
endif()
|
| 92 |
+
|
| 93 |
+
if (GGML_STATIC)
|
| 94 |
+
if (WIN32)
|
| 95 |
+
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
| 96 |
+
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
| 97 |
+
else ()
|
| 98 |
+
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
| 99 |
+
endif()
|
| 100 |
+
else()
|
| 101 |
+
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
| 102 |
+
endif()
|
| 103 |
+
|
| 104 |
+
if (GGML_CUDA_NO_VMM)
|
| 105 |
+
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
|
| 106 |
+
else()
|
| 107 |
+
target_link_libraries(ggml-cuda PRIVATE CUDA::cuda_driver)
|
| 108 |
+
endif()
|
| 109 |
+
|
| 110 |
+
set(CUDA_CXX_FLAGS "")
|
| 111 |
+
|
| 112 |
+
set(CUDA_FLAGS -use_fast_math)
|
| 113 |
+
|
| 114 |
+
if (GGML_FATAL_WARNINGS)
|
| 115 |
+
list(APPEND CUDA_FLAGS -Werror all-warnings)
|
| 116 |
+
endif()
|
| 117 |
+
|
| 118 |
+
if (GGML_ALL_WARNINGS AND NOT MSVC)
|
| 119 |
+
set(NVCC_CMD ${CMAKE_CUDA_COMPILER} .c)
|
| 120 |
+
if (NOT CMAKE_CUDA_HOST_COMPILER STREQUAL "")
|
| 121 |
+
list(APPEND NVCC_CMD -ccbin ${CMAKE_CUDA_HOST_COMPILER})
|
| 122 |
+
endif()
|
| 123 |
+
|
| 124 |
+
execute_process(
|
| 125 |
+
COMMAND ${NVCC_CMD} -Xcompiler --version
|
| 126 |
+
OUTPUT_VARIABLE CUDA_CCFULLVER
|
| 127 |
+
ERROR_QUIET
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
if (NOT CUDA_CCFULLVER MATCHES clang)
|
| 131 |
+
set(CUDA_CCID "GNU")
|
| 132 |
+
execute_process(
|
| 133 |
+
COMMAND ${NVCC_CMD} -Xcompiler "-dumpfullversion -dumpversion"
|
| 134 |
+
OUTPUT_VARIABLE CUDA_CCVER
|
| 135 |
+
ERROR_QUIET
|
| 136 |
+
)
|
| 137 |
+
else()
|
| 138 |
+
if (CUDA_CCFULLVER MATCHES Apple)
|
| 139 |
+
set(CUDA_CCID "AppleClang")
|
| 140 |
+
else()
|
| 141 |
+
set(CUDA_CCID "Clang")
|
| 142 |
+
endif()
|
| 143 |
+
string(REGEX REPLACE "^.* version ([0-9.]*).*$" "\\1" CUDA_CCVER ${CUDA_CCFULLVER})
|
| 144 |
+
endif()
|
| 145 |
+
|
| 146 |
+
message("-- CUDA host compiler is ${CUDA_CCID} ${CUDA_CCVER}")
|
| 147 |
+
|
| 148 |
+
get_flags(${CUDA_CCID} ${CUDA_CCVER})
|
| 149 |
+
list(APPEND CUDA_CXX_FLAGS ${CXX_FLAGS} ${GF_CXX_FLAGS}) # This is passed to -Xcompiler later
|
| 150 |
+
endif()
|
| 151 |
+
|
| 152 |
+
if (NOT MSVC)
|
| 153 |
+
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
|
| 154 |
+
endif()
|
| 155 |
+
|
| 156 |
+
list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument
|
| 157 |
+
|
| 158 |
+
if (NOT CUDA_CXX_FLAGS_JOINED STREQUAL "")
|
| 159 |
+
list(APPEND CUDA_FLAGS -Xcompiler ${CUDA_CXX_FLAGS_JOINED})
|
| 160 |
+
endif()
|
| 161 |
+
|
| 162 |
+
add_compile_options("$<$<COMPILE_LANGUAGE:CUDA>:${CUDA_FLAGS}>")
|
| 163 |
+
else()
|
| 164 |
+
message(FATAL_ERROR "CUDA Toolkit not found")
|
| 165 |
+
endif()
|
ggml/src/ggml-cuda/mmq.cuh
CHANGED
|
@@ -100,9 +100,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|
| 100 |
return 128;
|
| 101 |
#else // INT8_MMA_AVAILABLE
|
| 102 |
|
| 103 |
-
#if defined(
|
| 104 |
return 128;
|
| 105 |
-
#else // defined(
|
| 106 |
|
| 107 |
#if __CUDA_ARCH__ >= CC_VOLTA
|
| 108 |
#ifdef GGML_CUDA_FORCE_MMQ
|
|
@@ -115,7 +115,7 @@ static constexpr __device__ int get_mmq_x_max_device() {
|
|
| 115 |
return 64;
|
| 116 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 117 |
|
| 118 |
-
#endif // defined(
|
| 119 |
#endif // INT8_MMA_AVAILABLE
|
| 120 |
}
|
| 121 |
|
|
@@ -124,7 +124,7 @@ static constexpr int get_mmq_y_host(const int cc) {
|
|
| 124 |
}
|
| 125 |
|
| 126 |
static constexpr __device__ int get_mmq_y_device() {
|
| 127 |
-
#if defined(
|
| 128 |
#if defined(RDNA1)
|
| 129 |
return 64;
|
| 130 |
#else
|
|
@@ -136,7 +136,7 @@ static constexpr __device__ int get_mmq_y_device() {
|
|
| 136 |
#else
|
| 137 |
return 64;
|
| 138 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 139 |
-
#endif // defined(
|
| 140 |
}
|
| 141 |
|
| 142 |
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
|
@@ -2569,7 +2569,7 @@ static __device__ void mul_mat_q_process_tile(
|
|
| 2569 |
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
|
| 2570 |
|
| 2571 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
| 2572 |
-
#if defined(
|
| 2573 |
#if defined(RDNA3) || defined(RDNA2)
|
| 2574 |
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
| 2575 |
#endif // defined(RDNA3) || defined(RDNA2)
|
|
@@ -2579,7 +2579,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
|
| 2579 |
#else
|
| 2580 |
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
| 2581 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 2582 |
-
#endif // defined(
|
| 2583 |
static __global__ void mul_mat_q(
|
| 2584 |
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
| 2585 |
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
|
|
@@ -2594,7 +2594,7 @@ static __global__ void mul_mat_q(
|
|
| 2594 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2595 |
|
| 2596 |
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
| 2597 |
-
#if (defined(
|
| 2598 |
{
|
| 2599 |
constexpr bool fixup = false;
|
| 2600 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
|
@@ -2602,7 +2602,7 @@ static __global__ void mul_mat_q(
|
|
| 2602 |
blockIdx.x, blockIdx.y, 0, ne00/qk);
|
| 2603 |
return;
|
| 2604 |
}
|
| 2605 |
-
#endif // (defined(
|
| 2606 |
|
| 2607 |
const int64_t blocks_per_ne00 = ne00 / qk;
|
| 2608 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
@@ -2765,14 +2765,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
|
|
| 2765 |
|
| 2766 |
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
|
| 2767 |
|
| 2768 |
-
#if !(defined(
|
| 2769 |
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
| 2770 |
if (!shmem_limit_raised[id]) {
|
| 2771 |
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
| 2772 |
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
| 2773 |
shmem_limit_raised[id] = true;
|
| 2774 |
}
|
| 2775 |
-
#endif // !(defined(
|
| 2776 |
|
| 2777 |
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
|
| 2778 |
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
|
|
|
|
| 100 |
return 128;
|
| 101 |
#else // INT8_MMA_AVAILABLE
|
| 102 |
|
| 103 |
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 104 |
return 128;
|
| 105 |
+
#else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 106 |
|
| 107 |
#if __CUDA_ARCH__ >= CC_VOLTA
|
| 108 |
#ifdef GGML_CUDA_FORCE_MMQ
|
|
|
|
| 115 |
return 64;
|
| 116 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 117 |
|
| 118 |
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 119 |
#endif // INT8_MMA_AVAILABLE
|
| 120 |
}
|
| 121 |
|
|
|
|
| 124 |
}
|
| 125 |
|
| 126 |
static constexpr __device__ int get_mmq_y_device() {
|
| 127 |
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 128 |
#if defined(RDNA1)
|
| 129 |
return 64;
|
| 130 |
#else
|
|
|
|
| 136 |
#else
|
| 137 |
return 64;
|
| 138 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 139 |
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 140 |
}
|
| 141 |
|
| 142 |
#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0}
|
|
|
|
| 2569 |
// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598
|
| 2570 |
|
| 2571 |
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
|
| 2572 |
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 2573 |
#if defined(RDNA3) || defined(RDNA2)
|
| 2574 |
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
| 2575 |
#endif // defined(RDNA3) || defined(RDNA2)
|
|
|
|
| 2579 |
#else
|
| 2580 |
__launch_bounds__(WARP_SIZE*nwarps, 2)
|
| 2581 |
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
| 2582 |
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
|
| 2583 |
static __global__ void mul_mat_q(
|
| 2584 |
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
|
| 2585 |
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
|
|
|
|
| 2594 |
constexpr int mmq_y = get_mmq_y_device();
|
| 2595 |
|
| 2596 |
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
|
| 2597 |
+
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
|
| 2598 |
{
|
| 2599 |
constexpr bool fixup = false;
|
| 2600 |
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
|
|
|
|
| 2602 |
blockIdx.x, blockIdx.y, 0, ne00/qk);
|
| 2603 |
return;
|
| 2604 |
}
|
| 2605 |
+
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
|
| 2606 |
|
| 2607 |
const int64_t blocks_per_ne00 = ne00 / qk;
|
| 2608 |
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
|
|
|
|
| 2765 |
|
| 2766 |
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
|
| 2767 |
|
| 2768 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 2769 |
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
|
| 2770 |
if (!shmem_limit_raised[id]) {
|
| 2771 |
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
| 2772 |
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
|
| 2773 |
shmem_limit_raised[id] = true;
|
| 2774 |
}
|
| 2775 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 2776 |
|
| 2777 |
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
|
| 2778 |
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
|
ggml/src/ggml-cuda/mmvq.cu
CHANGED
|
@@ -48,10 +48,10 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
|
|
| 48 |
}
|
| 49 |
|
| 50 |
template <ggml_type type, int ncols_y>
|
| 51 |
-
#if !(defined(
|
| 52 |
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
| 53 |
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
| 54 |
-
#endif // !(defined(
|
| 55 |
static __global__ void mul_mat_vec_q(
|
| 56 |
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
| 57 |
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
|
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
|
|
| 62 |
|
| 63 |
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
| 64 |
|
| 65 |
-
#if defined(
|
| 66 |
constexpr int nwarps = 1;
|
| 67 |
constexpr int rows_per_cuda_block = 1;
|
| 68 |
#else
|
| 69 |
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
| 70 |
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
| 71 |
-
#endif // defined(
|
| 72 |
|
| 73 |
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
| 74 |
const int row0 = rows_per_cuda_block*blockIdx.x;
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
template <ggml_type type, int ncols_y>
|
| 51 |
+
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 52 |
// tell the compiler to use as many registers as it wants, see nwarps definition below
|
| 53 |
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
|
| 54 |
+
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
|
| 55 |
static __global__ void mul_mat_vec_q(
|
| 56 |
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
| 57 |
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
|
|
|
| 62 |
|
| 63 |
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
|
| 64 |
|
| 65 |
+
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
| 66 |
constexpr int nwarps = 1;
|
| 67 |
constexpr int rows_per_cuda_block = 1;
|
| 68 |
#else
|
| 69 |
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
| 70 |
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
| 71 |
+
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
| 72 |
|
| 73 |
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
| 74 |
const int row0 = rows_per_cuda_block*blockIdx.x;
|
ggml/src/ggml-cuda/sum.cu
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
-
#if !defined(
|
| 2 |
#define USE_CUB
|
| 3 |
-
#endif // !defined(
|
| 4 |
|
| 5 |
#ifdef USE_CUB
|
| 6 |
// On Windows CUB uses libraries with variables called CC_PASCAL which conflict with the define in common.cuh.
|
|
|
|
| 1 |
+
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
|
| 2 |
#define USE_CUB
|
| 3 |
+
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
|
| 4 |
|
| 5 |
#ifdef USE_CUB
|
| 6 |
// On Windows CUB uses libraries with variables called CC_PASCAL which conflict with the define in common.cuh.
|
ggml/src/ggml-hip/CMakeLists.txt
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if (NOT EXISTS $ENV{ROCM_PATH})
|
| 2 |
+
if (NOT EXISTS /opt/rocm)
|
| 3 |
+
set(ROCM_PATH /usr)
|
| 4 |
+
else()
|
| 5 |
+
set(ROCM_PATH /opt/rocm)
|
| 6 |
+
endif()
|
| 7 |
+
else()
|
| 8 |
+
set(ROCM_PATH $ENV{ROCM_PATH})
|
| 9 |
+
endif()
|
| 10 |
+
|
| 11 |
+
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
|
| 12 |
+
list(APPEND CMAKE_PREFIX_PATH "${ROCM_PATH}/lib64/cmake")
|
| 13 |
+
|
| 14 |
+
# CMake on Windows doesn't support the HIP language yet
|
| 15 |
+
if (WIN32)
|
| 16 |
+
set(CXX_IS_HIPCC TRUE)
|
| 17 |
+
else()
|
| 18 |
+
string(REGEX MATCH "hipcc(\.bat)?$" CXX_IS_HIPCC "${CMAKE_CXX_COMPILER}")
|
| 19 |
+
endif()
|
| 20 |
+
|
| 21 |
+
if (CXX_IS_HIPCC)
|
| 22 |
+
if (LINUX)
|
| 23 |
+
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
|
| 24 |
+
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
|
| 25 |
+
endif()
|
| 26 |
+
|
| 27 |
+
message(WARNING "Setting hipcc as the C++ compiler is legacy behavior."
|
| 28 |
+
" Prefer setting the HIP compiler directly. See README for details.")
|
| 29 |
+
endif()
|
| 30 |
+
else()
|
| 31 |
+
# Forward AMDGPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
|
| 32 |
+
if (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
|
| 33 |
+
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
|
| 34 |
+
endif()
|
| 35 |
+
cmake_minimum_required(VERSION 3.21)
|
| 36 |
+
enable_language(HIP)
|
| 37 |
+
endif()
|
| 38 |
+
|
| 39 |
+
find_package(hip REQUIRED)
|
| 40 |
+
find_package(hipblas REQUIRED)
|
| 41 |
+
find_package(rocblas REQUIRED)
|
| 42 |
+
|
| 43 |
+
message(STATUS "HIP and hipBLAS found")
|
| 44 |
+
|
| 45 |
+
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
|
| 46 |
+
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
|
| 47 |
+
|
| 48 |
+
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
|
| 49 |
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
|
| 50 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
| 51 |
+
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
|
| 52 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
| 53 |
+
|
| 54 |
+
if (GGML_CUDA_FA_ALL_QUANTS)
|
| 55 |
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
|
| 56 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
| 57 |
+
add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
|
| 58 |
+
else()
|
| 59 |
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
|
| 60 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
| 61 |
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
|
| 62 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
| 63 |
+
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
|
| 64 |
+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
|
| 65 |
+
endif()
|
| 66 |
+
|
| 67 |
+
add_library(ggml-hip
|
| 68 |
+
${GGML_HEADERS_ROCM}
|
| 69 |
+
${GGML_SOURCES_ROCM})
|
| 70 |
+
|
| 71 |
+
target_link_libraries(ggml-hip PRIVATE ggml-base)
|
| 72 |
+
target_include_directories(ggml-hip PRIVATE . ..)
|
| 73 |
+
|
| 74 |
+
# TODO: do not use CUDA definitions for HIP
|
| 75 |
+
target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
|
| 76 |
+
|
| 77 |
+
add_compile_definitions(GGML_USE_HIP)
|
| 78 |
+
add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
|
| 79 |
+
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
|
| 80 |
+
add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
|
| 81 |
+
|
| 82 |
+
if (GGML_HIP_UMA)
|
| 83 |
+
add_compile_definitions(GGML_HIP_UMA)
|
| 84 |
+
endif()
|
| 85 |
+
|
| 86 |
+
if (GGML_CUDA_FORCE_DMMV)
|
| 87 |
+
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
|
| 88 |
+
endif()
|
| 89 |
+
|
| 90 |
+
if (GGML_CUDA_FORCE_MMQ)
|
| 91 |
+
add_compile_definitions(GGML_CUDA_FORCE_MMQ)
|
| 92 |
+
endif()
|
| 93 |
+
|
| 94 |
+
if (GGML_CUDA_FORCE_CUBLAS)
|
| 95 |
+
add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
|
| 96 |
+
endif()
|
| 97 |
+
|
| 98 |
+
if (GGML_CUDA_NO_PEER_COPY)
|
| 99 |
+
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
| 100 |
+
endif()
|
| 101 |
+
|
| 102 |
+
if (CXX_IS_HIPCC)
|
| 103 |
+
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
|
| 104 |
+
target_link_libraries(ggml-hip PRIVATE hip::device)
|
| 105 |
+
else()
|
| 106 |
+
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE HIP)
|
| 107 |
+
endif()
|
| 108 |
+
|
| 109 |
+
if (GGML_STATIC)
|
| 110 |
+
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
|
| 111 |
+
endif()
|
| 112 |
+
|
| 113 |
+
target_link_libraries(ggml-hip PRIVATE ggml-base hip::host roc::rocblas roc::hipblas)
|
ggml/src/ggml-impl.h
CHANGED
|
@@ -3,13 +3,29 @@
|
|
| 3 |
// GGML internal header
|
| 4 |
|
| 5 |
#include "ggml.h"
|
| 6 |
-
|
| 7 |
#include <assert.h>
|
|
|
|
| 8 |
#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
|
| 9 |
#include <stdbool.h>
|
| 10 |
#include <stdint.h>
|
| 11 |
#include <string.h>
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
#ifdef __cplusplus
|
| 14 |
extern "C" {
|
| 15 |
#endif
|
|
@@ -28,13 +44,13 @@ extern "C" {
|
|
| 28 |
// if C99 - static_assert is noop
|
| 29 |
// ref: https://stackoverflow.com/a/53923785/4039976
|
| 30 |
#ifndef __cplusplus
|
| 31 |
-
#ifndef static_assert
|
| 32 |
-
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
|
| 33 |
-
#define static_assert(cond, msg) _Static_assert(cond, msg)
|
| 34 |
-
#else
|
| 35 |
-
#define static_assert(cond, msg) struct global_scope_noop_trick
|
| 36 |
-
#endif
|
| 37 |
-
#endif
|
| 38 |
#endif
|
| 39 |
|
| 40 |
static inline int ggml_up32(int n) {
|
|
@@ -120,14 +136,12 @@ struct ggml_map_custom1_op_params {
|
|
| 120 |
void * userdata;
|
| 121 |
};
|
| 122 |
|
| 123 |
-
|
| 124 |
struct ggml_map_custom2_op_params {
|
| 125 |
ggml_custom2_op_t fun;
|
| 126 |
int n_tasks;
|
| 127 |
void * userdata;
|
| 128 |
};
|
| 129 |
|
| 130 |
-
|
| 131 |
struct ggml_map_custom3_op_params {
|
| 132 |
ggml_custom3_op_t fun;
|
| 133 |
int n_tasks;
|
|
@@ -287,9 +301,249 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
|
|
| 287 |
void * ggml_aligned_malloc(size_t size);
|
| 288 |
void ggml_aligned_free(void * ptr, size_t size);
|
| 289 |
|
| 290 |
-
//
|
| 291 |
-
|
| 292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
|
| 294 |
#ifdef __cplusplus
|
| 295 |
}
|
|
|
|
| 3 |
// GGML internal header
|
| 4 |
|
| 5 |
#include "ggml.h"
|
|
|
|
| 6 |
#include <assert.h>
|
| 7 |
+
#include <math.h>
|
| 8 |
#include <stdlib.h> // load `stdlib.h` before other headers to work around MinGW bug: https://sourceforge.net/p/mingw-w64/bugs/192/
|
| 9 |
#include <stdbool.h>
|
| 10 |
#include <stdint.h>
|
| 11 |
#include <string.h>
|
| 12 |
|
| 13 |
+
#ifdef __ARM_FEATURE_SVE
|
| 14 |
+
#include <arm_sve.h>
|
| 15 |
+
#endif // __ARM_FEATURE_SVE
|
| 16 |
+
|
| 17 |
+
#if defined(__ARM_NEON)
|
| 18 |
+
// if YCM cannot find <arm_neon.h>, make a symbolic link to it, for example:
|
| 19 |
+
//
|
| 20 |
+
// $ ln -sfn /Library/Developer/CommandLineTools/usr/lib/clang/13.1.6/include/arm_neon.h ./src/
|
| 21 |
+
//
|
| 22 |
+
#include <arm_neon.h>
|
| 23 |
+
#endif
|
| 24 |
+
|
| 25 |
+
#if defined(__F16C__)
|
| 26 |
+
#include <immintrin.h>
|
| 27 |
+
#endif
|
| 28 |
+
|
| 29 |
#ifdef __cplusplus
|
| 30 |
extern "C" {
|
| 31 |
#endif
|
|
|
|
| 44 |
// if C99 - static_assert is noop
|
| 45 |
// ref: https://stackoverflow.com/a/53923785/4039976
|
| 46 |
#ifndef __cplusplus
|
| 47 |
+
#ifndef static_assert
|
| 48 |
+
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
|
| 49 |
+
#define static_assert(cond, msg) _Static_assert(cond, msg)
|
| 50 |
+
#else
|
| 51 |
+
#define static_assert(cond, msg) struct global_scope_noop_trick
|
| 52 |
+
#endif
|
| 53 |
+
#endif
|
| 54 |
#endif
|
| 55 |
|
| 56 |
static inline int ggml_up32(int n) {
|
|
|
|
| 136 |
void * userdata;
|
| 137 |
};
|
| 138 |
|
|
|
|
| 139 |
struct ggml_map_custom2_op_params {
|
| 140 |
ggml_custom2_op_t fun;
|
| 141 |
int n_tasks;
|
| 142 |
void * userdata;
|
| 143 |
};
|
| 144 |
|
|
|
|
| 145 |
struct ggml_map_custom3_op_params {
|
| 146 |
ggml_custom3_op_t fun;
|
| 147 |
int n_tasks;
|
|
|
|
| 301 |
void * ggml_aligned_malloc(size_t size);
|
| 302 |
void ggml_aligned_free(void * ptr, size_t size);
|
| 303 |
|
| 304 |
+
// FP16 to FP32 conversion
|
| 305 |
+
|
| 306 |
+
#if defined(__ARM_NEON)
|
| 307 |
+
#ifdef _MSC_VER
|
| 308 |
+
typedef uint16_t ggml_fp16_internal_t;
|
| 309 |
+
#else
|
| 310 |
+
typedef __fp16 ggml_fp16_internal_t;
|
| 311 |
+
#endif
|
| 312 |
+
#endif
|
| 313 |
+
|
| 314 |
+
#if defined(__ARM_NEON) && !defined(_MSC_VER)
|
| 315 |
+
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
| 316 |
+
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
| 317 |
+
|
| 318 |
+
#define GGML_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
| 319 |
+
|
| 320 |
+
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
| 321 |
+
ggml_fp16_internal_t tmp;
|
| 322 |
+
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
|
| 323 |
+
return (float)tmp;
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
| 327 |
+
ggml_fp16_t res;
|
| 328 |
+
ggml_fp16_internal_t tmp = f;
|
| 329 |
+
memcpy(&res, &tmp, sizeof(ggml_fp16_t));
|
| 330 |
+
return res;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
#elif defined(__F16C__)
|
| 334 |
+
|
| 335 |
+
#ifdef _MSC_VER
|
| 336 |
+
#define GGML_COMPUTE_FP16_TO_FP32(x) _mm_cvtss_f32(_mm_cvtph_ps(_mm_cvtsi32_si128(x)))
|
| 337 |
+
#define GGML_COMPUTE_FP32_TO_FP16(x) _mm_extract_epi16(_mm_cvtps_ph(_mm_set_ss(x), 0), 0)
|
| 338 |
+
#else
|
| 339 |
+
#define GGML_COMPUTE_FP16_TO_FP32(x) _cvtsh_ss(x)
|
| 340 |
+
#define GGML_COMPUTE_FP32_TO_FP16(x) _cvtss_sh(x, 0)
|
| 341 |
+
#endif
|
| 342 |
+
|
| 343 |
+
#elif defined(__POWER9_VECTOR__)
|
| 344 |
+
|
| 345 |
+
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
| 346 |
+
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
| 347 |
+
/* the inline asm below is about 12% faster than the lookup method */
|
| 348 |
+
#define GGML_FP16_TO_FP32(x) GGML_COMPUTE_FP16_TO_FP32(x)
|
| 349 |
+
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
| 350 |
+
|
| 351 |
+
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
| 352 |
+
register float f;
|
| 353 |
+
register double d;
|
| 354 |
+
__asm__(
|
| 355 |
+
"mtfprd %0,%2\n"
|
| 356 |
+
"xscvhpdp %0,%0\n"
|
| 357 |
+
"frsp %1,%0\n" :
|
| 358 |
+
/* temp */ "=d"(d),
|
| 359 |
+
/* out */ "=f"(f):
|
| 360 |
+
/* in */ "r"(h));
|
| 361 |
+
return f;
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
| 365 |
+
register double d;
|
| 366 |
+
register ggml_fp16_t r;
|
| 367 |
+
__asm__( /* xscvdphp can work on double or single precision */
|
| 368 |
+
"xscvdphp %0,%2\n"
|
| 369 |
+
"mffprd %1,%0\n" :
|
| 370 |
+
/* temp */ "=d"(d),
|
| 371 |
+
/* out */ "=r"(r):
|
| 372 |
+
/* in */ "f"(f));
|
| 373 |
+
return r;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
#else
|
| 377 |
+
|
| 378 |
+
// FP16 <-> FP32
|
| 379 |
+
// ref: https://github.com/Maratyszcza/FP16
|
| 380 |
+
|
| 381 |
+
static inline float fp32_from_bits(uint32_t w) {
|
| 382 |
+
union {
|
| 383 |
+
uint32_t as_bits;
|
| 384 |
+
float as_value;
|
| 385 |
+
} fp32;
|
| 386 |
+
fp32.as_bits = w;
|
| 387 |
+
return fp32.as_value;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
static inline uint32_t fp32_to_bits(float f) {
|
| 391 |
+
union {
|
| 392 |
+
float as_value;
|
| 393 |
+
uint32_t as_bits;
|
| 394 |
+
} fp32;
|
| 395 |
+
fp32.as_value = f;
|
| 396 |
+
return fp32.as_bits;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
|
| 400 |
+
const uint32_t w = (uint32_t) h << 16;
|
| 401 |
+
const uint32_t sign = w & UINT32_C(0x80000000);
|
| 402 |
+
const uint32_t two_w = w + w;
|
| 403 |
+
|
| 404 |
+
const uint32_t exp_offset = UINT32_C(0xE0) << 23;
|
| 405 |
+
#if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
|
| 406 |
+
const float exp_scale = 0x1.0p-112f;
|
| 407 |
+
#else
|
| 408 |
+
const float exp_scale = fp32_from_bits(UINT32_C(0x7800000));
|
| 409 |
+
#endif
|
| 410 |
+
const float normalized_value = fp32_from_bits((two_w >> 4) + exp_offset) * exp_scale;
|
| 411 |
+
|
| 412 |
+
const uint32_t magic_mask = UINT32_C(126) << 23;
|
| 413 |
+
const float magic_bias = 0.5f;
|
| 414 |
+
const float denormalized_value = fp32_from_bits((two_w >> 17) | magic_mask) - magic_bias;
|
| 415 |
+
|
| 416 |
+
const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
|
| 417 |
+
const uint32_t result = sign |
|
| 418 |
+
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
|
| 419 |
+
return fp32_from_bits(result);
|
| 420 |
+
}
|
| 421 |
+
|
| 422 |
+
static inline ggml_fp16_t ggml_compute_fp32_to_fp16(float f) {
|
| 423 |
+
#if (defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) || defined(__GNUC__) && !defined(__STRICT_ANSI__)) && (!defined(__cplusplus) || __cplusplus >= 201703L)
|
| 424 |
+
const float scale_to_inf = 0x1.0p+112f;
|
| 425 |
+
const float scale_to_zero = 0x1.0p-110f;
|
| 426 |
+
#else
|
| 427 |
+
const float scale_to_inf = fp32_from_bits(UINT32_C(0x77800000));
|
| 428 |
+
const float scale_to_zero = fp32_from_bits(UINT32_C(0x08800000));
|
| 429 |
+
#endif
|
| 430 |
+
float base = (fabsf(f) * scale_to_inf) * scale_to_zero;
|
| 431 |
+
|
| 432 |
+
const uint32_t w = fp32_to_bits(f);
|
| 433 |
+
const uint32_t shl1_w = w + w;
|
| 434 |
+
const uint32_t sign = w & UINT32_C(0x80000000);
|
| 435 |
+
uint32_t bias = shl1_w & UINT32_C(0xFF000000);
|
| 436 |
+
if (bias < UINT32_C(0x71000000)) {
|
| 437 |
+
bias = UINT32_C(0x71000000);
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
base = fp32_from_bits((bias >> 1) + UINT32_C(0x07800000)) + base;
|
| 441 |
+
const uint32_t bits = fp32_to_bits(base);
|
| 442 |
+
const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00);
|
| 443 |
+
const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF);
|
| 444 |
+
const uint32_t nonsign = exp_bits + mantissa_bits;
|
| 445 |
+
return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign);
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
|
| 449 |
+
#define GGML_COMPUTE_FP32_TO_FP16(x) ggml_compute_fp32_to_fp16(x)
|
| 450 |
+
|
| 451 |
+
#endif // defined(__ARM_NEON) && (!defined(__MSC_VER)
|
| 452 |
+
|
| 453 |
+
// precomputed f32 table for f16 (256 KB)
|
| 454 |
+
// defined in ggml.c, initialized in ggml_init()
|
| 455 |
+
GGML_API float ggml_table_f32_f16[1 << 16];
|
| 456 |
+
|
| 457 |
+
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
|
| 458 |
+
// so we define GGML_FP16_TO_FP32 and GGML_FP32_TO_FP16 elsewhere for NEON.
|
| 459 |
+
// This is also true for POWER9.
|
| 460 |
+
#if !defined(GGML_FP16_TO_FP32)
|
| 461 |
+
inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
|
| 462 |
+
uint16_t s;
|
| 463 |
+
memcpy(&s, &f, sizeof(uint16_t));
|
| 464 |
+
return ggml_table_f32_f16[s];
|
| 465 |
+
}
|
| 466 |
+
|
| 467 |
+
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
|
| 468 |
+
#endif
|
| 469 |
+
|
| 470 |
+
#if !defined(GGML_FP32_TO_FP16)
|
| 471 |
+
#define GGML_FP32_TO_FP16(x) GGML_COMPUTE_FP32_TO_FP16(x)
|
| 472 |
+
#endif
|
| 473 |
+
|
| 474 |
+
/**
|
| 475 |
+
* Converts brain16 to float32.
|
| 476 |
+
*
|
| 477 |
+
* The bfloat16 floating point format has the following structure:
|
| 478 |
+
*
|
| 479 |
+
* ┌sign
|
| 480 |
+
* │
|
| 481 |
+
* │ ┌exponent
|
| 482 |
+
* │ │
|
| 483 |
+
* │ │ ┌mantissa
|
| 484 |
+
* │ │ │
|
| 485 |
+
* │┌──┴───┐┌─┴───┐
|
| 486 |
+
* 0b0000000000000000 brain16
|
| 487 |
+
*
|
| 488 |
+
* Since bf16 has the same number of exponent bits as a 32bit float,
|
| 489 |
+
* encoding and decoding numbers becomes relatively straightforward.
|
| 490 |
+
*
|
| 491 |
+
* ┌sign
|
| 492 |
+
* │
|
| 493 |
+
* │ ┌exponent
|
| 494 |
+
* │ │
|
| 495 |
+
* │ │ ┌mantissa
|
| 496 |
+
* │ │ │
|
| 497 |
+
* │┌──┴───┐┌─┴───────────────────┐
|
| 498 |
+
* 0b00000000000000000000000000000000 IEEE binary32
|
| 499 |
+
*
|
| 500 |
+
* For comparison, the standard fp16 format has fewer exponent bits.
|
| 501 |
+
*
|
| 502 |
+
* ┌sign
|
| 503 |
+
* │
|
| 504 |
+
* │ ┌exponent
|
| 505 |
+
* │ │
|
| 506 |
+
* │ │ ┌mantissa
|
| 507 |
+
* │ │ │
|
| 508 |
+
* │┌─┴─┐┌─┴──────┐
|
| 509 |
+
* 0b0000000000000000 IEEE binary16
|
| 510 |
+
*
|
| 511 |
+
* @see IEEE 754-2008
|
| 512 |
+
*/
|
| 513 |
+
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
|
| 514 |
+
union {
|
| 515 |
+
float f;
|
| 516 |
+
uint32_t i;
|
| 517 |
+
} u;
|
| 518 |
+
u.i = (uint32_t)h.bits << 16;
|
| 519 |
+
return u.f;
|
| 520 |
+
}
|
| 521 |
+
|
| 522 |
+
/**
|
| 523 |
+
* Converts float32 to brain16.
|
| 524 |
+
*
|
| 525 |
+
* This is binary identical with Google Brain float conversion.
|
| 526 |
+
* Floats shall round to nearest even, and NANs shall be quiet.
|
| 527 |
+
* Subnormals aren't flushed to zero, except perhaps when used.
|
| 528 |
+
* This code should vectorize nicely if using modern compilers.
|
| 529 |
+
*/
|
| 530 |
+
static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) {
|
| 531 |
+
ggml_bf16_t h;
|
| 532 |
+
union {
|
| 533 |
+
float f;
|
| 534 |
+
uint32_t i;
|
| 535 |
+
} u;
|
| 536 |
+
u.f = s;
|
| 537 |
+
if ((u.i & 0x7fffffff) > 0x7f800000) { /* nan */
|
| 538 |
+
h.bits = (u.i >> 16) | 64; /* force to quiet */
|
| 539 |
+
return h;
|
| 540 |
+
}
|
| 541 |
+
h.bits = (u.i + (0x7fff + ((u.i >> 16) & 1))) >> 16;
|
| 542 |
+
return h;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
|
| 546 |
+
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
|
| 547 |
|
| 548 |
#ifdef __cplusplus
|
| 549 |
}
|
ggml/src/ggml-kompute/CMakeLists.txt
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
| 3 |
+
find_program(glslc_executable NAMES glslc HINTS Vulkan::glslc)
|
| 4 |
+
|
| 5 |
+
if (NOT glslc_executable)
|
| 6 |
+
message(FATAL_ERROR "glslc not found")
|
| 7 |
+
endif()
|
| 8 |
+
|
| 9 |
+
add_library(ggml-kompute
|
| 10 |
+
ggml-kompute.cpp
|
| 11 |
+
../../include/ggml-kompute.h
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
target_link_libraries(ggml-kompute PRIVATE ggml-base kompute)
|
| 15 |
+
target_include_directories(ggml-kompute PRIVATE . .. ${CMAKE_CURRENT_BINARY_DIR})
|
| 16 |
+
|
| 17 |
+
add_compile_definitions(VULKAN_HPP_DISPATCH_LOADER_DYNAMIC=1)
|
| 18 |
+
|
| 19 |
+
function(compile_shader)
|
| 20 |
+
set(options)
|
| 21 |
+
set(oneValueArgs)
|
| 22 |
+
set(multiValueArgs SOURCES)
|
| 23 |
+
cmake_parse_arguments(compile_shader "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
| 24 |
+
foreach(source ${compile_shader_SOURCES})
|
| 25 |
+
get_filename_component(filename ${source} NAME)
|
| 26 |
+
set(spv_file ${filename}.spv)
|
| 27 |
+
add_custom_command(
|
| 28 |
+
OUTPUT ${spv_file}
|
| 29 |
+
DEPENDS ${CMAKE_CURRENT_SOURCE_DIR}/${source}
|
| 30 |
+
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/common.comp
|
| 31 |
+
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_getrows.comp
|
| 32 |
+
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n_pre.comp
|
| 33 |
+
${CMAKE_CURRENT_SOURCE_DIR}/kompute-shaders/op_mul_mv_q_n.comp
|
| 34 |
+
COMMAND ${glslc_executable} --target-env=vulkan1.2 -o ${spv_file} ${CMAKE_CURRENT_SOURCE_DIR}/${source}
|
| 35 |
+
COMMENT "Compiling ${source} to ${spv_file}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
get_filename_component(RAW_FILE_NAME ${spv_file} NAME)
|
| 39 |
+
set(FILE_NAME "shader${RAW_FILE_NAME}")
|
| 40 |
+
string(REPLACE ".comp.spv" ".h" HEADER_FILE ${FILE_NAME})
|
| 41 |
+
string(TOUPPER ${HEADER_FILE} HEADER_FILE_DEFINE)
|
| 42 |
+
string(REPLACE "." "_" HEADER_FILE_DEFINE "${HEADER_FILE_DEFINE}")
|
| 43 |
+
set(OUTPUT_HEADER_FILE "${HEADER_FILE}")
|
| 44 |
+
message(STATUS "${HEADER_FILE} generating ${HEADER_FILE_DEFINE}")
|
| 45 |
+
if(CMAKE_GENERATOR MATCHES "Visual Studio")
|
| 46 |
+
add_custom_command(
|
| 47 |
+
OUTPUT ${OUTPUT_HEADER_FILE}
|
| 48 |
+
COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
|
| 49 |
+
COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
| 50 |
+
COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
| 51 |
+
COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
|
| 52 |
+
COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
|
| 53 |
+
COMMAND ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
|
| 54 |
+
COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
|
| 55 |
+
COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
| 56 |
+
DEPENDS ${spv_file} xxd
|
| 57 |
+
COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/$<CONFIG>/xxd"
|
| 58 |
+
)
|
| 59 |
+
else()
|
| 60 |
+
add_custom_command(
|
| 61 |
+
OUTPUT ${OUTPUT_HEADER_FILE}
|
| 62 |
+
COMMAND ${CMAKE_COMMAND} -E echo "/*THIS FILE HAS BEEN AUTOMATICALLY GENERATED - DO NOT EDIT*/" > ${OUTPUT_HEADER_FILE}
|
| 63 |
+
COMMAND ${CMAKE_COMMAND} -E echo \"\#ifndef ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
| 64 |
+
COMMAND ${CMAKE_COMMAND} -E echo \"\#define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
| 65 |
+
COMMAND ${CMAKE_COMMAND} -E echo "namespace kp {" >> ${OUTPUT_HEADER_FILE}
|
| 66 |
+
COMMAND ${CMAKE_COMMAND} -E echo "namespace shader_data {" >> ${OUTPUT_HEADER_FILE}
|
| 67 |
+
COMMAND ${CMAKE_BINARY_DIR}/bin/xxd -i ${RAW_FILE_NAME} >> ${OUTPUT_HEADER_FILE}
|
| 68 |
+
COMMAND ${CMAKE_COMMAND} -E echo "}}" >> ${OUTPUT_HEADER_FILE}
|
| 69 |
+
COMMAND ${CMAKE_COMMAND} -E echo \"\#endif // define ${HEADER_FILE_DEFINE}\" >> ${OUTPUT_HEADER_FILE}
|
| 70 |
+
DEPENDS ${spv_file} xxd
|
| 71 |
+
COMMENT "Converting to hpp: ${FILE_NAME} ${CMAKE_BINARY_DIR}/bin/xxd"
|
| 72 |
+
)
|
| 73 |
+
endif()
|
| 74 |
+
endforeach()
|
| 75 |
+
endfunction()
|
| 76 |
+
|
| 77 |
+
if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
|
| 78 |
+
message(STATUS "Kompute found")
|
| 79 |
+
set(KOMPUTE_OPT_LOG_LEVEL Error CACHE STRING "Kompute log level")
|
| 80 |
+
add_subdirectory(kompute)
|
| 81 |
+
|
| 82 |
+
# Compile our shaders
|
| 83 |
+
compile_shader(SOURCES
|
| 84 |
+
kompute-shaders/op_scale.comp
|
| 85 |
+
kompute-shaders/op_scale_8.comp
|
| 86 |
+
kompute-shaders/op_add.comp
|
| 87 |
+
kompute-shaders/op_addrow.comp
|
| 88 |
+
kompute-shaders/op_mul.comp
|
| 89 |
+
kompute-shaders/op_silu.comp
|
| 90 |
+
kompute-shaders/op_relu.comp
|
| 91 |
+
kompute-shaders/op_gelu.comp
|
| 92 |
+
kompute-shaders/op_softmax.comp
|
| 93 |
+
kompute-shaders/op_norm.comp
|
| 94 |
+
kompute-shaders/op_rmsnorm.comp
|
| 95 |
+
kompute-shaders/op_diagmask.comp
|
| 96 |
+
kompute-shaders/op_mul_mat_mat_f32.comp
|
| 97 |
+
kompute-shaders/op_mul_mat_f16.comp
|
| 98 |
+
kompute-shaders/op_mul_mat_q8_0.comp
|
| 99 |
+
kompute-shaders/op_mul_mat_q4_0.comp
|
| 100 |
+
kompute-shaders/op_mul_mat_q4_1.comp
|
| 101 |
+
kompute-shaders/op_mul_mat_q4_k.comp
|
| 102 |
+
kompute-shaders/op_mul_mat_q6_k.comp
|
| 103 |
+
kompute-shaders/op_getrows_f32.comp
|
| 104 |
+
kompute-shaders/op_getrows_f16.comp
|
| 105 |
+
kompute-shaders/op_getrows_q4_0.comp
|
| 106 |
+
kompute-shaders/op_getrows_q4_1.comp
|
| 107 |
+
kompute-shaders/op_getrows_q6_k.comp
|
| 108 |
+
kompute-shaders/op_rope_f16.comp
|
| 109 |
+
kompute-shaders/op_rope_f32.comp
|
| 110 |
+
kompute-shaders/op_cpy_f16_f16.comp
|
| 111 |
+
kompute-shaders/op_cpy_f16_f32.comp
|
| 112 |
+
kompute-shaders/op_cpy_f32_f16.comp
|
| 113 |
+
kompute-shaders/op_cpy_f32_f32.comp
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# Create a custom target for our generated shaders
|
| 117 |
+
add_custom_target(generated_shaders DEPENDS
|
| 118 |
+
shaderop_scale.h
|
| 119 |
+
shaderop_scale_8.h
|
| 120 |
+
shaderop_add.h
|
| 121 |
+
shaderop_addrow.h
|
| 122 |
+
shaderop_mul.h
|
| 123 |
+
shaderop_silu.h
|
| 124 |
+
shaderop_relu.h
|
| 125 |
+
shaderop_gelu.h
|
| 126 |
+
shaderop_softmax.h
|
| 127 |
+
shaderop_norm.h
|
| 128 |
+
shaderop_rmsnorm.h
|
| 129 |
+
shaderop_diagmask.h
|
| 130 |
+
shaderop_mul_mat_mat_f32.h
|
| 131 |
+
shaderop_mul_mat_f16.h
|
| 132 |
+
shaderop_mul_mat_q8_0.h
|
| 133 |
+
shaderop_mul_mat_q4_0.h
|
| 134 |
+
shaderop_mul_mat_q4_1.h
|
| 135 |
+
shaderop_mul_mat_q4_k.h
|
| 136 |
+
shaderop_mul_mat_q6_k.h
|
| 137 |
+
shaderop_getrows_f32.h
|
| 138 |
+
shaderop_getrows_f16.h
|
| 139 |
+
shaderop_getrows_q4_0.h
|
| 140 |
+
shaderop_getrows_q4_1.h
|
| 141 |
+
shaderop_getrows_q6_k.h
|
| 142 |
+
shaderop_rope_f16.h
|
| 143 |
+
shaderop_rope_f32.h
|
| 144 |
+
shaderop_cpy_f16_f16.h
|
| 145 |
+
shaderop_cpy_f16_f32.h
|
| 146 |
+
shaderop_cpy_f32_f16.h
|
| 147 |
+
shaderop_cpy_f32_f32.h
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Create a custom command that depends on the generated_shaders
|
| 151 |
+
add_custom_command(
|
| 152 |
+
OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
|
| 153 |
+
COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp
|
| 154 |
+
DEPENDS generated_shaders
|
| 155 |
+
COMMENT "Ensuring shaders are generated before compiling ggml-kompute.cpp"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Add the stamp to the main sources to ensure dependency tracking
|
| 159 |
+
target_sources(ggml-kompute PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/ggml-kompute.stamp)
|
| 160 |
+
else()
|
| 161 |
+
message(WARNING "Kompute not found")
|
| 162 |
+
endif()
|